题目地址:P4070 [SDOI2016]生成魔咒
相信看到题目之后很多人跟我的思路是一样的——
肯定要用 SA(P3809 【模板】后缀排序)
肯定要会求本质不同的子串个数(P2408 不同子串个数)
然后?就不会了......
瓶颈在哪儿?
你会发现每往后添加一个字符,整个 sa 数组只会插入一个数,要维护不难
但是 height 会无规律变化,这就导致无法高效维护
怎么办呢?
倒置字符串
我们将整个字符串倒置过来
显然本质不同的子串个数不会变化
而每往前添加一个字符串, height 的变化是 \(O(1)\) 的
那么,问题就变得简单很多了
具体实现请看代码注释
#include <bits/stdc++.h>
#define ll long long
#define si set<int>::iterator
using namespace std;
const int N = 1e5 + 6;
int n, m, a[N], b[N];
int sa[N], rk[N], tp[N], tx[N], he[N], st[N][20];
ll ans;
set<int> s;
inline void tsort() {//基数排序
for (int i = 1; i <= m; i++) tx[i] = 0;
for (int i = 1; i <= n; i++) ++tx[rk[i]];
for (int i = 1; i <= m; i++) tx[i] += tx[i-1];
for (int i = n; i; i--) sa[tx[rk[tp[i]]]--] = tp[i];
}
inline bool pd(int i, int w) {
return tp[sa[i-1]] == tp[sa[i]] && tp[sa[i-1]+w] == tp[sa[i]+w];
}
inline void SA() {//后缀数组板子
for (int i = 1; i <= n; i++) {
rk[i] = a[i] = lower_bound(b + 1, b + m + 1, a[i]) - b;
tp[i] = i;
}
tsort();
for (int w = 1, p = 0; p < n; m = p, w <<= 1) {
p = 0;
for (int i = 1; i <= w; i++) tp[++p] = n - w + i;
for (int i = 1; i <= n; i++)
if (sa[i] > w) tp[++p] = sa[i] - w;
tsort();
swap(rk, tp);
rk[sa[1]] = p = 1;
for (int i = 2; i <= n; i++)
rk[sa[i]] = pd(i, w) ? p : ++p;
}
int p = 0;
for (int i = 1; i <= n; i++) {
if (p) --p;
int j = sa[rk[i]-1];
while (a[i+p] == a[j+p]) ++p;
he[rk[i]] = p;
}
}
inline void ST() {//构造ST表
for (int i = 1; i <= n; i++) st[i][0] = he[i];
int w = log(n) / log(2);
for (int k = 1; k <= w; k++)
for (int i = 1; i <= n; i++) {
if (i + (1 << k) > n + 1) break;
st[i][k] = min(st[i][k-1], st[i+(1<<(k-1))][k-1]);
}
}
inline int get(int l, int r) {//求l~r之间的最小值(即l-1与r的lcp)
int k = log(r - l + 1) / log(2);
return min(st[l][k], st[r-(1<<k)+1][k]);
}
int main() {
cin >> n;
for (int i = 1; i <= n; i++) {
scanf("%d", &a[i]);
b[i] = a[i];
}
//离散化
sort(b + 1, b + n + 1);
m = unique(b + 1, b + n + 1) - (b + 1);
reverse(a + 1, a + n + 1);//倒置字符串
SA();//求sa,rk,height数组
ST();//ST表
for (int i = n; i; i--) {//倒序考虑
s.insert(rk[i]);//以rk为关键字插入set
si it = s.find(rk[i]);//找到插入的位置
int k = 0;//存最长lcp
if (it != s.begin()) {//找前驱,注意特判
int p = *(--it);
k = get(p + 1, rk[i]);
++it;
}
++it;
if (it != s.end()) {//找后继,注意特判
int p = *it;
k = max(k, get(rk[i] + 1, p));
}
ans += n + 1 - i - k;//加上新生成的子串
printf("%lld\n", ans);
}
return 0;
}