题目描述

给定两个字符串,求出在两个字符串中各取出一个子串使得这两个子串相同的方案数。两个方案不同当且仅当这两个子串中有一个位置不同。

输入格式

两行,两个字符串 \(s_1,s_2\),长度分别为\(n_1, n_2\)。字符串中只有小写字母。

输出格式

输出一个整数表示答案

样例

样例输入

aabb
bbaa

样例输出

10

数据范围与提示

\(1 \leq n_1, n_2 \leq 200000\)

内存限制: 256 MiB 时间限制: 1000 ms


传送门:

https://loj.ac/problem/2064 或 http://www.lydsy.com/JudgeOnline/problem.php?id=4566


网上大部分题解都是关于SAM的……

然而我又用SA水过了QAQ

从两个串中分别取出一个后缀,对答案的贡献是他们的\(\mathrm{LCP}\)的长度;并且枚举所有后缀就是枚举了所有情况,因为实质上我们枚举的是相同子串的开头。

这样,如果将两个串加入位置在\(\mathrm{split}\)的分隔符拼接在一起变为\(S\),答案变为了$$
\sum_{i = 1} ^ {\mathrm{split} – 1} \sum_{j = \mathrm{split} + 1}^{|S|} \mathrm{LCP}(i, j)
$$转化问题到后缀数组上,那么问题就转化为分别统计对于每个\(h_i\),当他作为最小值即\(\mathrm{LCP}\)时,对答案做贡献的次数\(\mathrm{cnt}_i\)。

如果只考虑一个点,那么也就是统计经过它的符合条件的区间数。这个点左右合法区间端点必然是一个连续的区间\([L_i, R_i]\),为了避免重复我们定义\(L_i\)为最小的使得\(\mathrm{min}(h_{[L_i, i)}) \gt h_i\)的值,\(R_i\)为最小的使得\(\mathrm{min}(h_{[i, R_i]}) \geq h_i\)的值,有$$
\mathrm{cnt}_i = \sum_{l = L_i} ^ {i – 1} \sum_{r = i} ^ {R_i} h_i \cdot [(\mathrm{sa}_l < \mathrm{split} \ \mathrm{and}\ \mathrm{sa}_r > \mathrm{split}) \ \mathrm{or}\ (\mathrm{sa}_l > \mathrm{split} \ \mathrm{and}\ \mathrm{sa}_r < \mathrm{split})]
$$两个\(\sum\)可以用乘法原理和前缀和\(O(1)\)求出,剩下的问题就是求\(L_i\)和\(R_i\)了,这是一个单调栈的经典问题。

所以我们就在求出后缀数组后,\(O(N)\)解决了这道题。

/* Never Say Die */
#include <iostream>
#include <cstdio>
#include <cstring>
#include <cmath>
#include <ctime>
#include <cstdlib>
#include <algorithm>
#include <queue>
#include <string>
#include <map>
#include <set>
#include <bitset>
using namespace std;
typedef long long LL;
typedef double D;
typedef pair<int, int> pii;
#define mp(a, b) make_pair(a, b)
#define fir first
#define sec second
inline int min(int a, int b) {return a < b ? a : b;}
inline int max(int a, int b) {return a > b ? a : b;}
int ch = 0;
template <class T> inline void read(T &a) {
    bool f = 0; a = 0;
    while (ch < '0' || ch > '9') {if (ch == '-') f = 1; ch = getchar();}
    while (ch >= '0' && ch <= '9') {a = a * 10 + ch - '0'; ch = getchar();}
    if (f) a = -a;
}

#define MAXN 400010

void Sort(int in[], int out[], int p[], int n, int m) {
    static int P[MAXN];
    for (int i = 1; i <= m; i++) P[i] = 0;
    for (int i = 1; i <= n; i++) P[in[i]]++;
    for (int i = 2; i <= m; i++) P[i] += P[i - 1];
    for (int i = n; i; i--) out[P[in[p[i]]]--] = p[i];
}

char s[MAXN];
int sa[MAXN], rk[MAXN], h[MAXN];

