「HAOI2016」找相同字符

题目描述

给定两个字符串,求出在两个字符串中各取出一个子串使得这两个子串相同的方案数。两个方案不同当且仅当这两个子串中有一个位置不同。

输入格式

两行,两个字符串 \(s_1,s_2\),长度分别为\(n_1, n_2\)。字符串中只有小写字母。

输出格式

输出一个整数表示答案

样例

样例输入

aabb
bbaa

样例输出

10

数据范围与提示

\(1 \leq n_1, n_2 \leq 200000\)

内存限制: 256 MiB 时间限制: 1000 ms


传送门:

https://loj.ac/problem/2064 或 http://www.lydsy.com/JudgeOnline/problem.php?id=4566


网上大部分题解都是关于SAM的……

然而我又用SA水过了QAQ

从两个串中分别取出一个后缀,对答案的贡献是他们的\(\mathrm{LCP}\)的长度;并且枚举所有后缀就是枚举了所有情况,因为实质上我们枚举的是相同子串的开头。

这样,如果将两个串加入位置在\(\mathrm{split}\)的分隔符拼接在一起变为\(S\),答案变为了$$
\sum_{i = 1} ^ {\mathrm{split} – 1} \sum_{j = \mathrm{split} + 1}^{|S|} \mathrm{LCP}(i, j)
$$转化问题到后缀数组上,那么问题就转化为分别统计对于每个\(h_i\),当他作为最小值即\(\mathrm{LCP}\)时,对答案做贡献的次数\(\mathrm{cnt}_i\)。

如果只考虑一个点,那么也就是统计经过它的符合条件的区间数。这个点左右合法区间端点必然是一个连续的区间\([L_i, R_i]\),为了避免重复我们定义\(L_i\)为最小的使得\(\mathrm{min}(h_{[L_i, i)}) \gt h_i\)的值,\(R_i\)为最小的使得\(\mathrm{min}(h_{[i, R_i]}) \geq h_i\)的值,有$$
\mathrm{cnt}_i = \sum_{l = L_i} ^ {i – 1} \sum_{r = i} ^ {R_i} h_i \cdot [(\mathrm{sa}_l < \mathrm{split} \ \mathrm{and}\ \mathrm{sa}_r > \mathrm{split}) \ \mathrm{or}\ (\mathrm{sa}_l > \mathrm{split} \ \mathrm{and}\ \mathrm{sa}_r < \mathrm{split})]
$$两个\(\sum\)可以用乘法原理和前缀和\(O(1)\)求出,剩下的问题就是求\(L_i\)和\(R_i\)了,这是一个单调栈的经典问题。

所以我们就在求出后缀数组后,\(O(N)\)解决了这道题。

/* Never Say Die */
#include <iostream>
#include <cstdio>
#include <cstring>
#include <cmath>
#include <ctime>
#include <cstdlib>
#include <algorithm>
#include <queue>
#include <string>
#include <map>
#include <set>
#include <bitset>
using namespace std;
typedef long long LL;
typedef double D;
typedef pair<int, int> pii;
#define mp(a, b) make_pair(a, b)
#define fir first
#define sec second
inline int min(int a, int b) {return a < b ? a : b;}
inline int max(int a, int b) {return a > b ? a : b;}
int ch = 0;
template <class T> inline void read(T &a) {
	bool f = 0; a = 0;
	while (ch < '0' || ch > '9') {if (ch == '-') f = 1; ch = getchar();}
	while (ch >= '0' && ch <= '9') {a = a * 10 + ch - '0'; ch = getchar();}
	if (f) a = -a;
}

#define MAXN 400010

void Sort(int in[], int out[], int p[], int n, int m) {
	static int P[MAXN];
	for (int i = 1; i <= m; i++) P[i] = 0;
	for (int i = 1; i <= n; i++) P[in[i]]++;
	for (int i = 2; i <= m; i++) P[i] += P[i - 1];
	for (int i = n; i; i--) out[P[in[p[i]]]--] = p[i];
}

char s[MAXN];
int sa[MAXN], rk[MAXN], h[MAXN];

int n, l1;
void getsa() {
	static int t1[MAXN], t2[MAXN], *x = t1, *y = t2;
	int m = 127;
	for (int i = 1; i <= n; i++) x[i] = s[i], y[i] = i;
	Sort(x, sa, y, n, m);
	for (int j = 1, i, k = 0; k < n; m = k, j <<= 1) {
		for (i = n - j + 1, k = 0; i <= n; i++) y[++k] = i;
		for (i = 1; i <= n; i++) if (sa[i] > j) y[++k] = sa[i] - j;
		Sort(x, sa, y, n, m);
		for (swap(x, y), i = 2, x[sa[1]] = k = 1; i <= n; i++) {
			x[sa[i]] = (y[sa[i - 1]] == y[sa[i]] && y[sa[i - 1] + j] == y[sa[i] + j]) ? k : ++k;
		}
	}
	for (int i = 1; i <= n; i++) rk[sa[i]] = i;
	for (int i = 1, k = 0; i <= n; h[rk[i++]] = k) {
		k -= !!k;
		for (int j = sa[rk[i] - 1]; s[i + k] == s[j + k]; k++);
	}
}

int sta[MAXN], p[MAXN], L[MAXN], R[MAXN], s1[MAXN], s2[MAXN];

void work() {
	int top = 0; sta[0] = -1;
	for (int i = 1; i <= n; i++) {
		while (sta[top] >= h[i]) {
			R[p[top]] = i - 1;
			top--;
		}
		L[i] = p[top] + 1;
		sta[++top] = h[i], p[top] = i;
	}
	while (top) R[p[top--]] = n;
	for (int i = 1; i <= n; i++) s1[i] = s1[i - 1] + (sa[i] < l1);
	for (int i = 1; i <= n; i++) s2[i] = s2[i - 1] + (sa[i] > l1);
	LL ans = 0;
	for (int i = 2; i <= n; i++) {
		ans += (LL)h[i] * ((s2[i - 1] - s2[L[i] - 2]) * (s1[R[i]] - s1[i - 1]) + (s1[i - 1] - s1[L[i] - 2]) * (s2[R[i]] - s2[i - 1]));
	}
	printf("%lld\n", ans);
}

int main() {
	scanf("%s", s + 1);
	n = (int)strlen(s + 1);
	s[l1 = ++n] = '|';
	scanf("%s", s + n + 1);
	n += (int)strlen(s + n + 1);
	getsa();
	work();
	return 0;
}

QAQ这么搞我什么时候才能学会SAM啊……

发表评论

电子邮件地址不会被公开。 必填项已用*标注