题目:http://codeforces.com/contest/438/problem/E
https://www.lydsy.com/JudgeOnline/problem.php?id=3625
多项式开方...
注意传进 sqt 中的模数应该是2的整数次幂,所以先补到 >=m ;
还要注意每次一定要先递归或进入别的子函数,再算 rev 数组,否则会被覆盖!
最重要的是 lim < n+n 而不是 <= ,否则会把数组撑大一倍(于是 (1<<18) 会RE),而如果真的把数组开到 (1<<19),又会因为进行 NTT 的长度变成两倍而(在bzoj上) TLE ...
想想,因为一开始已经是 lim <= m,所以 lim 一定是偏大的,也就是传进去的 n 并不是顶到的上界,也就不用必须 <= ;
而传进去的 n 本身是一个2的整数次幂,所以 <= 会纯粹的增大一倍;(所以直接写成 n>>1 就好了)
注意细节啊...
代码如下:
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
typedef long long ll;
int const xn=(<<)+,mod=,g=;
int n,m,c[xn],t[xn],tt[xn],rev[xn],sc[xn],inv2,f[xn];
int rd()
{
int ret=,f=; char ch=getchar();
while(ch<''||ch>''){if(ch=='-')f=; ch=getchar();}
while(ch>=''&&ch<='')ret=(ret<<)+(ret<<)+ch-'',ch=getchar();
return f?ret:-ret;
}
ll pw(ll a,int b)
{
ll ret=;
for(;b;b>>=,a=(a*a)%mod)if(b&)ret=(ret*a)%mod;
return ret;
}
int upt(int x){while(x>=mod)x-=mod; while(x<)x+=mod; return x;}
void ntt(int *a,int tp,int lim)
{
for(int i=;i<lim;i++)
if(i<rev[i])swap(a[i],a[rev[i]]);
for(int mid=;mid<lim;mid<<=)
{
int wn=pw(g,(mod-)/(mid<<));
if(tp==-)wn=pw(wn,mod-);
for(int j=,len=(mid<<);j<lim;j+=len)
{
int w=;
for(int k=;k<mid;k++,w=(ll)w*wn%mod)
{
int x=a[j+k],y=(ll)w*a[j+mid+k]%mod;
a[j+k]=upt(x+y); a[j+mid+k]=upt(x-y);
}
}
}
if(tp==)return; int inv=pw(lim,mod-);
for(int i=;i<lim;i++)a[i]=(ll)a[i]*inv%mod;
}
void inv(int *a,int *b,int n)
{
if(n==){b[]=pw(a[],mod-); return;}
inv(a,b,n>>);
int lim=,l=;
while(lim<n+n)lim<<=,l++;//<= (1<<19) TLE
for(int i=;i<lim;i++)rev[i]=((rev[i>>]>>)|((i&)<<(l-)));//after inv!!!
for(int i=;i<n;i++)tt[i]=a[i];
for(int i=n;i<lim;i++)tt[i]=;
ntt(tt,,lim); ntt(b,,lim);
for(int i=;i<lim;i++)b[i]=upt(((ll)-(ll)tt[i]*b[i])%mod*b[i]%mod);
ntt(b,-,lim);
for(int i=n;i<lim;i++)b[i]=;
}
void sqt(int *a,int *b,int n)
{
if(n==){b[]=; return;}
sqt(a,b,n>>);
int lim=,l=;
while(lim<n+n)lim<<=,l++;//<= (1<<19) TLE
for(int i=;i<lim;i++)t[i]=;
inv(b,t,n);
for(int i=;i<lim;i++)rev[i]=((rev[i>>]>>)|((i&)<<(l-)));//after inv!!!
for(int i=;i<n;i++)tt[i]=a[i];
for(int i=n;i<lim;i++)tt[i]=;
ntt(b,,lim); ntt(tt,,lim); ntt(t,,lim);
for(int i=;i<lim;i++)b[i]=((ll)b[i]+(ll)tt[i]*t[i])%mod*inv2%mod;
ntt(b,-,lim);
for(int i=n;i<lim;i++)b[i]=;
}
int main()
{
n=rd(); m=rd(); inv2=pw(,mod-);
for(int i=,x;i<=n;i++)x=rd(),c[x]++;
int lim=; while(lim<=m)lim<<=;//m
for(int i=;i<lim;i++)c[i]=((-(ll)*c[i])%mod+mod)%mod;//!1-4*c[i]! //(ll)!!
c[]=;//1+!!
sqt(c,sc,lim);//lim
sc[]++; sc[]=upt(sc[]);
inv(sc,f,lim);
for(int i=;i<=m;i++)printf("%d\n",upt(f[i]<<));
return ;
}