int n, l1;
void getsa() {
    static int t1[MAXN], t2[MAXN], *x = t1, *y = t2;
    int m = 127;
    for (int i = 1; i <= n; i++) x[i] = s[i], y[i] = i;
    Sort(x, sa, y, n, m);
    for (int j = 1, i, k = 0; k < n; m = k, j <<= 1) {
        for (i = n - j + 1, k = 0; i <= n; i++) y[++k] = i;
        for (i = 1; i <= n; i++) if (sa[i] > j) y[++k] = sa[i] - j;
        Sort(x, sa, y, n, m);
        for (swap(x, y), i = 2, x[sa[1]] = k = 1; i <= n; i++) {
            x[sa[i]] = (y[sa[i - 1]] == y[sa[i]] && y[sa[i - 1] + j] == y[sa[i] + j]) ? k : ++k;
        }
    }
    for (int i = 1; i <= n; i++) rk[sa[i]] = i;
    for (int i = 1, k = 0; i <= n; h[rk[i++]] = k) {
        k -= !!k;
        for (int j = sa[rk[i] - 1]; s[i + k] == s[j + k]; k++);
    }
}

int sta[MAXN], p[MAXN], L[MAXN], R[MAXN], s1[MAXN], s2[MAXN];

void work() {
    int top = 0; sta[0] = -1;
    for (int i = 1; i <= n; i++) {
        while (sta[top] >= h[i]) {
            R[p[top]] = i - 1;
            top--;
        }
        L[i] = p[top] + 1;
        sta[++top] = h[i], p[top] = i;
    }
    while (top) R[p[top--]] = n;
    for (int i = 1; i <= n; i++) s1[i] = s1[i - 1] + (sa[i] < l1);
    for (int i = 1; i <= n; i++) s2[i] = s2[i - 1] + (sa[i] > l1);
    LL ans = 0;
    for (int i = 2; i <= n; i++) {
        ans += (LL)h[i] * ((s2[i - 1] - s2[L[i] - 2]) * (s1[R[i]] - s1[i - 1]) + (s1[i - 1] - s1[L[i] - 2]) * (s2[R[i]] - s2[i - 1]));
    }
    printf("%lld\n", ans);
}

int main() {
    scanf("%s", s + 1);
    n = (int)strlen(s + 1);
    s[l1 = ++n] = '|';
    scanf("%s", s + n + 1);
    n += (int)strlen(s + n + 1);
    getsa();
    work();
    return 0;
}

QAQ这么搞我什么时候才能学会SAM啊……

题目描述

对于一个给定长度为 $ N $ 的字符串,求它的第 $ K $ 小子串是什么。

输入格式

第一行是一个仅由小写英文字母构成的字符串 $ S $ 。 第二行为两个整数 $ T $ 和 $ K $ , $ T $ 为 0 则表示不同位置的相同子串算作一个。 $ T $ 为 1 则表示不同位置的相同子串算作多个。 $ K $ 的意义如题所述。

输出格式

输出仅一行,为一个数字串,为第 $ K $ 小的子串。如果子串数目不足 $ K $ 个,则输出 -1。

样例输入

aabc
0 3

样例输出

aab

数据范围与提示

$ N \leq 5 \times 10 ^ 5, T \lt 2, K \leq 10 ^ 9 $

内存限制: 256 MiB

时间限制: 1000 ms


传送门: https://loj.ac/problem/2102http://www.lydsy.com/JudgeOnline/problem.php?id=3998


首先这道题是一个后缀自动机模板题……网上关于SAM的题解又多又好,而且这道题也是我学习SAM的动力。

不过直到学习了SAM重新审视我当时写错的SA代码,突然意识到这道题可以用SA完成,虽然常数大,但是复杂度( $ O(N \log N) $ )比SAM( $ O(N |\alpha|) $ )优秀,跑得很快。

关于SAM的做法不再赘述了,我们这里考虑以SA的方式来思考这道题。 以下, $ \mathrm{sa}_i $ 表示字符串 $ S $ 的后缀从小到大排序后开始位置, $ \mathrm{len}_i $ 表示排名为 $ i $ 的后缀的长度( $ |S| – \mathrm{sa}_i + 1 $ ), $ h_i $ 表示第 $ i $ 大后缀与第 $ i – 1 $ 大子串的最长公共前缀(也叫 $ \mathrm{LCP} $ 或者 $ \mathrm{height} $ ),其中 $ h_1 = 0 $ 。 当 $ T = 0 $ 时,我们需要统计的是本质不同的字符串。显然,对于第 $ i $ 大的后缀,对答案的贡献只有 $ \mathrm{len}_i – h_i $ ,因为前一个串相同的部分已经被算过答案了,然后扫一遍就好了,复杂度 $ O(N) $ 。 当 $ T = 1 $ 时,我们不仅仅需要知道本质不同的字符串的个数,还需要知道依次出现了多少次以确定答案。

