KMP 算法

kmp-01
失配函数计算,根据之前的分析,对于某个位置 P(i)P(i)T[k]T[\cdots k \cdots] 失配
尝试去匹配前缀子串 P[0i1]P[0 \cdots i-1],预处理失配函数,即尝试让串 PP 自己匹配自己

  1. f(i)=jf(i) = j 表示在 P(i)P(i) 处失配时,应该跳到前缀的哪个位置
    如果下标从 00 开始,jj 表示状态机上已经匹配了 jj 个字符,即前缀 [0j1][0\cdots j-1]
    即对于当前位置 P(i)P(i)尝试P(i)P(i) 匹配 P(j)P(j)
  2. f(i)=jf(i) = j,可以根据 jj 来建立状态机,状态机节点编号 jj,同样表示已匹配前缀 [0j1][0\cdots j-1]
    如果下标从 00 开始,还可以表示当前尝试匹配 P(i),P(j)P(i), P(j)
  3. 算法设计如下,遍历 PP,对于 P(i)P(i),其失配后应该要跳到 j=f(i)j = f(i)
    • 如果 P(j)P(i)P(j) \neq P(i),不断地往失配边走,即不断地令 jf(j)j' \leftarrow f(j)
      直到 P(j)P(j')P(i)P(i) 匹配上
    • 如果找不到这样的 jj',那么接下来应该从 00 开始重新匹配
  4. 可以发现一个递推结构,即进行第 ii 次匹配的时候,我们可以知道第 i+1i+1 次应该从哪里匹配
    • P(i)=P(j)P(i) = P(j'),那么第 i+1i+1 次匹配应该考虑 j+1j'+1,即 f(i+1)=j+1f(i+1) = j'+1
    • 如果找不到这样的 jj',即 j=0j'=0 仍然有 P(i)P(j)P(i) \neq P(j'),令 f(i+1)=jf(i+1) = j'
      kmp-02

有了失配函数,主算法就比较好理解了

  • j=0j = 0,表示最开始在模式串 P(0)P(0)
  • 遍历文本串 TT,对于位置 T(i)T(i)
    • 如果 T(i)P(j)T(i) \neq P(j),那么就不断沿着失配边 jf(j)j \leftarrow f(j) 走,直到匹配为止
    • 如果 T(i)=P(j)T(i) = P(j),那么 jj+1j \leftarrow j+1
  • 如果循环中 jj 走完了模式串,即 j=length(P)j = \text{length}(P),那么成功匹配
    此时 T[im+1,i]T[i-m+1, i] 就是文本串和模式串匹配上的部分,退出循环
    如果 ii 遍历完整个 TT 之后,jlength(P)j \neq \text{length}(P),那么整个串都无法和 PP 匹配
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
const int maxn = 1e5 + 10;
int f[maxn];

void getFail(const char *P) {
int n = strlen(P);
f[0] = 0, f[1] = 0;
for (int i = 1; i < n; i++) {
int j = f[i];
while (j && P[j] != P[i]) j = f[j];
f[i+1] = (P[j] == P[i] ? j+1 : 0);
}
}

vector<int> ans;
void KMP(const char *T, const char *P) {
int n = strlen(T), m = strlen(P);
getFail(P);

int j = 0;
for (int i = 0; i < n; i++) {
while (j && T[i] != P[j]) j = f[j];
if (T[i] == P[j]) j++;
if (j == m) ans.push_back(i-m+1);
}
}

最短循环节
Acwing141

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
const int maxn = 1000000 + 10;
int f[maxn], n;
char str[maxn];

void getFail() {
memset(f, 0, sizeof f);
for (int i = 1; i < n; i++) {
int j = f[i];
while (j && str[j] != str[i]) j = f[j];
f[i+1] = (str[i] == str[j] ? j+1 : 0);
}
}

int main() {
freopen("input.txt", "r", stdin);
int T = 0;
while (scanf("%d", &n) == 1 && n) {
printf("Test case #%d\n", ++T);
scanf("%s", str);

getFail();
for (int i = 2; i <= n; i++) {
if (f[i] > 0 && i % (i-f[i]) == 0) printf("%d %d\n", i, i / (i-f[i]));
}
puts("");
}
}

值得注意的是,对于字符串 str[0n1]\text{str}[0\cdots n-1]
KMP\text{KMP} 失配函数 f(i)f(i)i[1,n]i \in [1, n]ii 表示的是前缀长度

Trie

Trie 是基于前缀的多叉树数据结构,Trie 的构建是采用动态开点的

  • 指针 pp 指向节点编号,初始时候指向根节点 11
  • Trie\text{Trie} 的根节点编号为 11
  • 在构建的时候,同时标记每个编号idx\text{idx} 的节点,是否为串的末尾?
    如果是,用 end[idx]\text{end}[idx] 存储字符串信息

插入
指针 pp 初始化为根节点 11,然后扫描字符串中的每个字符 cc

  • 如果 trie(p,c)=q\text{trie}(p, c) = q,那么 trie\text{trie} 中已经有这个字符信息了,pqp \leftarrow q
  • 如果 trie(p,c)=0\text{trie}(p, c) = 0,那么新建一个节点,trie(p,c)=++tot=q\text{trie}(p, c) = ++tot = q,然后令 pqp \leftarrow q
  • 字符串扫描完成后,标记末尾节点信息 end(p)end(p)注意,一般情况下,只在表示字符串末尾的节点,打标记

查询

  • 同样指针 pp 初始化为根节点 11,然后扫描每个字符 cc,如果 trie(p,c)=0\text{trie}(p, c) = 0
    那么 trie\text{trie} 树中不存在这个字符串,返回
  • 否则的话,p=trie(p,c)p = \text{trie}(p, c)
  • 字符串扫描完毕之后,如果 end(p)\text{end}(p) 存在,那么 trie\text{trie} 中存在该字符串

前缀统计

只有在字符串末尾,cnt(p)0\text{cnt}(p) \neq 0,考虑 pp 从根节点 11 开始
沿着模式串 str\text{str} 的字符往下走,对路径上的 cnt\text{cnt} 求和, cnt(p)\sum \text{cnt}(p) 就是答案

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
const int maxn = 500000 + 10;

class Trie {
public:
int n;
int tot;
vector<vector<int> > t;
vector<int> cnt;

Trie() = default;
Trie(int _n) : n(_n) {
tot = 1;
t.resize(n), cnt.resize(n);

fill(t.begin(), t.end(), vector<int> (26, 0));
fill(cnt.begin(), cnt.end(), 0);
}

void insert(const string &str) {
int p = 1;
for (auto x : str) {
int c = x - 'a';
if (t[p][c] == 0) t[p][c] = ++tot;
p = t[p][c];
}
cnt[p]++;
}

int query(const string &str) {
int p = 1, res = 0;
for (auto x : str) {
int c = x - 'a';
if (t[p][c] == 0) break;
p = t[p][c];
res += cnt[p];
}
return res;
}
};

Trie trie(maxn);

int n, m;

int main() {
freopen("input.txt", "r", stdin);
ios::sync_with_stdio(false);
cin.tie(0), cout.tie(0);

cin >> n >> m;
while (n--) {
string str;
cin >> str;
trie.insert(str);
}
while (m--) {
string str;
cin >> str;
int res = trie.query(str);
cout << res << endl;
}
}

Trie 处理位运算

最大异或对
对每个  Ai\forall \ A_i,考虑从高位到低位, i[highbitlowbit]\forall \ i \in [\text{highbit} \to \text{lowbit}]
要找到这样的一个 AjA_j,满足 AjA_j 尽可能多的高位与 AiA_i 不同,这样异或的高位就有尽可能多的 11

  • 将所有数按位插入 trie\text{trie} 中,p=1p = 1 初始化为根节点
  • 遍历每个 AiA_ii[highbitlowbit]i \in [\text{highbit} \to \text{lowbit}] 检查每一位
    • 如果 trie(p,!i)0\text{trie}(p, !i) \neq 0,那么顺着 trie(p,!i)\text{trie}(p, !i)
      并且 res+=(1<<i)res += (1 << i)
    • 否则沿着 trie(p,i)\text{trie}(p, i)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
const int maxn = 3100000 + 5;
const int N = 1e5 + 5;
int n, a[N];

class Trie {
public:
int tot;
int t[maxn][2];

Trie() {
tot = 1;
memset(t, 0, sizeof t);
}

void insert(int x) {
int p = 1;
for (int i = 30; i >= 0; i--) {
int c = (x >> i & 1);
if (t[p][c] == 0) t[p][c] = ++tot;
p = t[p][c];
}
}
int query(int x) {
int res = 0, p = 1;
for (int i = 30; i >= 0; i--) {
int c = x >> i & 1;
if (t[p][!c]) {
res += (1<<i);
p = t[p][!c];
}
else p = t[p][c];
}
return res;
}
};

Trie trie;

int main() {
freopen("input.txt", "r", stdin);
scanf("%d", &n);
for (int i = 0; i < n; i++) {
scanf("%d", &a[i]);
trie.insert(a[i]);
}

int res = 0;
for (int i = 0; i < n; i++) res = max(res, trie.query(a[i]));
printf("%d\n", res);
}

该问题也有扩展版本
最长异或值路径
注意到对于树边 (x,y)(x, y)d(y)d(y) 表示根节点到 yy 的距离,那么有
d(y)=d(x)e(x,y)d(y) = d(x) \oplus e(x, y),由于异或操作对于路径重复的部分,值为 00
所以 (xy)\oplus (x \to y) 实际上就是 d(x)d(y)d(x) \oplus d(y)
只要通过 dfs\text{dfs} 预处理出所有的 d[x]d[x],然后转换成最大异或对问题求解即可

Trie 处理字符串统计

Remember The Word

f(i)f(i) 表示字符串 S[iL]S[i\cdots L] 的拆分方案数 (从下标为 ii 开始到字符串末尾)
可以得到状态转移方程

f(i)=f(j){S[ij1]构成一个合法的单词}f(i) = \sum f(j) \quad \{ S[i \cdots j-1] \text{构成一个合法的单词} \}

从而考虑可以把所有的单词插入 trie\text{trie} 树中

f(i)=f((i+len(x)1)+1){x 是满足 S[i]trie 的单词}f(i) = \sum f((i + len(x)-1) + 1) \quad \{x \ \text{是满足} \ S[i\cdots] \in \text{trie} \ \text{的单词} \}

  • 将所有单词插入 trie\text{trie},同时在表示串末的节点上,维护一个 len[]len[\cdots] 信息
  • 具体来说,ii 从后往前遍历字符串(如果从前往后的话,可能在 trie 树中找到了相应的串,但串终点会越过 ii,会造成重复计算)
  • trie\text{trie} 树中,令 p=1p=1 为根节点,然后沿着 S[i]S[i\cdots] 查找
    如果在路径中遇到了字符串的结尾,就把相应的 len(p)len(p) 存入集合 vec\text{vec}
  • 对每一个 ii,遍历 vec\text{vec},并且令 f(i)+=f(i+vec(j))f(i) += f(i + \text{vec}(j)),最初时候,f(len)=1f(len) = 1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
const int maxn = 1200000 + 5, mod = 20071027;
const int N = 300000 + 10;
string S;
int n, len;
ll f[N];

class Trie {
public:
int t[maxn][26];
int len[maxn];
int tot = 1;

Trie() {
tot = 1;
memset(len, 0, sizeof len);
memset(t, 0, sizeof t);
}

void clear() {
tot = 1;
memset(t, 0, sizeof t);
memset(len, 0, sizeof len);
}

void insert(const string &str) {
int p = 1;
for (auto x : str) {
int c = x-'a';
if (!t[p][c]) t[p][c] = ++tot;
p = t[p][c];
}
len[p] = str.length();
}

void query(const string &str, int pos, vector<int> &vec) {
int p = 1;
for (int i = pos; i < str.length(); i++) {
int c = str[i] - 'a';
if (t[p][c] == 0) break;
p = t[p][c];
if (len[p]) vec.push_back(len[p]);
}
}
};

Trie trie;

void dp() {
memset(f, 0, sizeof f);
len = S.length();
f[len] = 1;

for (int i = len-1; i >= 0; i--) {
vector<int> vec;
trie.query(S, i, vec);
for (auto x : vec) f[i] += f[i+x], f[i] %= mod;
}
printf("%lld\n", f[0]);
}

int main() {
freopen("input.txt", "r", stdin);
ios::sync_with_stdio(false);
cin.tie(0), cout.tie(0);

int T = 0;
while (cin >> S) {
trie.clear();
printf("Case %d: ", ++T);
cin >> n;
for (int i = 0; i < n; i++) {
string str;
cin >> str;
trie.insert(str);
}
dp();
}
}

Trie 与前缀处理

Trie 这种数据结构经常用于前缀比较问题
strcmp() Anyone

这里可以使用边插入边统计,对于第 ii 个串 SS ,此时 trie\text{trie} 中已经放入第 [1i1][1\cdots i-1] 个串
此时遍历 trie\text{trie} 树,对于某个节点 pp,在插入字符 SjS_j 的时候,t(p,Sj)=0t(p, S_j) = 0,此时 pp 为分叉点
首先能想到的是,如果两个字符串不完全相同,那么 cnt=2lcp+1\text{cnt} = 2 \cdot \text{lcp} + 1
推广到多个串的情况,2lcp+1(i1)2 \cdot \sum \text{lcp} + 1\cdot (i-1)
lcp\text{lcp} 部分比较 22 次,第 ii 个串失配位置与前 i1i-1 个串需比较 i1i-1 次,下面来处理边界问题

  • 因为串在 SjS_j 位置一旦失配,那么从 S[j’0’]S[j \cdots \text{'0'}] (把末尾的 ‘\0’ 也算上),这部分串信息是无用的
    所以可以假设每一个串都会在末尾失配,对于第 ii 个串,resres+(i1)res \leftarrow res + (i-1)
  • 如果两个串完全相同,那么相当于额外花费了失配补全代价,即把原先失配的部分,让它匹配上
    如果前 [1i][1\cdots i] 个串都完全相同,插入第 ii 个串的时候,resres+(i1)res \leftarrow res + (i-1) 是失配代价
    在此基础上,把前 i1i-1 个串失配的部分一一补全,代价又要增加 i1i-1resres+2(i1)res \leftarrow res + 2\cdot (i-1)

可以设计算法如下

  • 边插入边统计,对于第 ii 个串,resres+(i1)res \leftarrow res + (i-1),表示失配代价
  • 对于 trie\text{trie} 中的节点 pp,维护 cnt(p)\text{cnt}(p),表示 [1i1][1\cdots i-1] 中有几个串经过 pp 节点
    这部分对答案的贡献为 2cnt(p),res+=2cnt(p)2 \text{cnt}(p), \quad res += 2\cdot \text{cnt}(p),然后 cnt(p)++\text{cnt}(p)++
  • 到达串结尾的时候,还要维护一个 end(p)\text{end}(p),表示补全失配的代价,此时意味着串相等
    end(p)\text{end}(p) 表示 pp 这个节点,是 [1i1][1\cdots i-1] 中几个串的结尾?也意味着 [1i1][1\cdots i-1] 有几个和串 ii 相同
    相等的串即为需要失配补全的串,此时 res+=end(p)\text{res} += \text{end}(p),然后 end(p)++\text{end}(p)++
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
const int maxn = 4001 * 1000 + 5;
int n;

inline int get(char x) {
if (isdigit(x)) return x - '0';
else if (isupper(x)) return x - 'A' + 10;
else return x - 'a' + 36;
}

class Trie {
public:
int t[maxn][65];
int cnt[maxn], end[maxn];
int tot;

Trie() {
tot = 1;
memset(t, 0, sizeof t);
memset(cnt, 0, sizeof cnt);
memset(end, 0, sizeof end);
}
void clear() {
tot = 1;
memset(t, 0, sizeof t);
memset(cnt, 0, sizeof cnt);
memset(end, 0, sizeof end);
}

void insert(const string &str, int val, ll &res) {
res += val;
int p = 1;
for (auto x : str) {
int c = get(x);
if (!t[p][c]) t[p][c] = ++tot;
p = t[p][c];

res += 2ll*cnt[p], cnt[p]++;
}
res += (ll)end[p], end[p]++;
}
};

Trie trie;

int main() {
freopen("input.txt", "r", stdin);
int T = 0;
while (scanf("%d", &n) == 1 && n) {
printf("Case %d: ", ++T);
// init
trie.clear();

string str;
ll res = 0;
for (int i = 1; i <= n; i++) {
cin >> str;
trie.insert(str, i-1, res);
}
printf("%lld\n", res);
}
}

Trie 计数问题

Gym10085D

考虑两个字符串 S1,S2S_1, S_2,它们所有字符都不相同
可以将所有字符串正着插入一个 trie1\text{trie}_1 中,再将所有串反着插入 trie2\text{trie}_2
trie1\text{trie}_1 中的节点数为 tot1tot_1tot11tot_1-1 表示一共有多少个不同的非空前缀
同理,tot21tot_2 - 1 表示这些串构成多少个不同的非空后缀
所以 res=(tot11)(tot21)res = (tot_1 - 1) \cdot (tot_2 - 1)

下面考虑重复的情况
形如 S=S1+c+S2S = S_1 + c + S_2 的串,正向插入 trie1\text{trie}_1 的时候被统计了一次
反向插入 trie2\text{trie}_2 的时候也被统计了一次,也就是说
每出现一个长度 3\geqslant 3 的串 S1+c+S2S_1 + c + S_2,那么就多统计了一次

问题转换为,原字典集合中,有多少个长度 3\geqslant 3 的,形如 S1+c+S2S_1 + c + S_2 的串

  • cnt1(c)\text{cnt}_1(c) 维护有多少个以字符 cc 结尾的前缀
  • cnt2(c)\text{cnt}_2(c) 维护有多少个以字符 cc 结尾的后缀

于是可以边插入边维护,根节点的深度假设为 d=0d = 0

  • 正向插入的时候,对于 i>0i > 0 并且 t(p,Si)=0t(p, S_i) = 0 时候,此时需要动态开点
    即当第一次在 trie1\text{trie}_1 树深度 d2d \geqslant 2 的节点加入字符 SiS_i 时,
    SiS_i 作为前缀结尾的子串个数 cnt1(Si)++\text{cnt}_1(S_i)++
    注意,对于深度为 dd 的节点,只在第一次出现字符 cc 的时候, 令 cnt1(c)++\text{cnt}_1(c)++
    当然,当深度 dd 变化的时候,再出现字符 cc,同样要统计 cnt1(c)++\text{cnt}_1(c)++
    这样 cnt1(c)\text{cnt}_1(c) 表示以 cc 结尾的不同前缀的个数
  • 反向插入同理
  • 最后的答案为,遍历所有的字符 c[0,26)c \in [0, 26)

    res=res(cnt1(c)1)(cnt2(c)1)res = res - \binom{\text{cnt}_1 (c)}{1} \cdot \binom{\text{cnt}_2 (c)}{1}

  • 注意当单词只有一个字符的时候,需要特判,此时当字符第一次出现的时候,令 res+=1res += 1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
const int maxn = 400000 + 10;
const int N = 10000 + 10;
string str[N];
int n, vis[27];

class Trie {
public:
int t1[maxn][27], t2[maxn][27];
int cnt1[27], cnt2[27];
int tot1, tot2;

Trie() {
tot1 = tot2 = 1;
memset(t1, 0, sizeof t1);
memset(t2, 0, sizeof t2);
memset(cnt1, 0, sizeof cnt1);
memset(cnt2, 0, sizeof cnt2);
}

void insert(const string &str) {
int p = 1;
for (int i = 0; i < str.length(); i++) {
int c = str[i] - 'a';
if (!t1[p][c]) {
if (i) cnt1[c]++;
t1[p][c] = ++tot1;
}
p = t1[p][c];
}
p = 1;
for (int i = str.length()-1; i >= 0; i--) {
int c = str[i] - 'a';
if (!t2[p][c]) {
// todo
if (i != str.length()-1) cnt2[c]++;
t2[p][c] = ++tot2;
}
p = t2[p][c];
}
}
ll query() {
ll res = 0;
res += 1ll * (tot1 - 1) * (tot2 - 1);
for (int i = 0; i < 26; i++) res -= 1ll * cnt1[i] * cnt2[i];
return res;
}
} trie;

int main() {
freopen("dictionary.in", "r", stdin);
freopen("dictionary.out", "w", stdout);
ios::sync_with_stdio(false);
cin.tie(0), cout.tie(0);
memset(vis, 0, sizeof vis);

cin >> n;
for (int i = 1; i <= n; i++) {
cin >> str[i];
trie.insert(str[i]);
}

ll res = 0;
for (int i = 1; i <= n; i++) {
const auto &S = str[i];
if (S.length() == 1 && vis[S[0] - 'a'] == 0) {
vis[S[0] - 'a'] = 1, res++;
}
}
res += trie.query();
cout << res << endl;
}