题目大意
有一棵\(n\)(\(n\leq 1666\))个点的树,有点权\(d_i\),点权最大值为\(w\)(\(w\leq 1666\))。给出\(k\)(\(k\leq n\)),定义一个选择连通块的方案的权值为该连通块第\(k\)大的点权,如果该连通块大小\(<k\),那么该方案的权值为0。求所有选择连通块的方案的权值之和。
题解
考虑暴力:
设\(f(S,k)\)表示连通块\(S\)中第\(k\)大的点权,那么答案就是\(\sum\limits_{i=1}^{w}i\times(\sum\limits_{S\in [1,n]}[f(S,k)=i])\);
前面乘的\(i\)看上去很烦,考虑把\([f(S,k)=i]\)算\(i\)遍,就有该式=\(\sum\limits_{i=1}^w\sum\limits_{j=1}^i\sum\limits_{S\in[1,n]}[f(S,k)=j]\);
把\(j\)拿到前面,就有该式=\(\sum\limits_{j=1}^w\sum\limits_{S\in[1,n]}[f(S,k)\geq j]\);
设\(c(S,i)\)表示连通块\(S\)中大于等于\(i\)的点权的个数,那么当 \([f(S,k)\geq i]\) 时一定有\([c(S,i)\geq k]\),式子变成\(\sum\limits_{i=1}^w\sum\limits_{S\in[1,n]}[c(S,i)\geq k]\)=\(\sum\limits_{i=1}^w\sum\limits_{j=k}^n\sum\limits_{S\in[1,n]}[c(S,i)==j]\);
就可以设\(h(i,j,x)\)表示 以\(i\)为深度连通块最小的点 且 大于等于\(j\)的数的个数为\(x\) 的连通块个数,就有答案=\(\sum\limits_{i=1}^{n}\sum\limits_{j=1}^w\sum\limits_{x=k}^n h(i,j,x)\);
\(h(i,j,x)\)的转移为:当\(d_i<j\)时,\(h(i,j,x)=\prod\limits_{v\in son(i),\sum y_v=x}h(v,j,y_v)\);当\(d_i\geq x\)时,\(h(i,j,x)=\prod\limits_{v\in son(i),\sum y_v=x-1}h(v,j,y_v)\)。
以上的暴力做法转移过程类似树上背包,所以它的时间复杂度是\(\Theta(n^2\times w)\)的。
考虑设多项式\(F_{i,j}(x)=\sum\limits_{a=0}^n h(i,j,a)\times x^n\),那么转移会变成卷。
设多项式\(G_{i,j}(x)=\sum\limits_{v\in i的子树}F_{v,j}(x)\),\(H_i(x)=\sum\limits_{j=1}^w G_{i,j}(x)\),那么\(H_1\)的\(k\)次项系数到第\(n\)次项系数就会是答案。
发现在进行儿子到父亲的转移时,对于\(j\in[1,w]\),\(f(i,j,a)\)的转移大同小异,考虑对于每个点\(i\)同时维护\(F_{i,j}(x)\)这些多项式。
直接维护多项式不太可行。考虑将\(x=1,2,...,n+1\)代入,用\(y=F_{i,j}(x)\)来代表这个多项式进行点值的计算,最后得到\(H_1(1),...,H_1(n+1)\),再用拉格朗日插值还原出\(H_1(x)\)。
初始值:当\(d_i\geq j\)时,有\(F_{i,j}=x\)(即\(f(i,j,1)=1\));当\(d_i< j\)时,有\(F_{i,j}=1\)(即\(f(i,j,0)=1\));
在从\(i\)的儿子\(v\)转移到\(i\)时要做这么几件事:\(F_{i,j}卷=(F_{v,j}+1),G_{i,j}+=G_{v,j}\);
考虑完所有\(i\)的儿子到\(i\)的转移后,还应有:\(G_{i,j}+=F_{i,j}\)。
发现\(F_{i,j}卷=(F_{v,j}+1)\)不太好办,可以在计算完\(F_{i,j},G_{i,j}\)后,令\(F_{i,j}+=1\)。
以上过程可以用线段树维护,设初始值相当于区间修改,其他的相当于全局加,从\(i\)的儿子转移到\(i\)相当于合并两棵线段树上的标记。
具体地,要维护标记\((a,b,c,d)\)表示\(F=a\times F+b,G=c\times F+d+G\)。
考虑标记的叠加:假设一个点先被标记了\((a_0,b_0,c_0,d_0)\),又被标记了\((a_1,b_1,c_1,d_1)\),就相当于\(F=a_1\times(a_0\times F+b_0)+b_1\)=\((a_1\times a_0)\times F+(a_1\times b_0+b_1)\),\(G=c_1\times(a_0\times F+b_0)+d_1+c_0\times F+d_0+G\)=\((c_1\times a_0+c_0)\times F+(c_1\times b_0+d_1+d_0)+G\),也就是它们的标记合并起来就是\((a_1\times a_0,a_1\times b_0+b_1,c_1\times a_0+c_0,c_1\times b_0+d_1+d_0)\)。
标记叠加的性质:观察可得标记的叠加不满足交换律;假设一个点先被标记了\((a_0,b_0,c_0,d_0)\),又被标记了\((a_1,b_1,c_1,d_1)\),又被标记了\((a_2,b_2,c_2,d_2)\),可以通过计算\((a_0,b_0,c_0,d_0)+(a_1,b_1,c_1,d_1)+(a_2,b_2,c_2,d_2)\)和\((a_0,b_0,c_0,d_0)+((a_1,b_1,c_1,d_1)+(a_2,b_2,c_2,d_2))\)来证明其满足结合律,因此可以在线段树上打标记,合并时合并标记。
整个流程是:枚举\(x\)取1到\(n+1\)(其实是随便\(n+1\)个数),从一号点(其实是随便一个点)开始DFS。每走到一个点,区间加初始值,走该点的每个儿子,走完每个儿子后线段树合并,合并后可以回收不用的点。走完所有儿子后进行一些小的处理(如\(G+=F,F++\))。把每个点走完后,计算\(H_1(x)=\sum_{j=1}^w G_{1,j}(x)\)。枚举完\(x\)后,得到\(n+1\)个\(H_1\)的点值,做一遍拉格朗日插值得到\(H_1\)每一项的系数,计算\(k\)次项至\(n\)次项系数的和。
代码
#include<algorithm>
#include<cmath>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<ctime>
#include<iomanip>
#include<iostream>
#include<map>
#include<queue>
#include<set>
#include<stack>
#include<vector>
#define rep(i,x,y) for(register int i=(x);i<=(y);++i)
#define dwn(i,x,y) for(register int i=(x);i>=(y);--i)
#define view(u,k) for(int k=fir[u];~k;k=nxt[k])
#define LL long long
#define maxn 1677
#define ls son[u][0]
#define rs son[u][1]
#define Ls(u) son[u][0]
#define Rs(u) son[u][1]
#define UI unsigned int
#define mi (l+r>>1)
using namespace std;
int read()
{
int x=0,f=1;char ch=getchar();
while(!isdigit(ch)&&ch!='-')ch=getchar();
if(ch=='-')f=-1,ch=getchar();
while(isdigit(ch))x=(x<<1)+(x<<3)+ch-'0',ch=getchar();
return x*f;
}
void write(UI x)
{
if(x==0){putchar('0'),putchar('\n');return;}
int f=0;char ch[20];
if(x<0)putchar('-'),x=-x;
while(x)ch[++f]=x%10+'0',x/=10;
while(f)putchar(ch[f--]);
putchar('\n');
return;
}
const UI mod=64123;
UI ans,f[maxn],qy[maxn],b[maxn];
int n,K,w,d[maxn],rt[maxn],fir[maxn],nxt[maxn<<1],v[maxn<<1],cnt,son[maxn<<6][2],rec[maxn<<6],tpr,cntnd;
struct node{UI a,b,c,d;}tr[maxn<<6];
UI mo(UI x){return x>=mod?x-mod:x;}
void ade(int u1,int v1){v[cnt]=v1,nxt[cnt]=fir[u1],fir[u1]=cnt++;}
UI mul(UI x,int y){UI res=1;while(y){if(y&1)res=res*x%mod;x=x*x%mod,y>>=1;}return res;}
void pu(node & x,node y,node z)
{
x.a=y.a*z.a%mod;
x.b=mo(z.a*y.b%mod+z.b);
x.c=mo(z.c*y.a%mod+y.c);
x.d=mo((z.c*y.b%mod+z.d)+y.d);
}
node getnd(UI a,UI b,UI c,UI d){node tmp;tmp.a=a,tmp.b=b,tmp.c=c,tmp.d=d;return tmp;}
int newnd(){int u=tpr?rec[tpr--]:++cntnd;tr[u].a=1,tr[u].b=tr[u].c=tr[u].d=ls=rs=0;return u;}
void pd(int u)
{
if(!ls)ls=newnd();
pu(tr[ls],tr[ls],tr[u]);
if(!rs)rs=newnd();
pu(tr[rs],tr[rs],tr[u]);
tr[u].a=1,tr[u].b=tr[u].c=tr[u].d=0;
}
int add(int u,int l,int r,int x,int y,node k)
{
if(!u)u=newnd();
if(x<=l&&r<=y){pu(tr[u],tr[u],k);return u;}
pd(u);
if(x<=mi)ls=add(ls,l,mi,x,y,k);
if(y>mi)rs=add(rs,mi+1,r,x,y,k);
return u;
}
int merge(int ua,int ub,int l,int r)
{
if(!ua||!ub){if(!ua)swap(ua,ub);rec[++tpr]=ub;return ua;}
if(!Ls(ua)&&!Rs(ua))swap(ua,ub);
if(!Ls(ub)&&!Rs(ub))
{
pu(tr[ua],tr[ua],getnd(tr[ub].b,0,0,0));
pu(tr[ua],tr[ua],getnd(1,0,0,tr[ub].d));
rec[++tpr]=ub;
return ua;
}
pd(ua),pd(ub);
Ls(ua)=merge(Ls(ua),Ls(ub),l,mi);
Rs(ua)=merge(Rs(ua),Rs(ub),mi+1,r);
rec[++tpr]=ub;
return ua;
}
void recy(int u){if(ls)recy(ls);rec[++tpr]=u;if(rs)recy(rs);}
UI ask(int u,int l,int r)
{
if(l==r)return tr[u].d;
pd(u);UI res=ask(ls,l,mi);
res=mo(res+ask(rs,mi+1,r));
return res;
}
void dfs(int u,int fa,UI qx)
{
rt[u]=add(rt[u],1,w,1,w,getnd(0,1,0,0));
view(u,k)if(v[k]!=fa)
{
dfs(v[k],u,qx);
rt[u]=merge(rt[u],rt[v[k]],1,w);
rt[v[k]]=0;
}
rt[u]=add(rt[u],1,w,1,d[u],getnd(qx,0,0,0));
rt[u]=add(rt[u],1,w,1,w,getnd(1,0,1,0));
rt[u]=add(rt[u],1,w,1,w,getnd(1,1,0,0));
}
void getf(int sz)
{
f[0]=1;
rep(i,1,sz)dwn(j,i,1)f[j]=mo(f[j]+(LL)f[j-1]*(mod-i)%mod);
reverse(f,f+sz+1);
rep(i,1,sz)
{
int lst=0,num=1,nyx=mul(mod-i,mod-2);
rep(j,1,sz)if(i!=j)num=(LL)num*mo(i-j+mod)%mod;
num=(LL)mul(num,mod-2)*qy[i]%mod;
rep(i,0,sz-1)lst=(LL)mo(f[i]-lst+mod)*nyx%mod,b[i]=mo(b[i]+(LL)lst*num%mod);
}
rep(i,K,sz-1)ans=mo(ans+b[i]);
}
UI ff(UI x)
{
UI res=0,now=1;
rep(i,0,n)res=mo(res+now*b[i]%mod),now=now*x%mod;
return res;
}
int main()
{
memset(fir,-1,sizeof(fir));
n=read(),K=read(),w=read();
rep(i,1,n)d[i]=read();
rep(i,1,n-1){int x=read(),y=read();ade(x,y),ade(y,x);}
rep(i,1,n+1){dfs(1,0,i);qy[i]=ask(rt[1],1,w);recy(rt[1]),rt[1]=0;}
getf(n+1);
write(ans);
return 0;
}
一些感想
打开loj“统计”,按“最快”排序,可发现前面一堆暴力。