这个题的思路非常好啊.
我们可以把 $k$ 个点拿出来,那么就是求将 $k$ 个点划分成不大于 $m$ 个集合的方案数.
令 $f[i][j]$ 表示将前 $i$ 个点划分到 $j$ 个集合中的方案数.
那么有 $f[i][j]=f[i-1][j-1]+f[i-1][j]*(j-fail[i])$,其中 $fail[i]$ 代表 $i$ 到根这条路径上祖先数量.
而 $fail[i]$ 的求解方式有:虚数统计/树上数据结构维护路径和,这里选择了用 LCT 来维护.
code:
#include <cstdio> #include <cstring> #include <vector> #include <algorithm> #define N 100007 #define ll long long #define mod 1000000007 #define setIO(s) freopen(s".in","r",stdin) using namespace std; namespace LCT { #define lson t[x].ch[0] #define rson t[x].ch[1] struct node { int ch[2],f,rev,sum,val; }t[N]; int sta[N]; int get(int x) { return t[t[x].f].ch[1]==x; } int isrt(int x) { return !(t[t[x].f].ch[0]==x||t[t[x].f].ch[1]==x); } void pushup(int x) { t[x].sum=t[lson].sum+t[rson].sum+t[x].val; } void mark(int x) { t[x].rev^=1; swap(lson,rson); } void pushdown(int x) { if(t[x].rev) { if(lson) mark(lson); if(rson) mark(rson); t[x].rev=0; } } void rotate(int x) { int old=t[x].f,fold=t[old].f,which=get(x); if(!isrt(old)) t[fold].ch[t[fold].ch[1]==old]=x; t[old].ch[which]=t[x].ch[which^1],t[t[old].ch[which]].f=old; t[x].ch[which^1]=old,t[old].f=x,t[x].f=fold; pushup(old),pushup(x); } void splay(int x) { int v=0,u=x,fa; for(sta[++v]=u;!isrt(u);u=t[u].f) sta[++v]=t[u].f; for(;v;--v) pushdown(sta[v]); for(u=t[u].f;(fa=t[x].f)!=u;rotate(x)) { if(t[fa].f!=u) rotate(get(fa)==get(x)?fa:x); } } void Access(int x) { for(int y=0;x;y=x,x=t[x].f) { splay(x); rson=y; pushup(x); } } void makeroot(int x) { Access(x),splay(x),mark(x); } void split(int x,int y) { makeroot(x),Access(y),splay(y); } void add(int x,int v) { Access(x),splay(x); t[x].val+=v,pushup(x); } int query(int x) { Access(x),splay(x); return t[x].sum; } #undef lson #undef rson }; int n,edges; int hd[N],to[N<<1],nex[N<<1],f[N],A[N],dp[N][302]; void add(int u,int v) { nex[++edges]=hd[u],hd[u]=edges,to[edges]=v; } void dfs(int u,int ff) { LCT::t[u].f=ff; for(int i=hd[u];i;i=nex[i]) { int v=to[i]; if(v==ff) continue; dfs(v,u); } } int main() { // setIO("input"); int i,j,q; scanf("%d%d",&n,&q); for(i=1;i<n;++i) { int x,y; scanf("%d%d",&x,&y); add(x,y),add(y,x); } dfs(1,0); for(i=1;i<=q;++i) { int k,m,r,flag=0; scanf("%d%d%d",&k,&m,&r); LCT::makeroot(r); for(j=1;j<=k;++j) { scanf("%d",&A[j]); LCT::add(A[j],1); } for(j=1;j<=k;++j) { f[j]=LCT::query(A[j])-1; if(f[j]>m) flag=1; } for(j=1;j<=k;++j) LCT::add(A[j],-1); if(flag) printf("0\n"); else { sort(f+1,f+1+k); dp[1][1]=1; for(j=2;j<=k;++j) { for(int p=1;p<=min(j,m);++p) { dp[j][p]=0; if(p<f[j]) dp[j][p]=dp[j-1][p-1]; else dp[j][p]=(ll)(dp[j-1][p-1]+1ll*(p-f[j])*dp[j-1][p]%mod)%mod; } } int ans=0; for(j=1;j<=m;++j) ans=(ans+dp[k][j])%mod; printf("%d\n",ans); } } return 0; }