题目链接:https://www.luogu.org/problem/P3313
这道题目就是树链剖分+线段树动态开点。
首先我们来理解一下这道题目的线段树部分。
如果这不是树链上面的操作,而是区间上面的操作,那么我们可以这么解决。
它最多有 \(n = 10^5\) 种颜色,而我们需要每种颜色动态去建树。
那么,如果按照传统方法去建一棵线段树,每一棵树都需要 \(n \times 4\) 个节点,那么总的节点数就会达到 \(n^2 \times 4 = 4 \times 10^{10}\) 数量级,是不能承受的。
那么,可以考虑动态建树,即一开始我创建 \(n\) 棵线段树,但是只创建 \(n\) 个根节点, \(root[i]\) 表示颜色为 \(i\) 的线段树的根节点,我们知道只有更新操作才会涉及对节点的更新。
而这里都是单点更新,这就意味着每次更新只会最多扩展 \(\lceil \log_2n \rceil\) 个节点,那么就算所有的操作都是更新,总共也只会扩展 \(q \times \log_2n\) 个节点,而 \(q \le 10^5\) 所以总量是可以承受的。
然后就按照这种思路来给线段树动态开点,即可解决这个问题(然而我一开始用数组处理的时候有bug但是一直没有找到bug所在 ,所以就用指针的形式来解决了这个问题)。
然后你就会发现这就是树链剖分+上述的线段树处理。
然后这道题目就变得很简单。
实现代码如下:
#include <bits/stdc++.h>
using namespace std;
#define INF (1<<29)
const int maxn = 100010;
int fa[maxn],
dep[maxn],
size[maxn],
son[maxn],
top[maxn],
seg[maxn], seg_cnt,
rev[maxn];
vector<int> g[maxn];
void dfs1(int u, int p) {
size[u] = 1;
for (vector<int>::iterator it = g[u].begin(); it != g[u].end(); it ++) {
int v = (*it);
if (v == p) continue;
fa[v] = u;
dep[v] = dep[u] + 1;
dfs1(v, u);
size[u] += size[v];
if (size[v] >size[son[u]]) son[u] = v;
}
}
void dfs2(int u, int tp) {
seg[u] = ++seg_cnt;
rev[seg_cnt] = u;
top[u] = tp;
if (son[u]) dfs2(son[u], tp);
for (vector<int>::iterator it = g[u].begin(); it != g[u].end(); it ++) {
int v = (*it);
if (v == fa[u] || v == son[u]) continue;
dfs2(v, v);
}
}
struct Tnode {
int l, r, sumw, maxw;
Tnode *lson, *rson;
Tnode(int _l, int _r, int _sumw, int _maxw) { l = _l; r = _r; sumw = _sumw; maxw = _maxw; lson = rson = NULL; }
} *root[maxn];
int n, q, w[maxn], c[maxn];
void push_up(Tnode *rt) {
rt->sumw = rt->maxw = 0;
if (rt->lson != NULL) {
rt->sumw += rt->lson->sumw;
rt->maxw = max(rt->maxw, rt->lson->maxw);
}
if (rt->rson != NULL) {
rt->sumw += rt->rson->sumw;
rt->maxw = max(rt->maxw, rt->rson->maxw);
}
}
void update(int p, int v, Tnode *rt) {
int l = rt->l, r = rt->r, mid = (rt->l + rt->r) / 2;
if (l == r) {
rt->sumw = rt->maxw = v;
return;
}
if (p <= mid) {
if (rt->lson == NULL) rt->lson = new Tnode(l, mid, 0, 0);
update(p, v, rt->lson);
}
else {
if (rt->rson == NULL) rt->rson = new Tnode(mid+1, r, 0, 0);
update(p, v, rt->rson);
}
push_up(rt);
}
int query_sum(int L, int R, Tnode *rt) {
int l = rt->l, r = rt->r, mid = (rt->l + rt->r) / 2;
if (L <= l && r <= R) return rt->sumw;
int tmp = 0;
if (L <= mid && rt->lson != NULL) tmp += query_sum(L, R, rt->lson);
if (R > mid && rt->rson != NULL) tmp += query_sum(L, R, rt->rson);
return tmp;
}
int query_max(int L, int R, Tnode *rt) {
int l = rt->l, r = rt->r, mid = (rt->l + rt->r) / 2;
if (L <= l && r <= R) return rt->maxw;
int tmp = 0;
if (L <= mid && rt->lson != NULL) tmp = max(tmp, query_max(L, R, rt->lson));
if (R > mid && rt->rson != NULL) tmp = max(tmp, query_max(L, R, rt->rson));
return tmp;
}
void init() {
for (int i = 1; i < maxn; i ++) root[i] = new Tnode(1, n, 0, 0);
}
int ask_sum(int u, int v) {
int res = 0;
Tnode* rt = root[c[u]];
while (top[u] != top[v]) {
if (dep[top[u]] < dep[top[v]]) swap(u, v);
res += query_sum(seg[top[u]], seg[u], rt);
u = fa[top[u]];
}
if (dep[u] < dep[v]) swap(u, v);
res += query_sum(seg[v], seg[u], rt);
return res;
}
int ask_max(int u, int v) {
int res = -INF;
Tnode* rt = root[c[u]];
while (top[u] != top[v]) {
if (dep[top[u]] < dep[top[v]]) swap(u, v);
res = max(res, query_max(seg[top[u]], seg[u], rt));
u = fa[top[u]];
}
if (dep[u] < dep[v]) swap(u, v);
res = max(res, query_max(seg[v], seg[u], rt));
return res;
}
int x, y;
string op;
int main() {
cin >> n >> q;
for (int i = 1; i <= n; i ++) {
cin >> w[i] >> c[i];
}
for (int i = 1; i < n; i ++) {
int x, y;
cin >> x >> y;
g[x].push_back(y);
g[y].push_back(x);
}
dep[1] = fa[1] = 1;
dfs1(1, -1);
dfs2(1, 1);
init();
for (int i = 1; i <= n; i ++) {
update(seg[i], w[i], root[c[i]]);
}
while (q --) {
cin >> op >> x >> y;
if (op == "CC") {
update(seg[x], 0, root[c[x]]);
c[x] = y;
update(seg[x], w[x], root[c[x]]);
}
else if (op == "CW") {
update(seg[x], y, root[c[x]]);
w[x] = y;
}
else if (op == "QS") {
cout << ask_sum(x, y) << endl;
}
else { // QM
cout << ask_max(x, y) << endl;
}
}
return 0;
}