题目:https://www.lydsy.com/JudgeOnline/problem.php?id=3160
求出关于一个位置有多少对对称字母,如果 i 位置有 f[i] 对,对答案的贡献是 2^f[i] - 1;
然后减去连续的,用 manachar 求出回文长度,每个位置作为边界都是一种不合法情况;
求对称,首先把字符串中间穿插字符 '$',于是字符串的长度变成2倍;
考虑一对字母 s[x],s[y],如果 s[x] = s[y],其对称中心是 (x+y)/2;
放在加入字符后的字符串中,对称中心就是 x+y;
所以可以看出卷积了:f[i] = ∑(0<=j<=i) (s[j]==s[i-j]),其中 i 视为新字符串中的位置,j 和 i-j 视为原字符串中的位置;
注意卷积和 manachar 算的个数都要包括自己成对,否则判断挺麻烦...
这里卷积的两个多项式其实是一样的,所以只要用 FFT 算出一个,然后自己乘起来即可;
做下一步的时候注意清空,别忘了清空 n~lim 部分的值;
处理 bin 的边界是 n 而非 n-1,因为最多可能有 n 对。
(学习了 manachar 的简洁写法)
代码如下:
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>
using namespace std;
typedef double db;
int const xn=(<<),mod=1e9+;
db const Pi=acos(-1.0);
int n,rev[xn],lim=,l,len[xn],bin[xn],c[xn];
char ch[xn];
struct com{db x,y;}a[xn],b[xn],aa[xn];
com operator + (com a,com b){return (com){a.x+b.x,a.y+b.y};}
com operator - (com a,com b){return (com){a.x-b.x,a.y-b.y};}
com operator * (com a,com b){return (com){a.x*b.x-a.y*b.y,a.x*b.y+b.x*a.y};}
int upt(int x){while(x>=mod)x-=mod; while(x<)x+=mod; return x;}
void fft(com *a,int tp)
{
for(int i=;i<lim;i++)
if(i<rev[i])swap(a[i],a[rev[i]]);
for(int mid=;mid<lim;mid<<=)
{
com wn=(com){cos(Pi/mid),tp*sin(Pi/mid)};
for(int j=,len=(mid<<);j<lim;j+=len)
{
com w=(com){,};
for(int k=;k<mid;k++,w=w*wn)
{
com x=a[j+k],y=w*a[j+mid+k];
a[j+k]=x+y; a[j+mid+k]=x-y;
}
}
}
}
void solve()
{
for(int i=;i<n;i++)a[i].x=(ch[i]=='a');
fft(a,);
for(int i=;i<lim;i++)b[i]=a[i]*a[i];
for(int i=;i<n;i++)a[i].x=(ch[i]=='b'),a[i].y=;//y=0
for(int i=n;i<lim;i++)a[i].x=,a[i].y=;//!!
fft(a,);
for(int i=;i<lim;i++)b[i]=b[i]+a[i]*a[i];
fft(b,-);
for(int i=;i<n+n;i++)c[i]=(c[i]+(int)(b[i].x/lim+0.5))%mod;
}
char s[xn];
int manachar()//+i self
{
int mx=,id=,ret=; s[]='$';
for(int i=;i<=n+n;i++)
if(i%==)s[i]='$';
else s[i]=ch[i>>];
for(int i=;i<=n+n;i++)
{
if(i<mx)len[i]=min(mx-i,len[id*-i]);
while(i-len[i]>=&&i+len[i]<=n+n&&s[i-len[i]]==s[i+len[i]])len[i]++;
if(i+len[i]>mx)mx=i+len[i],id=i;
ret=upt(ret+len[i]/);
}
return ret;
}
int main()
{
scanf("%s",ch); n=strlen(ch);
while(lim<=n+n)lim<<=,l++;//
for(int i=;i<lim;i++)
rev[i]=((rev[i>>]>>)|((i&)<<(l-)));
bin[]=;
for(int i=;i<=n;i++)bin[i]=upt(bin[i-]+bin[i-]);
solve();
int ans=;
for(int i=;i<n+n;i++)ans=upt(ans+bin[(c[i]+)>>]-);//+1 -1
printf("%d\n",upt(ans-manachar()));
return ;
}