FFT:
没啥好说的吧。。
证明应该都会,写的时候记住两个点就行:
1.怎么定义复数?千万别写成
complex<double> w=(1,0);
可以自己试一下这样输出什么东西……
2.枚举len,遍历前一半,用原来的$a_{i},a_{i+len/2}$值更新新的$a_{i},a_{i+len/2}$值。如果你想写递归版的就当没看见这句话。
强烈推荐伟大的石神给出的里程碑式的证明!
代码:
#include<bits/stdc++.h> #define maxn 3000005 #define maxm 500005 #define inf 0x7fffffff #define ll long long #define cp complex<double> #define pi acos(-1) using namespace std; cp a[maxn],b[maxn],c[maxn]; int ind[maxn]; inline int read(){ int x=0,f=1; char c=getchar(); for(;!isdigit(c);c=getchar()) if(c=='-') f=-1; for(;isdigit(c);c=getchar()) x=x*10+c-'0'; return x*f; } inline void fft(cp *t,int n,int op){ for(int i=0;i<n;i++) ind[i]=(i&1)?((ind[i>>1]>>1)|(n>>1)):(ind[i>>1]>>1); for(int i=0;i<n;i++) if(ind[i]<i) swap(t[i],t[ind[i]]); for(int len=2;len<=n;len<<=1) for(int j=0;j<n;j+=len){ cp w(1,0),p(cos(op*2*pi/len),sin(op*2*pi/len)); for(int k=0;k<(len>>1);k++,w*=p){ cp t1=t[j+k+(len>>1)],t2=t[j+k]; t[j+k+(len>>1)]=t2-t1*w,t[j+k]=t2+t1*w; } } return; } int main(){ int n=read(),m=read(); for(int i=0;i<=n;i++) a[i]=read(); for(int i=0;i<=m;i++) b[i]=read(); int len=1; while(len<=n+m) len<<=1; fft(a,len,1),fft(b,len,1); for(int i=0;i<len;i++) c[i]=a[i]*b[i]; fft(c,len,-1); for(int i=0;i<=n+m;i++) cout<<(int)(c[i].real()/len+0.5)<<" "; cout<<endl; return 0; }
NTT:
原根的性质与单位根一模一样。只是$IDFT$的时候得求个逆元。
题里有模数就用$NTT$,没有模数也尽量用$NTT$,避免出现时间精度双爆炸的惨剧。
注意$w_{n}^{k}=g^{\frac{mod-1}{n}\cdot k}$,证明时将k分离出来考虑即可。
伟大的石神在刚才那篇里也更了$NTT$!
代码:
#include<bits/stdc++.h> #define maxn 3000005 #define maxm 500005 #define inf 0x7fffffff #define ll long long #define mod 998244353 #define g 3 using namespace std; ll a[maxn],b[maxn],res[maxn],ind[maxn]; inline ll read(){ ll x=0,f=1; char c=getchar(); for(;!isdigit(c);c=getchar()) if(c=='-') f=-1; for(;isdigit(c);c=getchar()) x=x*10+c-'0'; return x*f; } inline ll power(ll a,ll b){ll ans=1;while(b)ans=(b&1)?ans*a%mod:ans,a=a*a%mod,b>>=1;return ans;} inline void ntt(ll *t,ll n,ll op){ for(ll i=0;i<n;i++) ind[i]=(i&1)?((ind[i>>1]>>1)|(n>>1)):(ind[i>>1]>>1); for(ll i=0;i<n;i++) if(ind[i]<i) swap(t[i],t[ind[i]]); for(ll len=1;len<=n;len<<=1){ ll p=power(g,(mod-1)/len); if(op==-1) p=power(p,mod-2); for(ll i=0;i<n;i+=len) for(ll j=i,w=1,tp;j<i+(len>>1);j++,w=w*p%mod) tp=t[j+(len>>1)]*w%mod,t[j+(len>>1)]=(t[j]-tp+mod)%mod,t[j]=(t[j]+tp)%mod; } return; } int main(){ ll n=read(),m=read(); for(ll i=0;i<=n;i++) a[i]=read(); for(ll i=0;i<=m;i++) b[i]=read(); ll len=1; while(len<=n+m) len<<=1; ntt(a,len,1),ntt(b,len,1); for(ll i=0;i<len;i++) res[i]=a[i]*b[i]; ntt(res,len,-1); ll mo=power(len,mod-2); for(ll i=0;i<=n+m;i++) cout<<res[i]*mo%mod<<" "; cout<<endl; return 0; }