🚧 This website is under construction. 🚧
AtCoder
蟻本
4-7

4-7 文字列を華麗に扱う

文字列に対する動的計画法

単一文字列の場合

複数文字列の場合

文字列検索

長さ nn の文字列 SS に含まれる、長さ mm の文字列 TT の場所を探したり、含まれる回数を探したりすることを文字列検索という

ナイーブな解法では、SS の開始位置を全て試し、一致しているかを調べる O(nm)O(nm) の解法が存在するが、ローリングハッシュではこれを O(n)O(n) で行う。

ローリングハッシュ

  • ナイーブ解では一回の一致判定に O(m)O(m) であるが、これをハッシュ値を用いることにより O(1)O(1) で行うことで目指す
  • そのままではハッシュ値の計算に O(m)O(m) かかるものの、直前の比較に用いたハッシュ値を利用することで計算を高速化 O(1)O(1) する

互いに素な適当な定数 b,h(1<b<h)b,h(1\lt{b}\lt{h}) を用意し、文字列 C=c1c2...cmC=c_1c_2...c_m のハッシュ値を、

H(C)=(c1×bm1+c2×bm2+...+cm×b0)modhH(C)=(c_1\times b^{m-1}+c{2}\times b^{m-2}+...+c_{m}\times b^{0})\mod{h}

とする。すると、文字列 S=s1s2...snS=s_1s_2...s_nk+1k+1 文字目からの mm 文字の部分文字列 sk+1...sk+ms_{k+1}...s_{k+m} に対するハッシュ値は、kk 文字目からの部分文字列 sksk+m1s_{k}\ldots s_{k+m-1} のハッシュ値により、以下のようにすぐ計算できる。

H(sk+1...sk+m)=(H(sk...sk+m1)×bsk×bm+sk+m)modhH(s_{k+1}...s_{k+m})=(H(s_{k}...s_{k+m-1})\times b-s_{k}\times b^{m}+s_{k+m})\mod{h}

h=264h=2^{64} と定めると mod を取る操作を省く事が出来ることを利用して、

H(sk+1...sk+m)=sk+1×bm1+sk+2×bm2+...+sk+m1×b1+sk+m×b0=sk×bm+sk+1×bm1+sk+2×bm2+...+sk+m1×b1+sk+m×b0sk×bm=(sk×bm1+sk+1×bm2+sk+2×bm3+...+sk+m1×b0)×b+sk+m×b0sk×bm=H(sk...sk+m1)×b+sk+msk×bm\begin{aligned} H(s_{k+1}...s_{k+m})&=s_{k+1}\times b^{m-1}+s_{k+2}\times b^{m-2}+...+s_{k+m-1}\times b^{1}+s_{k+m}\times b^{0} \\ &=s_{k}\times b^{m}+s_{k+1}\times b^{m-1}+s_{k+2}\times b^{m-2}+...+s_{k+m-1}\times b^{1}+s_{k+m}\times b^{0}-s_{k}\times b^{m} \\ &=(s_{k}\times b^{m-1}+s_{k+1}\times b^{m-2}+s_{k+2}\times b^{m-3}+...+s_{k+m-1}\times b^{0}) \times b+s_{k+m}\times b^{0}-s_{k}\times b^{m} \\ &=H(s_{k}...s_{k+m-1})\times b+s_{k+m}-s_{k}\times b^{m} \end{aligned}
# ローリングハッシュ
 
B = 1000000007
H = 998244353
 
 
def rolling_hash(a: str, b: str) -> bool:
    """return b in a"""
    n = len(a)
    m = len(b)
    powers = [1]
    for _ in range(m):
        powers.append(powers[-1] * B % H)
    a_hash = sum(map(lambda i: ord(a[i]) * powers[m - i - 1], range(m))) % H
    b_hash = sum(map(lambda i: ord(b[i]) * powers[m - i - 1], range(m))) % H
    if a_hash == b_hash:
        return True
    for i in range(n - m):
        a_hash = (a_hash * B + ord(a[i + m]) - ord(a[i]) * powers[m]) % H
        if a_hash == b_hash:
            return True
    return False
 

(Apple M2 (macOS 14.4.1 (23E224)) での実行結果)

>>> import random, timeit
>>> n = 1_000_000
>>> m = 500_000
>>> s = "".join(chr(random.randint(ord("a"), ord("z"))) for _ in range(n))
>>> t = "".join(chr(random.randint(ord("a"), ord("z"))) for _ in range(m))
>>> s = s[: n // 3] + t + s[n // 3 + m :]
>>> assert(t in s and rolling_hash(s, t))
>>> timeit.timeit("t in s", globals=globals(), number=100)
0.08759437501430511
>>> timeit.timeit("rolling_hash(s, t)", globals=globals(), number=100)
14.977220457978547  # ??