题目描述
给定两个字符串,求出在两个字符串中各取出一个子串使得这两个子串相同的方案数。两个方案不同当且仅当这两个子串中有一个位置不同。
输入格式
两行,两个字符串 \(s_1,s_2\),长度分别为\(n_1, n_2\)。字符串中只有小写字母。
输出格式
输出一个整数表示答案
样例
样例输入
aabb
bbaa
样例输出
10
数据范围与提示
\(1 \leq n_1, n_2 \leq 200000\)
内存限制: 256 MiB 时间限制: 1000 ms
传送门:
https://loj.ac/problem/2064 或 http://www.lydsy.com/JudgeOnline/problem.php?id=4566
网上大部分题解都是关于SAM的……
然而我又用SA水过了QAQ
从两个串中分别取出一个后缀,对答案的贡献是他们的\(\mathrm{LCP}\)的长度;并且枚举所有后缀就是枚举了所有情况,因为实质上我们枚举的是相同子串的开头。
这样,如果将两个串加入位置在\(\mathrm{split}\)的分隔符拼接在一起变为\(S\),答案变为了$$
\sum_{i = 1} ^ {\mathrm{split} – 1} \sum_{j = \mathrm{split} + 1}^{|S|} \mathrm{LCP}(i, j)
$$转化问题到后缀数组上,那么问题就转化为分别统计对于每个\(h_i\),当他作为最小值即\(\mathrm{LCP}\)时,对答案做贡献的次数\(\mathrm{cnt}_i\)。
如果只考虑一个点,那么也就是统计经过它的符合条件的区间数。这个点左右合法区间端点必然是一个连续的区间\([L_i, R_i]\),为了避免重复我们定义\(L_i\)为最小的使得\(\mathrm{min}(h_{[L_i, i)}) \gt h_i\)的值,\(R_i\)为最小的使得\(\mathrm{min}(h_{[i, R_i]}) \geq h_i\)的值,有$$
\mathrm{cnt}_i = \sum_{l = L_i} ^ {i – 1} \sum_{r = i} ^ {R_i} h_i \cdot [(\mathrm{sa}_l < \mathrm{split} \ \mathrm{and}\ \mathrm{sa}_r > \mathrm{split}) \ \mathrm{or}\ (\mathrm{sa}_l > \mathrm{split} \ \mathrm{and}\ \mathrm{sa}_r < \mathrm{split})]
$$两个\(\sum\)可以用乘法原理和前缀和\(O(1)\)求出,剩下的问题就是求\(L_i\)和\(R_i\)了,这是一个单调栈的经典问题。
所以我们就在求出后缀数组后,\(O(N)\)解决了这道题。
/* Never Say Die */ #include <iostream> #include <cstdio> #include <cstring> #include <cmath> #include <ctime> #include <cstdlib> #include <algorithm> #include <queue> #include <string> #include <map> #include <set> #include <bitset> using namespace std; typedef long long LL; typedef double D; typedef pair<int, int> pii; #define mp(a, b) make_pair(a, b) #define fir first #define sec second inline int min(int a, int b) {return a < b ? a : b;} inline int max(int a, int b) {return a > b ? a : b;} int ch = 0; template <class T> inline void read(T &a) { bool f = 0; a = 0; while (ch < '0' || ch > '9') {if (ch == '-') f = 1; ch = getchar();} while (ch >= '0' && ch <= '9') {a = a * 10 + ch - '0'; ch = getchar();} if (f) a = -a; } #define MAXN 400010 void Sort(int in[], int out[], int p[], int n, int m) { static int P[MAXN]; for (int i = 1; i <= m; i++) P[i] = 0; for (int i = 1; i <= n; i++) P[in[i]]++; for (int i = 2; i <= m; i++) P[i] += P[i - 1]; for (int i = n; i; i--) out[P[in[p[i]]]--] = p[i]; } char s[MAXN]; int sa[MAXN], rk[MAXN], h[MAXN]; int n, l1; void getsa() { static int t1[MAXN], t2[MAXN], *x = t1, *y = t2; int m = 127; for (int i = 1; i <= n; i++) x[i] = s[i], y[i] = i; Sort(x, sa, y, n, m); for (int j = 1, i, k = 0; k < n; m = k, j <<= 1) { for (i = n - j + 1, k = 0; i <= n; i++) y[++k] = i; for (i = 1; i <= n; i++) if (sa[i] > j) y[++k] = sa[i] - j; Sort(x, sa, y, n, m); for (swap(x, y), i = 2, x[sa[1]] = k = 1; i <= n; i++) { x[sa[i]] = (y[sa[i - 1]] == y[sa[i]] && y[sa[i - 1] + j] == y[sa[i] + j]) ? k : ++k; } } for (int i = 1; i <= n; i++) rk[sa[i]] = i; for (int i = 1, k = 0; i <= n; h[rk[i++]] = k) { k -= !!k; for (int j = sa[rk[i] - 1]; s[i + k] == s[j + k]; k++); } } int sta[MAXN], p[MAXN], L[MAXN], R[MAXN], s1[MAXN], s2[MAXN]; void work() { int top = 0; sta[0] = -1; for (int i = 1; i <= n; i++) { while (sta[top] >= h[i]) { R[p[top]] = i - 1; top--; } L[i] = p[top] + 1; sta[++top] = h[i], p[top] = i; } while (top) R[p[top--]] = n; for (int i = 1; i <= n; i++) s1[i] = s1[i - 1] + (sa[i] < l1); for (int i = 1; i <= n; i++) s2[i] = s2[i - 1] + (sa[i] > l1); LL ans = 0; for (int i = 2; i <= n; i++) { ans += (LL)h[i] * ((s2[i - 1] - s2[L[i] - 2]) * (s1[R[i]] - s1[i - 1]) + (s1[i - 1] - s1[L[i] - 2]) * (s2[R[i]] - s2[i - 1])); } printf("%lld\n", ans); } int main() { scanf("%s", s + 1); n = (int)strlen(s + 1); s[l1 = ++n] = '|'; scanf("%s", s + n + 1); n += (int)strlen(s + n + 1); getsa(); work(); return 0; }
QAQ这么搞我什么时候才能学会SAM啊……