csu1811

题意

给定一棵树,每个节点有颜色,每次仅删掉第 \(i\) 条边 \((a_i, b_i)\) ,得到两颗树,问两颗树节点的颜色集合的交集。

分析

转化一下,即所求答案为每次删掉 \(u\) 和 \(u\) 的父亲节点所连的边后形成的两颗子树的颜色集合的交集。

那么我们要求的其实和 \(u\) 的子树有关。子树的状态(颜色数量信息)是可以复用的。

可以套用 树上启发式合并 ,固定 1 为根节点,从上往下搜,每次保留子节点中节点最多的那颗子树的状态(颜色数量信息),也就是在计算当前节点所在子树的节点颜色的时候跳过这个子节点(因为前面保留了)。复杂度优化到 \(O(nlogn)\)。

code

#include<cstdio>
#include<cstring>
#include<map>
#include<algorithm>
using namespace std;
typedef long long ll;
const int MAXN = 1e5 + 10;
int n;
int fa[MAXN], son[MAXN], dep[MAXN], siz[MAXN];
int col[MAXN];
int cnt, head[MAXN];
struct Edge {
int to, next;
} e[MAXN << 1];
void addedge(int u, int v) {
e[cnt].to = v;
e[cnt].next = head[u];
head[u] = cnt++;
}
void dfs(int u) {
siz[u] = 1;
son[u] = 0;
for(int i = head[u]; ~i; i = e[i].next) {
if(e[i].to != fa[u]) {
fa[e[i].to] = u;
dep[e[i].to] = dep[u] + 1;
dfs(e[i].to);
if(siz[e[i].to] > siz[son[u]]) son[u] = e[i].to;
siz[u] += siz[e[i].to];
}
}
}
int COL[MAXN]; // 表示所有颜色的数量
int vis[MAXN];
int same; // 当前子树的颜色和另一个子树的颜色集合的交集
int C[MAXN]; // 当前子树某个颜色的数量
int ans[MAXN]; // ans[u]: 表示删掉 u 和它父亲所连的边后形成的两颗子树的答案
void update(int u, int c) {
C[col[u]] += c;
if(c > 0 && COL[col[u]] > 1) {
if(C[col[u]] == 1) same++;
else if(C[col[u]] == 0 || C[col[u]] == COL[col[u]]) same--;
}
for(int i = head[u]; ~i; i = e[i].next) {
if(e[i].to != fa[u] && !vis[e[i].to]) update(e[i].to, c);
}
}
void dfs1(int u, int flg) {
for(int i = head[u]; ~i; i = e[i].next) {
if(e[i].to != fa[u] && e[i].to != son[u]) dfs1(e[i].to, 1);
}
if(son[u]) {
dfs1(son[u], 0);
vis[son[u]] = 1;
}
update(u, 1);
ans[u] = same;
if(son[u]) vis[son[u]] = 0;
if(flg) {
update(u, -1);
same = 0;
}
}
typedef pair<int, int> P;
map<P, int> mp;
int res[MAXN];
int main() {
while(~scanf("%d", &n)) {
mp.clear();
cnt = 0;
dep[1] = 1;
fa[1] = 1;
memset(head, -1, sizeof head);
memset(COL, 0, sizeof COL);
memset(C, 0, sizeof C);
memset(vis, 0, sizeof vis);
same = 0;
for(int i = 1; i <= n; i++) {
scanf("%d", &col[i]);
COL[col[i]]++;
}
for(int i = 1; i < n; i++) {
int x, y;
scanf("%d%d", &x, &y);
mp[P(x, y)] = mp[P(y, x)] = i;
addedge(x, y);
addedge(y, x);
}
dfs(1);
dfs1(1, -1);
for(int i = 2; i <= n; i++) {
res[mp[P(i, fa[i])]] = ans[i];
}
for(int i = 1; i < n; i++) {
printf("%d\n", res[i]);
}
}
return 0;
}
05-11 20:44