Colorful String

 1 #include <bits/stdc++.h>
 2 using namespace std;
 3 typedef long long ll;
 4 const int maxn = 1e6+5;
 5 char s[maxn];
 6 int n;
 7 int record[maxn];  // 记录i结点在原字符串的位置
 8 int sum[maxn*4];
 9 void pushup(int rt) {
10     sum[rt] |= sum[rt*2];
11     sum[rt] |= sum[rt*2+1];
12 }
13 void build(int l, int r, int rt) {
14     if (l == r) {
15         sum[rt] |= 1<<(s[l]-'a');
16         return;
17     }
18     int mid = (l+r)/2;
19     build(l,mid,rt*2);
20     build(mid+1,r,rt*2+1);
21     pushup(rt);
22 }
23 int query(int be, int ed, int l, int r, int rt) {
24     if (be <= l && r <= ed) {
25         return sum[rt];
26     }
27     int mid = (l+r)/2, res = 0;
28     if (be <= mid) res |= query(be,ed,l,mid,rt*2);
29     if (ed > mid) res |= query(be,ed,mid+1,r,rt*2+1);
30     return res;
31 }
32
33 struct PAM {
34     int last;
35     struct Node {
36         ll cnt, len, fail, son[27];  // cnt为以i为结尾的回文子串个数,len为长度
37         Node(int len, int fail) : len(len), fail(fail), cnt(0){
38             memset(son, 0, sizeof(son));
39         };
40     };
41     vector<Node> st;
42     inline int newnode(int len, int fail = 0) {
43         st.emplace_back(len, fail);
44         return st.size()-1;
45     }
46     inline int getfail(int x, int n) {
47         while (s[n-st[x].len-1] != s[n]) x = st[x].fail;
48         return x;
49     }
50     inline void extend(int c, int i) {
51         int cur = getfail(last, i);
52         if (!st[cur].son[c]) {
53             int nw = newnode(st[cur].len+2, st[getfail(st[cur].fail, i)].son[c]);
54             st[cur].son[c] = nw;
55         }
56         st[ last=st[cur].son[c] ].cnt++;
57         record[last] = i;
58     }
59     void init() {
60         scanf("%s", s+1);
61         n = strlen(s+1);
62         s[0] = 0;
63         newnode(0, 1), newnode(-1);
64         last = 0;
65         for (int i = 1; i <= n; i++)
66             extend(s[i]-'a', i);
67     }
68     ll count() {
69         for (int i = st.size()-1; i >= 0; i--)
70             st[st[i].fail].cnt += st[i].cnt;
71
72         ll ans = n;
73         for (int i = 2; i <= st.size()-1; i++) {
74             if (st[i].len <= 1) continue;
75
76             int L = record[i]-st[i].len+1, R = record[i];
77             int res = query(L,R,1,n,1);
78             int num = 0;
79             while (res) {
80                 if (res&1) num++;
81                 res >>= 1;
82             }
83             ans += st[i].cnt*num;
84         }
85         return ans;
86     }
87 }pam;
88 int main() {
89     pam.init();
90     build(1,n,1);
91     printf("%lld\n",pam.count());
92     return 0;
93 }
01-07 06:36