思路
大部分是感性理解,不保证完全正确。
不能算是神仙题,但我还是不会qwq
这题显然就是求:把每一棵树分成若干条链,然后把链拼成一个环,使得相邻的链不来自同一棵树,的方案数。(我才不告诉你们我这一行都没推出来呢)
可以发现后面那步只和每棵树被分成了几段有关,所以第一步可以先求出每棵树分成几段的方案数。
具体方法:设\(dp_{x,i,0/1/2}\)表示\(x\)子树被填满,共用\(i\)条链,\(x\)所在的链处于 {只有\(x\)一个点/有一条从下面到\(x\)的链/有从下到\(x\)又到下的链} 的状态,然后随便DP。(我的代码中1、2两种情况不算在\(i\)里面,只有确定不再改变时才加进去)
(注意一条链有两个方向,所以链长大于1时方案数乘2)
一通DP之后可以得到\(f\)数组:\(f_i\)表示将树分成\(i\)条链的方案数。
先考虑把链不是拼成一个环,而是一个排列,使得相邻链颜色不同的方案数。这似乎是个经典问题。
首先,排列的个数要用指数型生成函数来完成,但颜色不同的限制呢?
考虑容斥:设一共有\(k\)条链,也就是有\(k-1\)个空隙必须被填满。容斥有多少个空隙可能会被填满,那么这一项的值就是
\]
所以整一棵树的生成函数就是
\]
组成排列的做完了,但组成一个环又该怎么办?
考虑断环为链。我们钦定第一棵树的第一条链必须放在第一个,于是第一棵树的生成函数变为
\]
(\((k-1)!\)表示钦定第一个位置不变,那么剩下的有\((k-1)!\)种排列;\(j-1\)表示有\(k\)条链时其实只会有\(k-1\)条链参与到后面的排列中去)
然而还有第一个和最后一个不能颜色相同的限制,所以第一棵树的生成函数还要减去
\]
(\(j-2\)表示钦定\(k-1\)条链的排列中的最后一个必须放在序列末尾,所以只有\(k-2\)条链参与到后面的排列)
最后暴力把所有生成函数卷在一起即可。
代码
代码很丑,请谨慎阅读qwq
#include<bits/stdc++.h>
clock_t t=clock();
namespace my_std{
using namespace std;
#define pii pair<int,int>
#define fir first
#define sec second
#define MP make_pair
#define rep(i,x,y) for (int i=(x);i<=(y);i++)
#define drep(i,x,y) for (int i=(x);i>=(y);i--)
#define go(x) for (int _=head[x];_;_=edge[_].nxt)
#define templ template<typename T>
#define sz 5050
#define mod 998244353ll
typedef long long ll;
typedef double db;
mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());
templ inline T rnd(T l,T r) {return uniform_int_distribution<T>(l,r)(rng);}
templ inline bool chkmax(T &x,T y){return x<y?x=y,1:0;}
templ inline bool chkmin(T &x,T y){return x>y?x=y,1:0;}
templ inline void read(T& t)
{
t=0;char f=0,ch=getchar();double d=0.1;
while(ch>'9'||ch<'0') f|=(ch=='-'),ch=getchar();
while(ch<='9'&&ch>='0') t=t*10+ch-48,ch=getchar();
if(ch=='.'){ch=getchar();while(ch<='9'&&ch>='0') t+=d*(ch^48),d*=0.1,ch=getchar();}
t=(f?-t:t);
}
template<typename T,typename... Args>inline void read(T& t,Args&... args){read(t); read(args...);}
char __sr[1<<21],__z[20];int __C=-1,__zz=0;
inline void Ot(){fwrite(__sr,1,__C+1,stdout),__C=-1;}
inline void print(register int x)
{
if(__C>1<<20)Ot();if(x<0)__sr[++__C]='-',x=-x;
while(__z[++__zz]=x%10+48,x/=10);
while(__sr[++__C]=__z[__zz],--__zz);__sr[++__C]='\n';
}
void file()
{
#ifdef NTFOrz
freopen("a.in","r",stdin);
#endif
}
inline void chktime()
{
#ifndef ONLINE_JUDGE
cout<<(clock()-t)/1000.0<<'\n';
#endif
}
#ifdef mod
ll ksm(ll x,int y){ll ret=1;for (;y;y>>=1,x=x*x%mod) if (y&1) ret=ret*x%mod;return ret;}
ll inv(ll x){return ksm(x,mod-2);}
#else
ll ksm(ll x,int y){ll ret=1;for (;y;y>>=1,x=x*x) if (y&1) ret=ret*x;return ret;}
#endif
// inline ll mul(ll a,ll b){ll d=(ll)(a*(double)b/mod+0.5);ll ret=a*b-d*mod;if (ret<0) ret+=mod;return ret;}
}
using namespace my_std;
inline void M(ll &x){x-=((mod-x)>>31&mod);}
int n;
struct hh{int t,nxt;}edge[sz<<1];
int head[sz],ecnt;
void make_edge(int f,int t)
{
edge[++ecnt]=(hh){t,head[f]};
head[f]=ecnt;
edge[++ecnt]=(hh){f,head[t]};
head[t]=ecnt;
}
ll dp[sz][sz][3],f[sz][3];
int size[sz];
void dfs(int x,int fa)
{
dp[x][0][0]=size[x]=1;
#define v edge[_].t
go(x) if (v!=fa)
{
dfs(v,x);
#define upd(a,b,p,q,r) M(f[i+j+q][p]+=dp[x][i][a]*dp[v][j][b]%mod*r%mod)
rep(i,0,size[x])
rep(j,0,size[v])
upd(0,0,1,0,1),upd(0,0,0,1,1),
upd(0,1,0,1,2),upd(0,1,1,0,1),
upd(0,2,0,0,1),
upd(1,0,1,1,1),upd(1,0,2,1,2),
upd(1,1,1,1,2),upd(1,1,2,1,2),
upd(1,2,1,0,1),
upd(2,0,2,1,1),
upd(2,1,2,1,2),
upd(2,2,2,0,1);
size[x]+=size[v];
rep(i,0,size[x]) rep(j,0,2) dp[x][i][j]=f[i][j],f[i][j]=0;
}
#undef v
}
ll cnt[sz];
ll fac[sz],_fac[sz];
void init(){_fac[0]=fac[0]=1;rep(i,1,sz-1) _fac[i]=inv(fac[i]=fac[i-1]*i%mod);}
ll C(int n,int m){return n>=m&&m>=0?fac[n]*_fac[m]%mod*_fac[n-m]%mod:0;}
ll F[sz],lenF,G[sz],lenG,tmp[sz];
void mul()
{
rep(i,0,lenF)
rep(j,0,lenG)
M(tmp[i+j]+=F[i]*G[j]%mod);
lenF+=lenG;
rep(i,0,lenF) F[i]=tmp[i],tmp[i]=G[i]=0;
}
int main()
{
file();
init();
F[0]=1;
int m;read(m);
rep(_,1,m)
{
read(n);
ecnt=0;rep(i,1,n) head[i]=0;
int x,y;
rep(i,1,n-1) read(x,y),make_edge(x,y);
rep(i,1,n) rep(j,0,n) rep(k,0,2) dp[i][j][k]=0;
dfs(1,0);
rep(i,0,n) cnt[i]=0;
rep(i,0,n) M(cnt[i+1]+=dp[1][i][0]),M(cnt[i+1]+=dp[1][i][1]*2%mod),M(cnt[i]+=dp[1][i][2]);
lenG=n;
rep(k,1,n) rep(j,_==1,k)
{
ll val=cnt[k]*fac[k-(_==1)]%mod*(((k-j)&1)?mod-1:1ll)%mod*C(k-1,j-1)%mod;
if (_==1) M(G[j-1]+=val*_fac[j-1]%mod),j>1&&(M(G[j-2]+=mod-val*_fac[j-2]%mod),0);
else M(G[j]+=val*_fac[j]%mod);
}
mul();
}
ll ans=0;
rep(i,0,lenF) M(ans+=F[i]*fac[i]%mod);
cout<<ans;
return 0;
}