FFT:

没啥好说的吧。。

证明应该都会,写的时候记住两个点就行:

1.怎么定义复数?千万别写成

complex<double> w=(1,0);

可以自己试一下这样输出什么东西……

2.枚举len,遍历前一半,用原来的$a_{i},a_{i+len/2}$值更新新的$a_{i},a_{i+len/2}$值。如果你想写递归版的就当没看见这句话。

强烈推荐伟大的石神给出的里程碑式的证明

代码:

#include<bits/stdc++.h>
#define maxn 3000005
#define maxm 500005
#define inf 0x7fffffff
#define ll long long
#define cp complex<double>
#define pi acos(-1)

using namespace std;
cp a[maxn],b[maxn],c[maxn];
int ind[maxn];

inline int read(){
    int x=0,f=1; char c=getchar();
    for(;!isdigit(c);c=getchar()) if(c=='-') f=-1;
    for(;isdigit(c);c=getchar()) x=x*10+c-'0';
    return x*f;
}

inline void fft(cp *t,int n,int op){
    for(int i=0;i<n;i++) ind[i]=(i&1)?((ind[i>>1]>>1)|(n>>1)):(ind[i>>1]>>1);
    for(int i=0;i<n;i++) if(ind[i]<i) swap(t[i],t[ind[i]]);
    for(int len=2;len<=n;len<<=1)
        for(int j=0;j<n;j+=len){
            cp w(1,0),p(cos(op*2*pi/len),sin(op*2*pi/len));
            for(int k=0;k<(len>>1);k++,w*=p){
                cp t1=t[j+k+(len>>1)],t2=t[j+k];
                t[j+k+(len>>1)]=t2-t1*w,t[j+k]=t2+t1*w;
            }
        }
    return;
}

int main(){
    int n=read(),m=read();
    for(int i=0;i<=n;i++) a[i]=read();
    for(int i=0;i<=m;i++) b[i]=read();
    int len=1; while(len<=n+m) len<<=1;
    fft(a,len,1),fft(b,len,1);
    for(int i=0;i<len;i++) c[i]=a[i]*b[i];
    fft(c,len,-1);
    for(int i=0;i<=n+m;i++) cout<<(int)(c[i].real()/len+0.5)<<" ";
    cout<<endl;
    return 0;
}
FFT

NTT:

原根的性质与单位根一模一样。只是$IDFT$的时候得求个逆元。

题里有模数就用$NTT$,没有模数也尽量用$NTT$,避免出现时间精度双爆炸的惨剧。

注意$w_{n}^{k}=g^{\frac{mod-1}{n}\cdot k}$,证明时将k分离出来考虑即可。

伟大的石神在刚才那篇里也更了$NTT$!

代码:

#include<bits/stdc++.h>
#define maxn 3000005
#define maxm 500005
#define inf 0x7fffffff
#define ll long long
#define mod 998244353
#define g 3

using namespace std;
ll a[maxn],b[maxn],res[maxn],ind[maxn];

inline ll read(){
    ll x=0,f=1; char c=getchar();
    for(;!isdigit(c);c=getchar()) if(c=='-') f=-1;
    for(;isdigit(c);c=getchar()) x=x*10+c-'0';
    return x*f;
}

inline ll power(ll a,ll b){ll ans=1;while(b)ans=(b&1)?ans*a%mod:ans,a=a*a%mod,b>>=1;return ans;}
inline void ntt(ll *t,ll n,ll op){
    for(ll i=0;i<n;i++) ind[i]=(i&1)?((ind[i>>1]>>1)|(n>>1)):(ind[i>>1]>>1);
    for(ll i=0;i<n;i++) if(ind[i]<i) swap(t[i],t[ind[i]]);
    for(ll len=1;len<=n;len<<=1){
        ll p=power(g,(mod-1)/len);
        if(op==-1) p=power(p,mod-2);
        for(ll i=0;i<n;i+=len)
            for(ll j=i,w=1,tp;j<i+(len>>1);j++,w=w*p%mod)
                tp=t[j+(len>>1)]*w%mod,t[j+(len>>1)]=(t[j]-tp+mod)%mod,t[j]=(t[j]+tp)%mod;
    }
    return;
}

int main(){
    ll n=read(),m=read();
    for(ll i=0;i<=n;i++) a[i]=read();
    for(ll i=0;i<=m;i++) b[i]=read();
    ll len=1; while(len<=n+m) len<<=1;
    ntt(a,len,1),ntt(b,len,1);
    for(ll i=0;i<len;i++) res[i]=a[i]*b[i];
    ntt(res,len,-1);
    ll mo=power(len,mod-2);
    for(ll i=0;i<=n+m;i++) cout<<res[i]*mo%mod<<" ";
    cout<<endl;
    return 0;
}
NTT
01-08 11:19