题目地址:P4070 [SDOI2016]生成魔咒

相信看到题目之后很多人跟我的思路是一样的——

肯定要用 SAP3809 【模板】后缀排序

肯定要会求本质不同的子串个数P2408 不同子串个数

然后?就不会了......

瓶颈在哪儿?

你会发现每往后添加一个字符,整个 sa 数组只会插入一个数,要维护不难

但是 height无规律变化,这就导致无法高效维护

怎么办呢?

倒置字符串

我们将整个字符串倒置过来

显然本质不同的子串个数不会变化

而每往前添加一个字符串, height 的变化是 \(O(1)\) 的

那么,问题就变得简单很多了

具体实现请看代码注释

#include <bits/stdc++.h>
#define ll long long
#define si set<int>::iterator
using namespace std;
const int N = 1e5 + 6;
int n, m, a[N], b[N];
int sa[N], rk[N], tp[N], tx[N], he[N], st[N][20];
ll ans;
set<int> s;

inline void tsort() {//基数排序
    for (int i = 1; i <= m; i++) tx[i] = 0;
    for (int i = 1; i <= n; i++) ++tx[rk[i]];
    for (int i = 1; i <= m; i++) tx[i] += tx[i-1];
    for (int i = n; i; i--) sa[tx[rk[tp[i]]]--] = tp[i];
}

inline bool pd(int i, int w) {
    return tp[sa[i-1]] == tp[sa[i]] && tp[sa[i-1]+w] == tp[sa[i]+w];
}

inline void SA() {//后缀数组板子
    for (int i = 1; i <= n; i++) {
        rk[i] = a[i] = lower_bound(b + 1, b + m + 1, a[i]) - b;
        tp[i] = i;
    }
    tsort();
    for (int w = 1, p = 0; p < n; m = p, w <<= 1) {
        p = 0;
        for (int i = 1; i <= w; i++) tp[++p] = n - w + i;
        for (int i = 1; i <= n; i++)
            if (sa[i] > w) tp[++p] = sa[i] - w;
        tsort();
        swap(rk, tp);
        rk[sa[1]] = p = 1;
        for (int i = 2; i <= n; i++)
            rk[sa[i]] = pd(i, w) ? p : ++p;
    }
    int p = 0;
    for (int i = 1; i <= n; i++) {
        if (p) --p;
        int j = sa[rk[i]-1];
        while (a[i+p] == a[j+p]) ++p;
        he[rk[i]] = p;
    }
}

inline void ST() {//构造ST表
    for (int i = 1; i <= n; i++) st[i][0] = he[i];
    int w = log(n) / log(2);
    for (int k = 1; k <= w; k++)
        for (int i = 1; i <= n; i++) {
            if (i + (1 << k) > n + 1) break;
            st[i][k] = min(st[i][k-1], st[i+(1<<(k-1))][k-1]);
        }
}

inline int get(int l, int r) {//求l~r之间的最小值(即l-1与r的lcp)
    int k = log(r - l + 1) / log(2);
    return min(st[l][k], st[r-(1<<k)+1][k]);
}

int main() {
    cin >> n;
    for (int i = 1; i <= n; i++) {
        scanf("%d", &a[i]);
        b[i] = a[i];
    }
    //离散化
    sort(b + 1, b + n + 1);
    m = unique(b + 1, b + n + 1) - (b + 1);
    reverse(a + 1, a + n + 1);//倒置字符串
    SA();//求sa,rk,height数组
    ST();//ST表
    for (int i = n; i; i--) {//倒序考虑
        s.insert(rk[i]);//以rk为关键字插入set
        si it = s.find(rk[i]);//找到插入的位置
        int k = 0;//存最长lcp
        if (it != s.begin()) {//找前驱,注意特判
            int p = *(--it);
            k = get(p + 1, rk[i]);
            ++it;
        }
        ++it;
        if (it != s.end()) {//找后继,注意特判
            int p = *it;
            k = max(k, get(rk[i] + 1, p));
        }
        ans += n + 1 - i - k;//加上新生成的子串
        printf("%lld\n", ans);
    }
    return 0;
}
04-30 21:29