题目描述
Description
Input
第一行两个个整数 n,k。
之后 n -1 行,第 i 行两个整数 ui, vi, 表示一条树边。
保证输入的数据构成一棵树。
Output
一行一个数表示答案。
Sample Input
Sample Input1
3 2
1 2
1 3
Sample Input2
10 367305945
1 2
2 3
2 4
3 5
2 6
5 7
1 8
4 9
1 10
Sample Output
Sample Output1
28
Explanation
1, 2, 3 : c = 3
1, 3, 2 : c = 3
2, 1, 3 : c = 2
2, 3, 1 : c = 2
3, 1, 2 : c = 1
3, 2, 1 : c = 1
Sample Output2
628657647
Data Constraint
题解
另一种做法:https://blog.csdn.net/qq_43649416/article/details/102925575
orz cold_chair
如果能先确定那些点是黑点(好点),那么就有若干约束条件:
①白点要比第一个黑点祖先小
②黑点要比第一个黑点祖先大
两点间有边即两点间存在大小限制
20%可以暴力枚举黑点,然后把白点的限制容斥,变成比黑点祖先大减与黑点祖先无大小限制,然后就可以变成一棵外向树+若干无关点的情况
具体来说,如果是比黑点祖先大的情况,那么相当于把白点变成黑点(不计算K的贡献)并单独提出来变为叶节点,并与原黑点祖先连边
因为若干白点与一个黑点有边时,白点之间没有大小限制,所以提出来作为叶节点
而且由于考虑的实际上是白点的贡献,所以不需要乘K
对于与黑点祖先无大小限制的情况,相当于直接把白点删掉
因为白点只与第一个黑点祖先有大小关系
外向树上的每个点都要比儿子小,所以一棵外向树的期望出现概率为∏1/size
这样就不用考虑具体的大小了
100%考虑用dp来实现上面的做法
设f[i][j],表示以i为根的子树,外向树大小为j的 各种情况的贡献、外向树概率、容斥系数的积 的和
显然j<=i,所以合并相当于O(n^2)树上背包
并且假设在i的祖先上有一个虚拟的黑点(因为具体的位置不重要)
,最终答案=∑f[root][j]
子树合并就直接对应相乘,考虑i的黑白情况
①i是黑点
那么i相当于把原来的虚拟黑点,i子树内的白点的边连向i
i则向i的祖先中的虚拟黑点连边,把K和外向树的概率算上
因为i是黑点,所以不需要容斥
f[i][j]*(1/(j+1))*K-->F[i][j+1]
②i是白点(i不为根)
考虑和虚拟黑点的连边,容斥一下变为 没有限制-比黑点祖先大 的情况
1、比黑点祖先大
等于把i变为叶子黑点,不用*K,子树大小为1所以不用乘外向树概率,容斥系数为-1
-f[i][j]-->F[i][j+1]
2、没有限制
等于把i删掉,容斥系数为1
f[i][j]-->F[i][j]
预处理1/x,O(n^2)dp即可
code
#include <algorithm>
#include <iostream>
#include <cstdlib>
#include <cstring>
#include <cstdio>
#define fo(a,b,c) for (a=b; a<=c; a++)
#define fd(a,b,c) for (a=b; a>=c; a--)
#define add(a,b) a=((a)+(b))%998244353
#define mod 998244353
using namespace std;
int a[10001][2];
int ls[5001];
long long f[5001][5001];
long long F[5001];
int size[5001];
long long w[5001];
int n,i,j,k,l,len;
long long K,ans;
void New(int x,int y)
{
++len;
a[len][0]=y;
a[len][1]=ls[x];
ls[x]=len;
}
void dfs(int Fa,int t)
{
int i,j,k;
f[t][0]=1;
size[t]=0;
for (i=ls[t]; i; i=a[i][1])
if (a[i][0]!=Fa)
{
dfs(t,a[i][0]);
fo(j,0,size[t])
{
fo(k,0,size[a[i][0]])
add(F[j+k],f[t][j]*f[a[i][0]][k]%mod);
}
size[t]+=size[a[i][0]];
fo(j,0,size[t])
f[t][j]=F[j],F[j]=0;
}
fo(j,0,size[t])
{
add(F[j+1],f[t][j]*K%mod*w[j+1]); //black
if (t>1) //white
{
add(F[j],f[t][j]);
add(F[j+1],-f[t][j]);
}
}
++size[t];
fo(j,0,size[t])
f[t][j]=F[j],F[j]=0;
}
int main()
{
freopen("random.in","r",stdin);
freopen("random.out","w",stdout);
scanf("%d%lld",&n,&K);
w[1]=1;
fo(i,2,n)
{
w[i]=mod-w[mod%i]*(mod/i)%mod;
scanf("%d%d",&j,&k);
New(j,k);
New(k,j);
}
dfs(0,1);
fo(i,1,n) add(ans,f[1][i]);
fo(i,1,n) ans=ans*i%mod;
printf("%lld\n",(ans+mod)%mod);
fclose(stdin);
fclose(stdout);
return 0;
}