Description
Input
Output
Sample Input
Sample Output
HINT
考虑用总回文序列数减去回文串数
对于每个对称轴,只需要知道有多少对关于它对称然后计算
先把$a$置为$1$,$b$置为$0$做一次卷积,然后$a$置为$0$,$b$置为$1$做一次卷积
两次加起来再稍微处理一下就行了
回文串数用马拉车求
#include <bits/stdc++.h> using namespace std; const int maxn = 262144, mod = 1e9 + 7; const double PI = acos(-1.0); struct comp{ double x, y; comp(double _x = 0, double _y = 0){ x = _x; y = _y; } friend comp operator * (const comp &a, const comp &b){ return comp(a.x * b.x - a.y * b.y, a.x * b.y + a.y * b.x); } friend comp operator + (const comp &a, const comp &b){ return comp(a.x + b.x, a.y + b.y); } friend comp operator - (const comp &a, const comp &b){ return comp(a.x - b.x, a.y - b.y); } }f[maxn], g[maxn]; int rev[maxn]; void dft(comp A[], int len, int kind){ for(int i = 0; i < len; i++){ if(i < rev[i]){ swap(A[i], A[rev[i]]); } } comp wn, s, t, tmp; for(int i = 1; i < len; i <<= 1){ wn = comp(cos(PI / i), kind * sin(PI / i)); for(int j = 0; j < len; j += (i << 1)){ tmp = comp(1, 0); for(int k = 0; k < i; k++){ s = A[j + k], t = tmp * A[i + j + k]; A[j + k] = s + t; A[i + j + k] = s - t; tmp = tmp * wn; } } } if(kind == -1) for(int i = 0; i < len; i++) A[i].x /= len; } void init(int &len, int n){ int L = 0; for(len = 1; len < n + n - 1; len <<= 1, L++); for(int i = 0; i < len; i++){ rev[i] = rev[i >> 1] >> 1 | (i & 1) << L - 1; } } int bin[100000 + 10]; char str[200000 + 10]; int slen, p[200000 + 10] = {0}; int main(){ scanf("%s", str); int n = strlen(str), len; init(len, n); for(int i = 0; i < n; i++){ if(str[i] == 'a') f[i].x = 1; } for(int i = 0; i < n; i++){ if(str[i] == 'b') g[i].x = 1; } dft(f, len, 1); dft(g, len, 1); for(int i = 0; i < len; i++) g[i] = g[i] * g[i] + f[i] * f[i]; dft(g, len, -1); for(int i = 0; i < n; i++) g[i << 1].x++; bin[0] = 1; for(int i = 1; i <= 100000; i++){ bin[i] = bin[i - 1] << 1; if(bin[i] >= mod) bin[i] -= mod; } int ans = 0; for(int i = 0; i < len; i++){ ans += bin[(int)(g[i].x + 0.5) >> 1] - 1; if(ans >= mod) ans -= mod; } for(int i = n - 1; ~i; i--){ str[i * 2 + 2] = str[i]; str[i * 2 + 3] = '$'; } str[0] = '~'; str[1] = '$'; slen = 2 * n + 2; int mr = 0, mid; for(int i = 1; i < slen; i++){ if(i <= mr){ p[i] = min(mr - i + 1, p[(mid << 1) - i]); } while(str[i + p[i]] == str[i - p[i]]) p[i]++; if(p[i] + i > mr) mr = p[i] + i - 1, mid = i; ans -= p[i] >> 1; if(ans < 0) ans += mod; } printf("%d\n", ans); return 0; }