另一道树题
题目大意:
数据范围:
题解:
这个题第一眼能发现的是,我们的答案分成两种情况。
第一种是在非根节点汇合,第二种是在根节点汇合。
尝试枚举在第几回合结束,假设在第$i$回合结束的方案数为$f_i$,那么总答案就是$\sum\limits_{i = 1} ^ {N - 1}i\times f_i$。
显然没法求这个$f_i$....
进而,觉得这鬼东西的后缀和好像比较好求,就是$g _ i = \sum\limits_{j = i} ^ {N - 1} f _ j$。
由于我们就相当于对于深度相等的点的讨论,不难想到$bfs$序。
只考虑不在根节点汇合的情况。
发现,其实就是一段连续的区间,他们在$i$不小于一个值的时候,最多只能选取一个值。
也就是说随着我们枚举的回合数递增,这些连续的区间会存在一些合并的情况。
至于什么时候合并呢?其实就根据,相邻两个点到其$lca$的深度有关(这两个点的深度得相等),就是在这个深度差恰好等于回合数的时候,我们实施合并操作。
这样就完美的解决了不是非根汇合的情况。
考虑在根节点汇合咋办。
其实就相当于,随着回合数递增,所有深度不大于$i$的点只能选一个,就相当于和根节点合并咯。
总之通通用并查集维护就好了。
代码:
#include <bits/stdc++.h> #define N 200010 using namespace std; int head[N], to[N << 1], nxt[N << 1], tot; struct Node {
int x, y;
}; vector <Node> v[N]; queue <int> q; int f[20][N], g[N], F[N], S[N], dep[N], dic[N], n, inv[N]; const int mod = 998244353 ; typedef long long ll; char *p1, *p2, buf[100000]; #define nc() (p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, 100000, stdin), p1 == p2) ? EOF : *p1 ++ ) int rd() {
int x = 0, f = 1;
char c = nc();
while (c < 48) {
if (c == '-')
f = -1;
c = nc();
}
while (c > 47) {
x = (((x << 2) + x) << 1) + (c ^ 48), c = nc();
}
return x * f;
} int qpow(int x, int y) {
int ans = 1;
while (y) {
if (y & 1) {
ans = (ll)ans * x % mod;
}
y >>= 1;
x = (ll)x * x % mod;
}
return ans;
} inline void add(int x, int y) {
to[ ++ tot] = y;
nxt[tot] = head[x];
head[x] = tot;
} int lca(int x, int y) {
if (dep[x] < dep[y])
swap(x, y);
for (int i = 19; ~i; i -- ) {
if (dep[f[i][x]] >= dep[y]) {
x = f[i][x];
}
}
if (x == y)
return x;
for (int i = 19; ~i; i -- ) {
if (f[i][x] != f[i][y]) {
x = f[i][x];
y = f[i][y];
}
}
return f[0][x];
} void dfs(int p, int fa) {
v[dep[p]].push_back((Node){1, p});
f[0][p] = fa;
for (int i = 1; i <= 19; i ++ ) {
f[i][p] = f[i - 1][f[i - 1][p]];
}
for (int i = head[p]; i; i = nxt[i]) {
if (to[i] != fa) {
dep[to[i]] = dep[p] + 1;
dfs(to[i], p);
}
}
} void bfs() {
while (!q.empty())
q.pop();
q.push(1);
int cnt = 0;
while (!q.empty()) {
int x = q.front();
q.pop();
dic[ ++ cnt] = x;
for (int i = head[x]; i; i = nxt[i]) {
if (to[i] != f[0][x]) {
q.push(to[i]);
}
}
}
for (int i = 1; i < n; i ++ ) {
if (dep[dic[i]] == dep[dic[i + 1]]) {
v[dep[dic[i]] - dep[lca(dic[i], dic[i + 1])]].push_back((Node) {dic[i], dic[i + 1]});
}
}
} int find(int x) {
return F[x] == x ? x : F[x] = find(F[x]);
} int main() {
n = rd();
for (int i = 1; i <= n; i ++ ) {
F[i] = i;
S[i] = 1;
}
for (int i = 2; i <= n; i ++ ) {
int x = rd();
add(x, i);
add(i, x);
}
dfs(1, 1);
bfs();
inv[0] = 1;
for (int i = 1; i <= n; i ++ )
inv[i] = qpow(i, mod - 2); // for (int i = 0 ; i <= n; i ++ ) {
// printf("%d ", inv[i]);
// }
// puts(""); int mdl = qpow(2, n);
for (int i = 1; i < n; i ++ ) {
g[i] = (mdl - n - 1 + mod) % mod;
int len = v[i].size();
for (int j = 0; j < len; j ++ ) {
int x = v[i][j].x, y = v[i][j].y;
x = find(x), y = find(y);
if (x != y) {
mdl = (ll)mdl * inv[S[x] + 1] % mod * inv[S[y] + 1] % mod;
F[x] = y; S[y] += S[x];
mdl = (ll)mdl * (S[y] + 1) % mod;
}
}
}
int ans = 0;
for (int i = 1; i < n; i ++ ) {
ans = (ans + (ll)(g[i] - g[i + 1] + mod) % mod * i % mod) % mod;
}
cout << ans << endl ;
return 0;
}
小结:好题好题,这个题的思路行云流水。重点是能否想到把那个,一段区间只能选一个这个事情考虑清楚,从而转变成区间的合并问题,这是关键。