原文链接https://www.cnblogs.com/zhouzhendong/p/UOJ351.html

题目传送门 - UOJ351

题意

  有一个 n 个节点的树,每次涂黑一个叶子节点(度为 1 的节点),可以重复涂黑。

  问使得白色部分的直径发生变化的期望涂黑次数。

  $n\leq 5\times 10^5$

题解

  首先考虑什么情况下直径长度会发生改变。

  考虑找到直径的中点,可能在边上。

  对于这个直径相连的每一个子树,分别算出在这个子树中的距离直径中点距离为直径长度的一半的节点个数。

  于是我们就得到了一些集合。那么,直径长度将在 只剩下一个集合有白色节点 的时候发生改变。

  于是,很容易得出菊花图的做法:

$$ ans = \sum_{i=1}^{n-1} \frac {n-1} {n-i}$$

  但是不是菊花图的时候,仍然很棘手。

  设总叶子节点个数为 $l$ ,所有集合的元素个数总和为 $tot$ 。

官方题解的算法三应该比较好理解吧

算法四比较神仙

  对于算法四,我再补充一句:

  这里,如果 x 不是最后一个被染黑的集合,那么必然全部染黑了。

代码

#include <bits/stdc++.h>
using namespace std;
int read(){
int x=0;
char ch=getchar();
while (!isdigit(ch))
ch=getchar();
while (isdigit(ch))
x=(x<<1)+(x<<3)+(ch^48),ch=getchar();
return x;
}
const int N=500005,mod=998244353;
int n;
vector <int> e[N],S;
int in[N],depth[N],fa[N];
int vis[N];
void dfs(int x,int pre,int d){
depth[x]=d,fa[x]=pre;
for (auto y : e[x])
if (y!=pre)
dfs(y,x,d+1);
}
int LCA(int x,int y){
if (depth[x]<depth[y])
swap(x,y);
while (depth[x]>depth[y])
x=fa[x];
while (x!=y)
x=fa[x],y=fa[y];
return x;
}
int FindFar(int x,int d){
vis[x]=d;
int res=x;
for (auto y : e[x])
if (!~vis[y]){
int tmp=FindFar(y,d+1);
if (vis[tmp]>vis[res])
res=tmp;
}
return res;
}
void Get(int x,int d,int md){
if (d==md)
S.back()++;
vis[x]=1;
for (auto y : e[x])
if (!vis[y])
Get(y,d+1,md);
}
void Get_S(){
memset(vis,-1,sizeof vis);
int s=FindFar(1,0);
memset(vis,-1,sizeof vis);
int t=FindFar(s,0);
if (depth[s]<depth[t])
swap(s,t);
int d=depth[s]+depth[t]-2*depth[LCA(s,t)];
memset(vis,0,sizeof vis);
if (d&1){
for (int i=d/2;i--;)
s=fa[s];
t=fa[s];
vis[s]=vis[t]=1;
S.push_back(0),Get(s,0,d/2);
S.push_back(0),Get(t,0,d/2);
}
else {
for (int i=d/2;i--;)
s=fa[s];
vis[s]=1;
for (auto y : e[s])
S.push_back(0),Get(y,1,d/2);
}
}
int Pow(int x,int y){
int ans=1;
for (;y;y>>=1,x=1LL*x*x%mod)
if (y&1)
ans=1LL*ans*x%mod;
return ans;
}
int Fac[N],Inv[N],Iv[N],h[N];
void Math_Prework(){
for (int i=Fac[0]=1;i<=n;i++)
Fac[i]=1LL*Fac[i-1]*i%mod;
Inv[n]=Pow(Fac[n],mod-2);
for (int i=n;i>=1;i--)
Inv[i-1]=1LL*Inv[i]*i%mod;
for (int i=1;i<=n;i++)
Iv[i]=1LL*Inv[i]*Fac[i-1]%mod;
for (int i=1;i<=n;i++)
h[i]=(h[i-1]+Iv[i])%mod;
}
int C(int n,int m){
if (m<0||m>n)
return 0;
return 1LL*Fac[n]*Inv[m]%mod*Inv[n-m]%mod;
}
int main(){
n=read();
for (int i=1;i<n;i++){
int a=read(),b=read();
e[a].push_back(b);
e[b].push_back(a);
in[a]++,in[b]++;
}
dfs(1,0,0);
Get_S();
int l=0;
for (int i=1;i<=n;i++)
if (in[i]==1)
l++;
Math_Prework();
int tot=0;
for (auto y : S)
tot+=y;
// 设当前剩余 k 个,总共有 l 个,选出 x 个数:
// l/k + l/(k-1) + ... + l/(k-x+1)
// l * (h[k]-h[k-x])
int ans=0;
for (auto y : S)
ans=(1LL*h[tot-y]+ans)%mod;
ans=(-1LL*((int)S.size()-1)*h[tot]%mod+ans+mod)%mod;
ans=1LL*ans*l%mod;
printf("%d",ans);
return 0;
}
05-11 22:21