首先,一个非常直观的思路是,对于新出现的每个本质不同的字符串,如果长度在 $ h_i \sim h_{i+1}$ 范围内,可以在 $ h $ 数组上二分 + RMQ得到出现次数;对于 $ h $ 数组外的部分只会出现一次,直接统计即可。

但是这样做显然是不优秀的。我们知道求 $ h $ 数组时我们利用了“原串相邻的两个后缀在对应位置的 $ h $ 的值,靠后位置的最多比靠前位置的 $ h $ 小 $ 1 $ ”这个结论,然而这对 $ h $ 数组是不适用的。这里, $ \sum_{i = 2} ^ n max(h_i – h_{i – 1}, 0) $ 是可以被卡到 $ O(N ^ 2) $ 级别的。卡法也很简单,我们构造一个类似于 abc…xyzabc… 的字符串就可以把 $ \sum_{i = 2} ^ n \max(h_i – h_{i – 1}, 0) $ 卡到 $ O(N |\alpha|) $ 级别,我们可以把几个字符看做一个字符,最终就能达到 $ O(N ^ 2) $ 的效果了。 于是看起来, $ T = 1 $ 时复杂度变为了 $ O(\mathrm{min}(K, N ^ 2) \log N) $ ,显然这个复杂度是无法接受的。

那么我们就弃疗了?不。我最开始想卡这个 $ O(\mathrm{min}(K, N ^ 2) \log N) $ 暴力……然而发现卡不掉,就是因为我在分析时没有意识到的情况下加了一个非常简单的优化。 考虑优化。事实上我们发现,对于一个当前一个新的长度为 $ \mathrm{Len} $ 二分的过程结束后会得到一个值 $ \mathrm{Max} $ ,表示出现了当前次数的字符串长度的最大值。这两个值表明,当前位置上字符串长度为 $ [\mathrm{Len},\mathrm{Max}] $ ,均出现了同样次数。那么我们将这一部分答案一并计算,最后将寻找下一个字符串从 $ \mathrm{Len}++ $ 变为 $ \mathrm{Len} = \mathrm{Max} + 1 $ 。 这么做,将这一步时间复杂度变为了 $ O(N \log N) $ ;将二分替换为单调栈,这一步的时间复杂度变为 $ O(N) $ 。 其实,这个复杂度证明也很简单……看起来就像最大矩形面积,如果出现次数变少,一定是之后某一个位置上的高度非常小“卡住了”。因为我们求的是本质不同的字符串数量,所以对于 $ h $ 数组的每个元素,最多会“卡”前面的连续段一次。这样如果采用二分自然就是 $ O(N \log N) $ 了。发现这个性质后其实可以用单调栈,每次弹出之前的数时,用链表/vector插入出现次数到前面所弹出的位置上;因为最多弹 $ O(N) $ 次,所以插入的数也只有 $ O(N) $ 个。 由于求后缀数组也是可以 $ O(N) $ 的,所以总复杂度也可以是 $ O(N) $ 的了。

这样,我们就在 $ O(N) $ 的时间复杂度内解决了本题;相比起后缀自动机,与字符集大小无关是这个做法更优越的一点。(由于我不会DC3我总复杂度还是只能 $ O(N \log N) $ TAT但是还是跑得速度不慢啦) 二分RMQ代码如下:(这道题在BZOJ上需要交换st数组下标来卡常TAT)

