「NOI2016」优秀的拆分

题目描述

如果一个字符串可以被拆分为 \(\text{AABB}\) 的形式,其中 \(\text{A}\)\(\text{B}\) 是任意非空字符串,则我们称该字符串的这种拆分是优秀的。
例如,对于字符串 \(\text {aabaabaa}\) ,如果令 \(\text{A}=\texttt{aab}\)\(\text{B}=\texttt{a}\),我们就找到了这个字符串拆分成 \(\text{AABB}\) 的一种方式。

一个字符串可能没有优秀的拆分,也可能存在不止一种优秀的拆分。
比如我们令 \(\text{A}=\texttt{a}\)\(\text{B}=\texttt{baa}\),也可以用 \(\text{AABB}\) 表示出上述字符串;但是,字符串 \(\texttt{abaabaa}\) 就没有优秀的拆分。

现在给出一个长度为 \(n\) 的字符串 \(S\),我们需要求出,在它所有子串的所有拆分方式中,优秀拆分的总个数。这里的子串是指字符串中连续的一段。

以下事项需要注意:

  1. 出现在不同位置的相同子串,我们认为是不同的子串,它们的优秀拆分均会被记入答案。
  2. 在一个拆分中,允许出现 \(\text{A}=\text{B}\)。例如 \(\texttt{cccc}\) 存在拆分 \(\text{A}=\text{B}=\texttt{c}\)
  3. 字符串本身也是它的一个子串。

输入格式

每个输入文件包含多组数据。
输入文件的第一行只有一个整数 \(T\),表示数据的组数。
接下来 \(T\) 行,每行包含一个仅由英文小写字母构成的字符串 \(S\),意义如题所述。

输出格式

输出 \(T\) 行,每行包含一个整数,表示字符串 \(S\) 所有子串的所有拆分中,总共有多少个是优秀的拆分。

样例

样例输入

4
aabbbb
cccccc
aabaabaabaa
bbaabaababaaba

样例输出

3
5
4
7

样例解释

我们用 \(S[i, j]\) 表示字符串 \(S\)\(i\) 个字符到第 \(j\) 个字符的子串(从 \(1\) 开始计数)。

第一组数据中,共有三个子串存在优秀的拆分:
\(S[1,4]=\text{aabb}\),优秀的拆分为 \(\text{A}=\texttt{a}\)\(\text{B}=\texttt{b}\)
\(S[3,6]=\text{bbbb}\),优秀的拆分为 \(\text{A}=\texttt{b}\)\(\text{B}=\texttt{b}\)
\(S[1,6]=\text{aabbbb}\),优秀的拆分为 \(\text{A}=\texttt{a}\)\(\text{B}=\texttt{bb}\)
而剩下的子串不存在优秀的拆分,所以第一组数据的答案是 \(3\)

第二组数据中,有两类,总共四个子串存在优秀的拆分:
对于子串 \(S[1,4]=S[2,5]=S[3,6]=\text{cccc}\),它们优秀的拆分相同,均为 \(\text{A}=\texttt{c}\)\(\text{B}=\texttt{c}\),但由于这些子串位置不同,因此要计算三次;
对于子串 \(S[1,6]=\text{cccccc}\),它优秀的拆分有两种:\(\text{A}=\texttt{c}\)\(\text{B}=\texttt{cc}\)\(\text{A}=\texttt{cc}\)\(\text{B}=\texttt{c}\),它们是相同子串的不同拆分,也都要计入答案。
所以第二组数据的答案是 \(3+2=5\)

第三组数据中,\(S[1,8]\)\(S[4,11]\) 各有两种优秀的拆分,其中 \(S[1,8]\) 是问题描述中的例子,所以答案是 \(2+2=4\)

第四组数据中,\(S[1,4]\)\(S[6,11]\)\(S[7,12]\)\(S[2,11]\)\(S[1,8]\) 各有一种优秀的拆分,\(S[3,14]\) 有两种优秀的拆分,所以答案是 \(5+2=7\)

数据范围与提示

对于全部的测试点,\(1 \leq T \leq 10, \ n \leq 30000\)

题解

