题目大意
小Q发明了一种进位制,每一位的变化范围是\(0\)~\(b_i-1\),给你一个这种进位制下的整数\(a\),问你有多少非负整数小于\(a\)。结果以十进制表示。
\(n\leq 120000,0\leq a_i<b_i\leq 1000000\)
题解
就是求这个数。
那没什么好说的,直接分治FFT
处理左半边(低位)的\(c_1=\prod b_i\)和答案\(d_1\),右半边的\(c2,d2\)
那么\(c=c_1\times c_2,d=d_2\times c_1+d_1\)
时间复杂度:\(O(n\log^2 n)\)
代码
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cstdlib>
#include<ctime>
#include<utility>
#include<cmath>
using namespace std;
typedef long long ll;
typedef pair<int,int> pii;
//typedef long double ld;
typedef double ld;
//const ld pi=3.1415926535897932384626433832L;
const ld pi=acos(ld(-1));
struct cp
{
ld x,y;
cp(ld _x=0,ld _y=0)
{
x=_x;
y=_y;
}
};
cp conj(cp &a){return cp(a.x,-a.y);}
cp operator +(cp &a,cp &b){return cp(a.x+b.x,a.y+b.y);}
cp operator -(cp &a,cp &b){return cp(a.x-b.x,a.y-b.y);}
cp operator *(cp &a,cp &b){return cp(a.x*b.x-a.y*b.y,a.x*b.y+a.y*b.x);}
cp operator /(cp &a,ld b){return cp(a.x/b,a.y/b);}
cp a1[500010];
cp a2[500010];
cp a3[500010];
cp w1[500010];
cp w2[500010];
int rev[500010];
int N;
namespace fft
{
void get(int n)
{
N=1;
while(N<n)
N<<=1;
int i;
for(i=2;i<=N;i<<=1)
{
w1[i]=cp(cos(ld(2*pi/i)),sin(ld(2*pi/i)));
w2[i]=conj(w1[i]);
}
for(i=0;i<N;i++)
rev[i]=(rev[i>>1]>>1)|(i&1?(N>>1):0);
}
void fft(cp *a,int t)
{
int i,j,k;
cp w,wn,u,v;
for(i=0;i<N;i++)
if(rev[i]<i)
swap(a[i],a[rev[i]]);
for(i=2;i<=N;i<<=1)
{
wn=t?w1[i]:w2[i];
for(j=0;j<N;j+=i)
{
w=cp(1);
for(k=j;k<j+i/2;k++)
{
u=a[k];
v=a[k+i/2]*w;
a[k]=u+v;
a[k+i/2]=u-v;
w=w*wn;
}
}
}
if(!t)
for(i=0;i<N;i++)
a[i]=a[i]/N;
}
}
ll a[500010];
ll b[500010];
ll c[500010];
ll d[500010];//答案
ll e[500010];
ll f[500010];
const ll A=1000;
const ll B=1000000;
void cheng(ll *a1,int n1,ll *a2,int n2,ll *a3,int n3)
{
int i,j;
for(i=0;i<n3;i++)
a3[i]=0;
for(i=0;i<n1;i++)
for(j=0;j<n2;j++)
a3[i+j]+=a1[i]*a2[j];
}
void clear(ll *a,int n)
{
int i;
for(i=0;i<n;i++)
a[i]=0;
}
void cheng(ll *a1,int n1,ll a2)
{
int i;
for(i=0;i<n1;i++)
a1[i]*=a2;
}
void jia(ll *a1,int n1,ll *a2,int n2,ll *a3,int n3)
{
int i;
for(i=0;i<n3;i++)
a3[i]=0;
for(i=0;i<n1||i<n2;i++)
{
ll s1=(i<n1?a1[i]:0);
ll s2=(i<n2?a2[i]:0);
a3[i]+=s1+s2;
}
}
void jia(ll *a1,ll *a2,int n2)
{
int i;
for(i=0;i<n2;i++)
a1[i]+=a2[i];
}
void solve(int l,int r)
{
if(l==r)
{
// d[l]=b[l];
// c[l]=a[l];
d[l*2]=b[l]%A;
d[l*2+1]=b[l]/A;
c[l*2]=a[l]%A;
c[l*2+1]=a[l]/A;
return;
}
if(r-l+1<=20)
{
int i,j;
int len=(r-l+1);
clear(c+l*2,len*2);
clear(d+l*2,len*2);
c[l*2]=1;
for(i=0;i<len;i++)
{
memcpy(e+l*2,c+l*2,len*sizeof(ll)*2);
cheng(e+l*2,len*2,b[l+i]);
for(j=0;j<len*2-1;j++)
{
e[l*2+j+1]+=e[l*2+j]/A;
e[l*2+j]%=A;
}
jia(d+l*2,e+l*2,len*2);
for(j=0;j<len*2-1;j++)
{
d[l*2+j+1]+=d[l*2+j]/A;
d[l*2+j]%=A;
}
cheng(c+l*2,len*2,a[l+i]);
for(j=0;j<len*2-1;j++)
{
c[l*2+j+1]+=c[l*2+j]/A;
c[l*2+j]%=A;
}
}
return;
}
int mid=(l+r)>>1;
solve(l,mid);
solve(mid+1,r);
int llen=mid-l+1;
int rlen=r-mid;
int len=r-l+1;
llen*=2;
rlen*=2;
len*=2;
// if(len>50000&&r==69999)
// int xfz=1;
// if(l==0)
// int xfz=1;
int i;
fft::get(len);
for(i=0;i<llen;i++)
a1[i]=cp(c[l*2+i]);
for(i=llen;i<N;i++)
a1[i]=cp();
for(i=0;i<rlen;i++)
{
a2[i]=cp(c[mid*2+2+i]);
a3[i]=cp(d[mid*2+2+i]);
}
for(i=rlen;i<N;i++)
a2[i]=a3[i]=cp();
fft::fft(a1,1);
fft::fft(a2,1);
fft::fft(a3,1);
for(i=0;i<N;i++)
{
a2[i]=a2[i]*a1[i];
a3[i]=a3[i]*a1[i];
}
fft::fft(a2,0);
fft::fft(a3,0);
// if(len>50000&&r==69999)
// int xfz=1;
for(i=0;i<len;i++)
{
c[l*2+i]=ll(a2[i].x+0.4);
e[l*2+i]=ll(a3[i].x+0.4);
}
for(i=0;i<llen;i++)
e[l*2+i]+=d[l*2+i];
for(i=0;i<len;i++)
d[l*2+i]=e[l*2+i];
for(i=0;i<len-1;i++)
{
c[l*2+i+1]+=c[l*2+i]/A;
c[l*2+i]%=A;
d[l*2+i+1]+=d[l*2+i]/A;
d[l*2+i]%=A;
}
// cheng(c+l,llen,d+mid+1,rlen,e,len);
// jia(e,len,d+l,llen,f,len);
// int i;
// for(i=0;i<len;i++)
// d[l+i]=f[i];
// for(i=0;i<len-1;i++)
// {
// d[l+i+1]+=d[l+i]/A;
// d[l+i]%=A;
// }
// cheng(c+l,llen,c+mid+1,rlen,e,len);
// for(i=0;i<len;i++)
// c[l+i]=e[i];
// for(i=0;i<len-1;i++)
// {
// c[l+i+1]+=c[l+i]/A;
// c[l+i]%=A;
// }
}
int main()
{
// freopen("conv.in","r",stdin);
// freopen("conv-2.out","w",stdout);
int n;
scanf("%d",&n);
int i;
for(i=0;i<n;i++)
scanf("%d",&a[i]);
for(i=0;i<n;i++)
scanf("%d",&b[i]);
solve(0,n-1);
for(i=2*n-1;!d[i];i--);
printf("%d",d[i]);
for(i--;i>=0;i--)
// output(d[i]);
printf("%03d",d[i]);
putchar('\n');
// int n=4;
// fft::get(n);
// a1[0]=cp(1);
// a1[1]=cp(2);
// a2[0]=cp(1);
// a2[1]=cp(2);
// fft::fft(a1,1);
// fft::fft(a2,1);
// int i;
// for(i=0;i<N;i++)
// a1[i]=a1[i]*a2[i];
// fft::fft(a1,0);
return 0;
}