题目链接:https://www.luogu.org/problem/P2486
首先这是一道树链剖分+线段树的题。
线段树部分
线段树区间操作,每一个线段对应的点包含三个信息:
- \(l\):表示这个区间最左边的点的数值;
- \(r\):表示这个区间最右边的点的数值;
- \(cnt\):表示这个区间有多少个数值段。
合并的时候:
- 根节点的 \(l\) 值等于左儿子节点的 \(l\) 值;
- 根节点的 \(r\) 值等于右儿子节点的 \(r\) 值;
- 根节点的 \(cnt\) 值取决于左儿子的 \(r\) 值和右儿子的 \(l\) 值是否相等,
- 如果相等,则为:左儿子的 \(cnt\) + 右儿子的 \(cnt\) - 1
- 否则,为:左儿子的 \(cnt\) + 右儿子的 \(cnt\)
更新的时候,如果节点表示的这一段区间全在区间范围内,
则将节点的 \(l\) 和 \(r\) 都置为将要更新的值,并将节点的 \(cnt\) 置为 1。
因为涉及区间操作,需要用到延迟操作。
树链剖分部分
此部分支持两种操作:
- 更新:这部分比较好实现;
- 查询:这部分需要你记录树链查询的时候的每一条边的信息,然后将这些信息进行汇总,处理起来稍有一些繁琐。
实现代码如下:
#include <bits/stdc++.h>
using namespace std;
#define INF (1<<29)
const int maxn = 100010;
int fa[maxn],
dep[maxn],
size[maxn],
son[maxn],
top[maxn],
seg[maxn], seg_cnt,
rev[maxn];
vector<int> g[maxn];
void dfs1(int u, int p) {
size[u] = 1;
for (vector<int>::iterator it = g[u].begin(); it != g[u].end(); it ++) {
int v = (*it);
if (v == p) continue;
fa[v] = u;
dep[v] = dep[u] + 1;
dfs1(v, u);
size[u] += size[v];
if (size[v] >size[son[u]]) son[u] = v;
}
}
void dfs2(int u, int tp) {
seg[u] = ++seg_cnt;
rev[seg_cnt] = u;
top[u] = tp;
if (son[u]) dfs2(son[u], tp);
for (vector<int>::iterator it = g[u].begin(); it != g[u].end(); it ++) {
int v = (*it);
if (v == fa[u] || v == son[u]) continue;
dfs2(v, v);
}
}
struct Node {
int l, r, cnt;
Node () {}
Node (int _l, int _r, int _cnt) { l = _l; r = _r; cnt = _cnt; }
Node reverse() { return Node(r, l, cnt); }
} tree[maxn<<2];
int n, lazy[maxn<<2], init_color[maxn];
#define lson l, mid, rt<<1
#define rson mid+1, r, rt<<1|1
void push_up(int rt) {
tree[rt].l = tree[rt<<1].l;
tree[rt].r = tree[rt<<1|1].r;
tree[rt].cnt = tree[rt<<1].cnt + tree[rt<<1|1].cnt - (tree[rt<<1].r == tree[rt<<1|1].l ? 1 : 0);
}
void push_down(int rt) {
if (lazy[rt]) {
lazy[rt<<1] = lazy[rt<<1|1] = lazy[rt];
tree[rt<<1].cnt = tree[rt<<1|1].cnt = 1;
tree[rt<<1].l = tree[rt<<1].r = tree[rt<<1|1].l = tree[rt<<1|1].r = lazy[rt];
lazy[rt] = 0;
}
}
void build(int l, int r, int rt) {
if (l == r) {
tree[rt] = Node(init_color[rev[l]], init_color[rev[l]], 1);
return;
}
int mid = (l + r) / 2;
build(lson);
build(rson);
push_up(rt);
}
void update(int L, int R, int v, int l, int r, int rt) {
if (L <= l && r <= R) {
tree[rt] = Node(v, v, 1);
lazy[rt] = v;
return;
}
push_down(rt);
int mid = (l + r) / 2;
if (L <= mid) update(L, R, v, lson);
if (R > mid) update(L, R, v, rson);
push_up(rt);
}
Node query(int L, int R, int l, int r, int rt) {
if (L <= l && r <= R) return tree[rt];
push_down(rt);
int mid = (l + r) / 2;
if (L > mid) return query(L, R, rson);
else if (R <= mid) return query(L, R, lson);
else {
Node a = query(L, R, lson);
Node b = query(L, R, rson);
return Node(a.l, b.r, a.cnt + b.cnt - (a.r == b.l ? 1 : 0));
}
}
void chain_update(int u, int v, int val) {
while (top[u] != top[v]) {
if (dep[top[u]] < dep[top[v]]) swap(u, v);
update(seg[top[u]], seg[u], val, 1, n, 1);
u = fa[top[u]];
}
if (dep[u] < dep[v]) swap(u, v);
update(seg[v], seg[u], val, 1, n, 1);
}
vector<Node> res1, res2, res;
void chain_query(int u, int v) {
res1.clear();
res2.clear();
res.clear();
while (top[u] != top[v]) {
if (dep[top[u]] > dep[top[v]]) {
res1.push_back(query(seg[top[u]], seg[u], 1, n, 1));
u = fa[top[u]];
}
else {
res2.push_back(query(seg[top[v]], seg[v], 1, n, 1));
v = fa[top[v]];
}
}
if (dep[u] > dep[v]) res1.push_back(query(seg[v], seg[u], 1, n, 1));
else res2.push_back(query(seg[u], seg[v], 1, n, 1));
int sz = res1.size();
for (int i = 0; i < sz; i ++) res.push_back(res1[i].reverse());
sz = res2.size();
for (int i = sz-1; i >= 0; i --) res.push_back(res2[i]);
Node tmp = res[0];
sz = res.size();
for (int i = 1; i < sz; i ++) {
int delta = (tmp.r == res[i].l);
tmp.cnt += res[i].cnt - delta;
tmp.r = res[i].r;
}
cout << tmp.cnt << endl;
}
int m, a, b, c;
char op[2];
int main() {
cin >> n >> m;
for (int i = 1; i <= n; i ++) cin >> init_color[i];
for (int i = 1; i < n; i ++) {
int u, v;
cin >> u >>v;
g[u].push_back(v);
g[v].push_back(u);
}
dep[1] = fa[1] = 1;
dfs1(1, -1);
dfs2(1, 1);
build(1, n, 1);
while (m --) {
cin >> op;
if (op[0] == 'C') {
cin >> a >> b >> c;
chain_update(a, b, c);
}
else {
cin >> a >> b;
chain_query(a, b);
}
}
return 0;
}