\(95\)分hash暴力真的就是随便写...
我们处理出\(a[i]\)\(b[i]\)表示以\(i\)为终点和起点的\(AA\)串的个数。那么答案即为\(\sum_{i=1}^{n-1}a[i]\times b[i + 1]\)。hash优化一下判定过程就是\(O(n^2)\)的。
\(100\)分不看题解真的没有什么思路(即使知道了这是一道后缀数组题...)
我们可以思考一下如何优化处理\(AA\)串的过程。
枚举\(A\)串的长度\(len\),然后对于相邻的两个长度间隔为\(len\)的点,如果他们的\(lcp(x,y)+lcs(x,y)\geq len\),那么中间则有一段长度为\(lcp+lcs-len+1\)的合法的\(AA\)串终点的区间。
为什么呢?可以通过把这句话画出来,比如这样:

那么中间那段红色的区域就是合法的终点区间。
\(lcp(x,y)\)\(lcs(x,y)\)可以直接用后缀数组来求。总复杂度为\(O(n \log n)\)
当然也可以用hash实现这个过程,复杂度就是\(O(n \log^2 n)\)的。

#include <bits/stdc++.h>
using namespace std;

typedef long long ll;
const int N = 50010;

int n, a[N], b[N];
char s[N];

struct SA {
int sa[N], height[N], tong[N], rnk[N], tp[N], f[N][16], LG[N];
int m;
void radix_sort() {
    for(int i = 1; i <= m; ++i) tong[i] = 0;
    for(int i = 1; i <= n; ++i) tong[rnk[i]]++;
    for(int i = 1; i <= m; ++i) tong[i] += tong[i - 1];
    for(int i = n; i; --i) sa[tong[rnk[tp[i]]]--] = tp[i];
}
int query(int l, int r) {
    l = rnk[l], r = rnk[r];
    if(l > r) swap(l, r);  ++l;
    int k = LG[r - l + 1];
    return min(f[l][k], f[r - (1 << k) + 1][k]);
}
void init() {
    memset(sa, 0, sizeof(sa));
    memset(height, 0, sizeof(height));
    memset(tong, 0, sizeof(tong));
    memset(rnk, 0, sizeof(rnk));
    memset(tp, 0, sizeof(tp));
    memset(f, 0, sizeof(f));
    memset(LG, 0, sizeof(LG));
}
void build(char *A) {
    init();
    for(int i = 1; i <= n; ++i) rnk[i] = A[i], tp[i] = i;
    m = 200; radix_sort();
    for(int w = 1, p = 0; w <= n && p < n; m = p, w <<= 1) {
        p = 0;
        for(int i = 1; i <= w; ++i) tp[++p] = n - w + i;
        for(int i = 1; i <= n; ++i) if(sa[i] > w) tp[++p] = sa[i] - w;
        radix_sort(); swap(tp, rnk); rnk[sa[1]] = p = 1;
        for(int i = 2; i <= n; ++i)
            rnk[sa[i]] = (tp[sa[i]] == tp[sa[i - 1]] && tp[sa[i] + w] == tp[sa[i - 1] + w]) ? p : ++p;
    }
    for(int i = 1, k = 0; i <= n; ++i) {
        if(k) --k; int j = sa[rnk[i] - 1];
        while(A[i + k] == A[j + k] && i + k <= n && j + k <= n) ++k;
        height[rnk[i]] = k;
    }
    for(int i = 2; i <= n; ++i) LG[i] = LG[i >> 1] + 1;
    for(int i = 1; i <= n; ++i) f[i][0] = height[i];
    for(int j = 1; j <= 15; ++j)
        for(int i = 1; i + (1 << j) - 1 <= n; ++i) {
            f[i][j] = min(f[i][j - 1], f[i + (1 << (j - 1))][j - 1]);
        }
}
}A, B;

int main() {
    int T = 0; scanf("%d", &T); while(T--) {
        memset(a, 0, sizeof(a));
        memset(b, 0, sizeof(b));
        scanf("%s", s + 1); n = strlen(s + 1);
        A.build(s); reverse(s + 1, s + n + 1); B.build(s);
        for(int len = 1; len <= (n >> 1); ++len) {
            for(int i = len, j = i + len; j <= n; i += len, j += len) {
                int LCS = min(len - 1, B.query(n - i + 2, n - j + 2)), LCP = min(len, A.query(i, j));
                if(LCS + LCP >= len) {
                    int t = LCP + LCS - len + 1;
                    a[i - LCS]++; a[i - LCS + t]--;
                    b[j + LCP - t]++; b[j + LCP]--;
                }
            }
        }
        for(int i = 1; i <= n; ++i) a[i] += a[i - 1], b[i] += b[i - 1];
        ll ans = 0;
        for(int i = 1; i < n; ++i) ans += 1LL * b[i] * a[i + 1];
        printf("%lld\n", ans);
    }
    return 0;
}
12-14 12:26