/* Never Say Die */
#include <iostream>
#include <cstdio>
#include <cstring>
#include <cmath>
#include <ctime>
#include <cstdlib>
#include <algorithm>
#include <queue>
#include <string>
#include <map>
#include <set>
#include <bitset>
using namespace std;
typedef long long LL;
typedef double D;
typedef pair<int, int> pii;
#define mp(a, b) make_pair(a, b)
#define fir first
#define sec second
inline int min(int a, int b) {return a < b ? a : b;}
inline int max(int a, int b) {return a > b ? a : b;}
int ch = 0;
template <class T> inline void read(T &a) {
    bool f = 0; a = 0;
    while (ch < '0' || ch > '9') {if (ch == '-') f = 1; ch = getchar();}
    while (ch >= '0' && ch <= '9') {a = a * 10 + ch - '0'; ch = getchar();}
    if (f) a = -a;
}
#define MAXN 500319
char s[MAXN];
int n, T, K;
void Sort(int in[], int out[], int p[], int m) {
    static int P[MAXN]; register int i;
    for (i = 1; i <= m; i++) P[i] = 0;
    for (i = 1; i <= n; i++) P[in[i]]++;
    for (i = 2; i <= m; i++) P[i] += P[i - 1];
    for (i = n; i; i--) out[P[in[p[i]]]--] = p[i];
}
int sa[MAXN], rk[MAXN], h[MAXN];
int Min[20][MAXN], Log[MAXN];
void getsa() {
    int m = 26;
    static int t1[MAXN], t2[MAXN], *x = t1, *y = t2;
    for (int i = 1; i <= n; i++) x[i] = s[i] - 'a' + 1, y[i] = i;
    Sort(x, sa, y, m);
    for (int j = 1, k = 0, i; k < n; m = k, j <<= 1) {
        for (k = 0, i = n - j + 1; i <= n; i++) y[++k] = i;
        for (i = 1; i <= n; i++) if (sa[i] > j) y[++k] = sa[i] - j;
        Sort(x, sa, y, m);
        for (swap(x, y), i = 2, k = x[sa[1]] = 1; i <= n; i++) {
            x[sa[i]] = (y[sa[i]] == y[sa[i - 1]] && y[sa[i] + j] == y[sa[i - 1] + j]) ? k : ++k;
        }
    }
    for (int i = 1; i <= n; i++) rk[sa[i]] = i;
    for (int i = 1, k = 0; i <= n; h[rk[i++]] = k) {
        k -= !!k;
        for (int j = sa[rk[i] - 1]; s[i + k] == s[j + k]; k++);
    }
    Log[0] = -1;
    for (int i = 1; i <= n; i++) Log[i] = Log[i >> 1] + 1;
    for (int i = 1; i <= n; i++) Min[0][i] = h[i];
    for (int k = 1; k <= Log[n]; k++) {
        for (int i = 1, d = 1 << (k - 1), j = n - (d << 1) + 1; i <= j; i++) {
            Min[k][i] = min(Min[k - 1][i], Min[k - 1][i + d]);
        }
    }
}
void work1() {
    int tot = 0;
    for (int i = 1, j; i <= n; i++) {
        for (j = h[i] + 1; j <= h[i + 1]; j++) {
            int x = i, mi = h[i + 1];
            for (int k = Log[n - i + 1]; ~k; k--) if (Min[k][x + 1] >= j) mi = min(mi, Min[k][x + 1]), x += 1 << k;
            int d = x - i + 1, cnt = mi - j + 1;
            if (K - tot - 1 <= (LL)cnt * d) {
                j += (K - tot - 1) / d;
                for (int l = 0; l < j; l++) putchar(s[sa[i] + l]);
                putchar('\n');
                return;
            }
            tot += d * cnt;
            j = mi;
        }
        int dev = n - sa[i] - j + 2;
        if (tot + dev >= K) {
            int d = int(K - tot + j - 1);
            for (j = 0; j < d; j++) putchar(s[sa[i] + j]);
            putchar('\n');
            return;
        }
        tot += dev;
    }
    puts(-1);
}
void work0() {
    int tot = 0;
    for (int i = 1; i <= n; i++) {
        int dev = n - sa[i] - h[i] + 1;
        if (tot + dev >= K) {
            int d = int(K - tot + h[i]);
            for (int j = 0; j < d; j++) putchar(s[sa[i] + j]);
            putchar('\n');
            return;
        }
        tot += dev;
    }
    puts(-1);
}
int main() {
    scanf(%s, s + 1);
    read(T); read(K);
    n = (int)strlen(s + 1);
    getsa();
    if (T) work1();
    else work0();
    return 0;
}

单调栈代码如下:

