题目大意:
给你一个数组 \(a_{1 \sim n}\),对于 \(k = 0 \sim n\),求出有多少个数组上的区间满足:区间内恰好有 \(k\) 个数比 \(x\) 小。\(x\) 为一个给定的数。
\(n \le 2 \times 10^5\)。
没见过大概想不到飞飞兔吧
考虑\(n\)为\(10w\)级别,且要求求\(0\sim n\)的每个\(k\),考虑\(fft\)
由于\(x\)是一定的,所以我们可以把小于\(x\)的数字变为\(1\),其他变为\(0\)
设\(s_i\)为变化后序列前缀和,\(f_i=\sum\limits_{j=0}^{n}[s_j==i]\)
那么我们要求
\[\sum\limits_{i=k}^{n}f_i*f_{i-k}\]
这个东西是不能\(fft\)的,对后面变化一下
\[\sum\limits_{i=k}^{n}f_i*f_{n-(n+k-i)}\]
令\(j=n+k-i\)
\[\sum\limits_{i+j=n+k}f_i*f_{n-j}\]
设\(g_i=f_{n-i}\)
\[\sum\limits_{i+j=n+k}f_i*g_j\]
它已经散发出毒品的香气了
不过注意当\(k=0\)时会重复计算左右端点,需要特判
一些细节证明先留锅,不太明白
#include<bits/stdc++.h>
using namespace std;
namespace red{
#define int long long
inline int read()
{
int x=0;char ch,f=1;
for(ch=getchar();(ch<'0'||ch>'9')&&ch!='-';ch=getchar());
if(ch=='-') f=0,ch=getchar();
while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+ch-'0';ch=getchar();}
return f?x:-x;
}
const int N=1e6+10;
double pi=acos(-1.0);
int n,tmp,limit,len;
int a[N],pos[N],ret[N];
int sum[N];
struct complex
{
double x,y;
complex(double tx=0,double ty=0){x=tx,y=ty;}
inline complex operator + (const complex &t) const
{
return complex(x+t.x,y+t.y);
}
inline complex operator - (const complex &t) const
{
return complex(x-t.x,y-t.y);
}
inline complex operator * (const complex &t) const
{
return complex(x*t.x-y*t.y,x*t.y+y*t.x);
}
}f[N],g[N];
inline void fft(int limit,complex *a,int inv)
{
for(int i=0;i<limit;++i)
if(i<pos[i]) swap(a[i],a[pos[i]]);
for(int mid=1;mid<limit;mid<<=1)
{
complex Wn(cos(pi/mid),inv*sin(pi/mid));
for(int r=mid<<1,j=0;j<limit;j+=r)
{
complex w(1,0);
for(int k=0;k<mid;++k,w=w*Wn)
{
complex x=a[j+k],y=w*a[j+k+mid];
a[j+k]=x+y;
a[j+k+mid]=x-y;
}
}
}
}
inline void main()
{
n=read(),tmp=read();
for(int i=1;i<=n;++i) a[i]=(read()<tmp),sum[i]=sum[i-1]+a[i];
f[0].x=1;
for(int tmp=0,i=1;i<=n;++i)
{
tmp+=a[i];
++f[tmp].x;
}
for(int i=0;i<=n;++i) g[i].x=f[n-i].x;
for(limit=1;limit<=n+n;limit<<=1) ++len;
for(int i=0;i<limit;++i) pos[i]=(pos[i>>1]>>1)|((i&1)<<(len-1));
fft(limit,f,1);
fft(limit,g,1);
for(int i=0;i<limit;++i) f[i]=f[i]*g[i];
fft(limit,f,-1);
for(int tmp=0,sum=0,i=1;i<=n;++i)
{
if(!a[i]) ++sum;
else
{
tmp+=((sum+1)*sum)>>1;
sum=0;
}
if(i==n)
{
tmp+=((sum+1)*sum)>>1;
printf("%lld ",tmp);
}
}
for(int i=n+1;i<=n+n;++i) ret[i]=f[i].x/limit+0.5;
for(int i=n+1;i<=n+n;++i) printf("%lld ",ret[i]);
}
}
signed main()
{
red::main();
return 0;
}