void insert(int pre, int p, int ver) { dfn[p] = ver; for (int k = 25; k >= 0; k--) { int c = s[ver] >> k & 1; t[p][c^1] = t[pre][c^1]; t[p][c] = ++tot; p = t[p][c], pre = t[pre][c]; dfn[p] = ver; } }
int ask(int p, int val, int lim) { for (int k = 25; k >= 0; k--) { int c = val >> k & 1; if (dfn[ t[p][c^1] ] >= lim) p = t[p][c^1]; else p = t[p][c]; } return s[dfn[p]] ^ val; } } trie;
int main() { freopen("input.txt", "r", stdin); cin >> n >> m;
// init for (int i = 1; i <= n; i++) { int x; scanf("%d", &x); s[i] = s[i-1] ^ x; root[i] = ++trie.tot; trie.insert(root[i-1], root[i], i); } while (m--) { char cmd[2]; scanf("%s", cmd); if (cmd[0] == 'A') { int x; scanf("%d", &x); root[++n] = ++trie.tot; s[n] = s[n-1] ^ x; trie.insert(root[n-1], root[n], n); } else { int l, r, x; scanf("%d%d%d", &l, &r, &x); int res = trie.ask(root[r-1], s[n]^x, l-1); printf("%d\n", res); } } }
void insert(int p, int q, int val) { ver[q] = ver[p] + 1; for (int k = 16; k >= 0; k--) { int c = val >> k & 1; t[q][c^1] = t[p][c^1]; t[q][c] = ++tot;
int query(int fa, int p, int val) { int res = 0; for (int k = 16; k >= 0; k--) { int c = val >> k & 1; if (ver[ t[fa][c^1] ] < ver[ t[p][c^1] ]) { res += (1<<k); fa = t[fa][c^1], p = t[p][c^1]; } else fa = t[fa][c], p = t[p][c]; } return res; } } trie; int fa[N][H], dep[N]; void dfs(int u, int pa) { fa[u][0] = pa, dep[u] = dep[pa] + 1; for (int i = 1; i < H; i++) fa[u][i] = fa[fa[u][i-1]][i-1]; root[u] = ++trie.tot, trie.insert(root[pa], root[u], a[u]); for (int i = G.head[u]; i; i = G.ne[i]) { int v = G.ver[i]; if (v == pa) continue; dfs(v, u); } } int lca(int x, int y) { if (dep[y] < dep[x]) swap(x, y); for (int i = H-1; i >= 0; i--) { if (dep[ fa[y][i] ] >= dep[x]) y = fa[y][i]; } if (y == x) return y; for (int i = H-1; i >= 0; i--) { if (fa[y][i] != fa[x][i]) y = fa[y][i], x = fa[x][i]; } return fa[x][0]; } int main() { freopen("input.txt", "r", stdin); while (~scanf("%d%d", &n, &m)) { // init trie.clear(), G.clear(); memset(root, 0, sizeof root); memset(fa, 0, sizeof fa); memset(dep, 0, sizeof dep); for (int i = 1; i <= n; i++) scanf("%d", &a[i]); for (int i = 0; i < n-1; i++) { int u, v; scanf("%d%d", &u, &v); G.add(u, v), G.add(v, u); } // dfs dfs(1, 0); // query while (m--) { int x, y, val; scanf("%d%d%d", &x, &y, &val); int f = fa[lca(x, y)][0]; int res = max(trie.query(root[f], root[x], val), trie.query(root[f], root[y], val)); printf("%d\n", res); } } }
typedef pair<ll, int> PII; const int maxn = 500000 + 10, N = maxn * 35; const int H = 33; int n, k, rk[maxn], root[maxn]; ll s[maxn]; priority_queue<PII> heap;
void insert(int pre, int p, int H, ll val) { if (H < 0) { sz[p] = sz[pre] + 1; return; } int c = val >> H & 1; if (pre) t[p][c^1] = t[pre][c^1]; t[p][c] = ++tot; insert(t[pre][c], t[p][c], H-1, val); sz[p] = sz[t[p][c]] + sz[t[p][c^1]]; }
void ask(int p, int rk, int H, ll val, ll &res) { if (H < 0) return; int c = val >> H & 1; if (sz[ t[p][c^1] ] >= rk) { res = (res << 1 | 1); ask(t[p][c^1], rk, H-1, val, res); } else { res <<= 1; ask(t[p][c], rk - sz[t[p][c^1]], H-1, val, res); } } } trie;
void solve() { for (int i = 1; i <= n; i++) { ll res = 0; trie.ask(root[i-1], rk[i], H, s[i], res); heap.push({res, i}); } ll ans = 0; while (k--) { auto x = heap.top(); heap.pop(); ans += x.first; int r = x.second; ll res = 0; trie.ask(root[r-1], ++rk[r], H, s[r], res); heap.push({res, r}); } printf("%lld\n", ans); }
void find(const string &str) { int p = 0; for (auto x : str) { int c = x-'a'; p = t[p][c]; if (val[p]) dfs(p); elseif (last[p]) dfs(last[p]); } } } ac;
void solve() { int res = -1; for (int i = 1; i <= n; i++) res = max(res, cnt[i]); printf("%d\n", res); for (int i = 1; i <= n; i++) if (cnt[M[P[i]]] == res) { printf("%s\n", P[i].c_str()); } }
typedef unsigned long long ull; const int P1 = 13331, P2 = 131; int n1, m1, n2, m2; const int maxn = 1000 + 10; char s1[maxn][maxn], s2[maxn][maxn]; ull h1[maxn][maxn], h2[maxn][maxn], p1[maxn*maxn], p2[maxn*maxn];
void pre() { p1[0] = p2[0] = 1; for (int i = 1; i <= maxn; i++) { p1[i] = p1[i-1] * P1; p2[i] = p2[i-1] * P2; } }
void getHash(const char s1[][maxn], ull h[][maxn], int n, int m) { for (int i = 1; i <= n; i++) { for (int j = 1; j <= m; j++) { h[i][j] = h[i-1][j] * P1 + h[i][j-1] * P2 - h[i-1][j-1] * P1 * P2 + (s1[i][j]-'a'); } } }
ull Hash(const ull h[][maxn], int x1, int y1, int x2, int y2) { return h[x2][y2] - h[x1-1][y2] * p1[x2-x1+1] - h[x2][y1-1] * p2[y2-y1+1] + h[x1-1][y1-1] * p1[x2-x1+1] * p2[y2-y1+1]; }
void solve() { ll ans = 0; for (int i = 1; i + n2 - 1 <= n1; i++) { for (int j = 1; j + m2 - 1 <= m1; j++) { if (Hash(h1, i, j, i+n2-1, j+m2-1) == h2[n2][m2]) ans++; } } printf("%lld\n", ans); }
int main() { freopen("input.txt", "r", stdin); // get mi pre();
int T; cin >> T; while (T--) { // init scanf("%d%d", &n1, &m1); for (int i = 1; i <= n1; i++) scanf("%s", s1[i]+1); getHash(s1, h1, n1, m1);
scanf("%d%d", &n2, &m2); for (int i = 1; i <= n2; i++) scanf("%s", s2[i]+1); getHash(s2, h2, n2, m2);
// then solve ull res = Hash(h1, 0, 0, 1, 1); solve(); } }
AC 自动机实现二维匹配
矩阵 P(x×y),T(n×m)
很容易想到,将 P 的第 i 行,∀i∈[1,x],P[i] 插入 AC 自动机中
执行 insert(P[i],i),这里需要在行结尾(即字符串末尾节点)u,维护一个 vector vec[u]={⋯},存储 u这个点是哪些行的结尾?
如果遍历 AC 自动机走到了某个行尾节点,vec[u] 表示 P 中 能和 T匹配上的行有哪些
查询的时候,对 T 的每一行执行 find(T[u])
接着根据字符串 T[u] 执行标准的 AC 自动机查找 len(T[u])=m,∀i∈[1⋯len(T[u])],遍历 AC 自动机
从 p=0 开始沿着 c=T[u][i] 走
对于某个 i,此时走到了 AC 自动机的 p 节点
如果 p 为串尾节点,val(p)=0,那么遍历 vec[p],∀r∈vec[p]
此时 P 中第 r 行能够匹配上,也就是说
以 T(u−r+1,i−y+1) 为左上角的 x×y 矩形,能匹配上的行数 +1
即 cnt(u−r+1,i−y+1)+=1
const int maxn = 1000 + 10; const int N = 10000 + 10, SZ = 27; char P[maxn][maxn], T[maxn][maxn]; int n, m, x, y, cnt[maxn][maxn];
class AC { public: int t[N][SZ], fail[N], val[N], last[N]; vector<int> vec[N]; int tot = 0;
void clear() { tot = 0; for (int i = 0; i < N; i++) vec[i].clear(); memset(t, 0, sizeof t); memset(fail, 0, sizeof fail); memset(val, 0, sizeof val); memset(last, 0, sizeof last); }
void insert(const char *str, int idx) { int p = 0; assert(strlen(str) == y); for (int i = 0; i < strlen(str); i++) { int c = str[i] - 'a'; if (!t[p][c]) t[p][c] = ++tot; p = t[p][c]; } val[p] = idx, vec[p].push_back(idx); }
void build() { queue<int> que; for (int i = 0; i < SZ; i++) { if (t[0][i]) que.push(t[0][i]); }
while (que.size()) { auto u = que.front(); que.pop(); for (int i = 0; i < SZ; i++) { if (t[u][i]) { fail[t[u][i]] = t[fail[u]][i]; que.push(t[u][i]); } else t[u][i] = t[fail[u]][i]; } last[u] = val[fail[u]] ? fail[u] : last[fail[u]]; } }
void find(int u) { int p = 0; for (int i = 1; i <= m; i++) { int c = T[u][i] - 'a'; p = t[p][c]; if (val[p]) { for (auto r : vec[p]) { if (u-r+1 >= 1) ++cnt[u-r+1][i-y+1]; } } elseif (last[p]) { for (auto r : vec[last[p]]) { if (u-r+1 >= 1) ++cnt[u-r+1][i-y+1]; } } } } } ac;
void solve() { for (int u = 1; u <= n; u++) ac.find(u); int ans = 0; for (int i = 1; i + x - 1 <= n; i++) { for (int j = 1; j + y - 1 <= m; j++) { if (cnt[i][j] == x) ans++; } } printf("%d\n", ans); }
int main() { freopen("input.txt", "r", stdin); int kase; cin >> kase; while (kase--) { // init ac.clear(); memset(cnt, 0, sizeof cnt);
scanf("%d%d", &n, &m); for (int i = 1; i <= n; i++) scanf("%s", T[i]+1), assert(strlen(T[i]+1) == m);
scanf("%d%d", &x, &y); for (int i = 1; i <= x; i++) { scanf("%s", P[i]+1); ac.insert(P[i]+1, i); }
if (n < x || m < y) { puts("0"); continue; }
// build ac.build();
// solve solve(); } }
AC 自动机和 dp
AC 自动机和 dp 有关的算法,经常需要借助 last 指针写状态转移方程 NVWLS
给出一个字典,每个单词去掉元音字母 A, E, I, O, U 之后形成了一个新字典
先给出一个只有辅音字母的串,用原字典(包含元音字母)的单词还原该串
如果存在多种还原方式,输出还原后元音字母数量最多的结果
很容易想到,将初始单词插入 AC 自动机中,不插入元音字母
for∀i 遍历模版串 str(i),并且找到 AC 自动机中表示 str(i) 的节点 p for∀k∈p→last(p),如果 k 为单词节点
AC 自动机维护单词节点编号 id,End(k)=id
同时还需要维护 cnt(id),表示 id 这个单词有多少个元音字母 len(id) 表示 id 这个单词除掉元音字母的长度,由此可以写出状态转移方程 f(i) 表示 str(i) 这个位置还原补上元音字母后,str[0⋯i] 的最长长度,那么 f(i)=max(f(i−len[id]))+cnt(id) dp 更新的时候记录状态 used(i)=id,表示 i 这个位置用了 id 这个单词
输出结果的时候递归输出 print(p−len(used(p)))
另外,值得注意的是,有可能单词去掉元音字母后,得到一样的单词
但是,cnt(A)>cnt(B),A 还原成的单词元音字母大于 B 的
在 AC 自动机中维护 End(p)=id 以及 val(p) 的最大值
其中 val(p) 表示 p 这个节点能够还原成的最多的元音字母数
void insert(const string &str, int id) { int p = 0; for (auto x : str) { int c = x - 'A'; if (!t[p][c]) t[p][c] = ++tot; p = t[p][c]; } if (cnt[id] >= cnt[End[p]]) End[p] = id; }
void build() { queue<int> que; for (int i = 0; i < SZ; i++) { if (t[0][i]) que.push(t[0][i]); } while (que.size()) { auto u = que.front(); que.pop(); for (int i = 0; i < SZ; i++) { if (t[u][i]) { fail[t[u][i]] = t[fail[u]][i]; que.push(t[u][i]); } else t[u][i] = t[fail[u]][i]; } last[u] = End[fail[u]] ? fail[u] : last[fail[u]]; } } } ac;
void dp(const char *str) { int N = strlen(str+1); for (int i = 1; i <= N; i++) f[i] = -1; f[0] = 0;
int p = 0; for (int i = 1; i <= N; i++) { int c = str[i] - 'A'; p = ac.t[p][c];
for (int j = p; j; j = ac.last[j]) { if (!ac.End[j]) j = ac.last[j]; int id = ac.End[j]; if (id && f[i - len[id]] != -1 && f[i - len[id]] + cnt[id] > f[i]) { f[i] = f[i-len[id]] + cnt[id], used[i] = id; } } } }
void print(const int p) { if (p < 1) return; print(p - len[used[p]]); printf("%s ", S[used[p]].c_str()); }