/* Never Say Die */
#include <iostream>
#include <cstdio>
#include <cstring>
#include <cmath>
#include <ctime>
#include <cstdlib>
#include <algorithm>
#include <queue>
#include <string>
#include <map>
#include <set>
#include <bitset>
#include <list>
using namespace std;
typedef long long LL;
typedef double D;
typedef pair<int, int> pii;
#define mp(a, b) make_pair(a, b)
#define fir first
#define sec second
inline int min(int a, int b) {return a < b ? a : b;}
inline int max(int a, int b) {return a > b ? a : b;}
int ch = 0;
template <class T> inline void read(T &a) {
    bool f = 0; a = 0;
    while (ch < '0' || ch > '9') {if (ch == '-') f = 1; ch = getchar();}
    while (ch >= '0' && ch <= '9') {a = a * 10 + ch - '0'; ch = getchar();}
    if (f) a = -a;
}
#define MAXN 500319
char s[MAXN];
int n, T, K;
void Sort(int in[], int out[], int p[], int m) {
    static int P[MAXN]; register int i;
    for (i = 1; i <= m; i++) P[i] = 0;
    for (i = 1; i <= n; i++) P[in[i]]++;
    for (i = 2; i <= m; i++) P[i] += P[i - 1];
    for (i = n; i; i--) out[P[in[p[i]]]--] = p[i];
}
int sa[MAXN], rk[MAXN], h[MAXN];
void getsa() {
    int m = 26;
    static int t1[MAXN], t2[MAXN], *x = t1, *y = t2;
    for (int i = 1; i <= n; i++) x[i] = s[i] - 'a' + 1, y[i] = i;
    Sort(x, sa, y, m);
    for (int j = 1, k = 0, i; k < n; m = k, j <<= 1) {
        for (k = 0, i = n - j + 1; i <= n; i++) y[++k] = i;
        for (i = 1; i <= n; i++) if (sa[i] > j) y[++k] = sa[i] - j;
        Sort(x, sa, y, m);
        for (swap(x, y), i = 2, k = x[sa[1]] = 1; i <= n; i++) {
            x[sa[i]] = (y[sa[i]] == y[sa[i - 1]] && y[sa[i] + j] == y[sa[i - 1] + j]) ? k : ++k;
        }
    }
    for (int i = 1; i <= n; i++) rk[sa[i]] = i;
    for (int i = 1, k = 0; i <= n; h[rk[i++]] = k) {
        k -= !!k;
        for (int j = sa[rk[i] - 1]; s[i + k] == s[j + k]; k++);
    }
}
struct Node {
    int fir, sec, nxt;
} a[MAXN];
int head[MAXN], acnt = 1, sta[MAXN], p[MAXN];
void work1() {
    int top = 0; sta[0] = -1;
    for (int i = 1; i <= n; i++) {
        int r = i;
        while (sta[top] >= h[i]) {
            r = min(r, p[top]);
            if (sta[top] > h[i]) a[++acnt] = (Node){sta[top], i - p[top], head[p[top]]}, head[p[top]] = acnt;
            top--;
        }
        sta[++top] = h[i], p[top] = r;
    }
    while (top && sta[top] > 0) {
        a[++acnt] = (Node){sta[top], n + 1 - p[top], head[p[top]]}, head[p[top]] = acnt;
        top--;
    }
    int tot = 0;
    for (int i = 1; i <= n; i++) {
        int j = h[i] + 1;
        for (int now = head[i + 1]; now; now = a[now].nxt) {
            int d = a[now].sec + 1, cnt = a[now].fir - j + 1;
            if (K - tot - 1 <= (LL) cnt * d) {
                j += (K - tot - 1) / d;
                for (int l = 0; l < j; l++) putchar(s[sa[i] + l]);
                putchar('\n');
                return;
            }
            tot += d * cnt;
            j += cnt;
        }
        int dev = n - sa[i] - j + 2;
        if (tot + dev >= K) {
            int d = int(K - tot + j - 1);
            for (j = 0; j < d; j++) putchar(s[sa[i] + j]);
            putchar('\n');
            return;
        }
        tot += dev;
    }
    puts(-1);
}
void work0() {
    int tot = 0;
    for (int i = 1; i <= n; i++) {
        int dev = n - sa[i] - h[i] + 1;
        if (tot + dev >= K) {
            int d = int(K - tot + h[i]);
            for (int j = 0; j < d; j++) putchar(s[sa[i] + j]);
            putchar('\n');
            return;
        }
        tot += dev;
    }
    puts(-1);
}
int main() {
    scanf(%s, s + 1);
    read(T); read(K);
    n = (int)strlen(s + 1);
    getsa();
    if (T) work1();
    else work0();
    return 0;
}

跑得比最快的大爷慢,在LOJ上交了一发发现预处理后缀数组用了80%的时间…TAT。