题目链接:https://www.luogu.org/problem/P2486
首先这是一道树链剖分+线段树的题。

线段树部分

线段树区间操作,每一个线段对应的点包含三个信息:

  • \(l\):表示这个区间最左边的点的数值;
  • \(r\):表示这个区间最右边的点的数值;
  • \(cnt\):表示这个区间有多少个数值段。

合并的时候:

  • 根节点的 \(l\) 值等于左儿子节点的 \(l\) 值;
  • 根节点的 \(r\) 值等于右儿子节点的 \(r\) 值;
  • 根节点的 \(cnt\) 值取决于左儿子的 \(r\) 值和右儿子的 \(l\) 值是否相等,
    1. 如果相等,则为:左儿子的 \(cnt\) + 右儿子的 \(cnt\) - 1
    2. 否则,为:左儿子的 \(cnt\) + 右儿子的 \(cnt\)

更新的时候,如果节点表示的这一段区间全在区间范围内,
则将节点的 \(l\)\(r\) 都置为将要更新的值,并将节点的 \(cnt\) 置为 1。

因为涉及区间操作,需要用到延迟操作。

树链剖分部分

此部分支持两种操作:

  • 更新:这部分比较好实现;
  • 查询:这部分需要你记录树链查询的时候的每一条边的信息,然后将这些信息进行汇总,处理起来稍有一些繁琐。

实现代码如下:

#include <bits/stdc++.h>
using namespace std;
#define INF (1<<29)
const int maxn = 100010;
int fa[maxn],
    dep[maxn],
    size[maxn],
    son[maxn],
    top[maxn],
    seg[maxn], seg_cnt,
    rev[maxn];
vector<int> g[maxn];
void dfs1(int u, int p) {
    size[u] = 1;
    for (vector<int>::iterator it = g[u].begin(); it != g[u].end(); it ++) {
        int v = (*it);
        if (v == p) continue;
        fa[v] = u;
        dep[v] = dep[u] + 1;
        dfs1(v, u);
        size[u] += size[v];
        if (size[v] >size[son[u]]) son[u] = v;
    }
}
void dfs2(int u, int tp) {
    seg[u] = ++seg_cnt;
    rev[seg_cnt] = u;
    top[u] = tp;
    if (son[u]) dfs2(son[u], tp);
    for (vector<int>::iterator it = g[u].begin(); it != g[u].end(); it ++) {
        int v = (*it);
        if (v == fa[u] || v == son[u]) continue;
        dfs2(v, v);
    }
}
struct Node {
    int l, r, cnt;
    Node () {}
    Node (int _l, int _r, int _cnt) { l = _l; r = _r; cnt = _cnt; }
    Node reverse() { return Node(r, l, cnt); }
} tree[maxn<<2];
int n, lazy[maxn<<2], init_color[maxn];
#define lson l, mid, rt<<1
#define rson mid+1, r, rt<<1|1
void push_up(int rt) {
    tree[rt].l = tree[rt<<1].l;
    tree[rt].r = tree[rt<<1|1].r;
    tree[rt].cnt = tree[rt<<1].cnt + tree[rt<<1|1].cnt - (tree[rt<<1].r == tree[rt<<1|1].l ? 1 : 0);
}
void push_down(int rt) {
    if (lazy[rt]) {
        lazy[rt<<1] = lazy[rt<<1|1] = lazy[rt];
        tree[rt<<1].cnt = tree[rt<<1|1].cnt = 1;
        tree[rt<<1].l = tree[rt<<1].r = tree[rt<<1|1].l = tree[rt<<1|1].r = lazy[rt];
        lazy[rt] = 0;
    }
}
void build(int l, int r, int rt) {
    if (l == r) {
        tree[rt] = Node(init_color[rev[l]], init_color[rev[l]], 1);
        return;
    }
    int mid = (l + r) / 2;
    build(lson);
    build(rson);
    push_up(rt);
}
void update(int L, int R, int v, int l, int r, int rt) {
    if (L <= l && r <= R) {
        tree[rt] = Node(v, v, 1);
        lazy[rt] = v;
        return;
    }
    push_down(rt);
    int mid = (l + r) / 2;
    if (L <= mid) update(L, R, v, lson);
    if (R > mid) update(L, R, v, rson);
    push_up(rt);
}
Node query(int L, int R, int l, int r, int rt) {
    if (L <= l && r <= R) return tree[rt];
    push_down(rt);
    int mid = (l + r) / 2;
    if (L > mid) return query(L, R, rson);
    else if (R <= mid) return query(L, R, lson);
    else {
        Node a = query(L, R, lson);
        Node b = query(L, R, rson);
        return Node(a.l, b.r, a.cnt + b.cnt - (a.r == b.l ? 1 : 0));
    }
}

void chain_update(int u, int v, int val) {
    while (top[u] != top[v]) {
        if (dep[top[u]] < dep[top[v]]) swap(u, v);
        update(seg[top[u]], seg[u], val, 1, n, 1);
        u = fa[top[u]];
    }
    if (dep[u] < dep[v]) swap(u, v);
    update(seg[v], seg[u], val, 1, n, 1);
}
vector<Node> res1, res2, res;
void chain_query(int u, int v) {
    res1.clear();
    res2.clear();
    res.clear();
    while (top[u] != top[v]) {
        if (dep[top[u]] > dep[top[v]]) {
            res1.push_back(query(seg[top[u]], seg[u], 1, n, 1));
            u = fa[top[u]];
        }
        else {
            res2.push_back(query(seg[top[v]], seg[v], 1, n, 1));
            v = fa[top[v]];
        }
    }
    if (dep[u] > dep[v]) res1.push_back(query(seg[v], seg[u], 1, n, 1));
    else res2.push_back(query(seg[u], seg[v], 1, n, 1));
    int sz = res1.size();
    for (int i = 0; i < sz; i ++) res.push_back(res1[i].reverse());
    sz = res2.size();
    for (int i = sz-1; i >= 0; i --) res.push_back(res2[i]);
    Node tmp = res[0];
    sz = res.size();
    for (int i = 1; i < sz; i ++) {
        int delta = (tmp.r == res[i].l);
        tmp.cnt += res[i].cnt - delta;
        tmp.r = res[i].r;
    }
    cout << tmp.cnt << endl;
}

int m, a, b, c;
char op[2];
int main() {
    cin >> n >> m;
    for (int i = 1; i <= n; i ++) cin >> init_color[i];
    for (int i = 1; i < n; i ++) {
        int u, v;
        cin >> u >>v;
        g[u].push_back(v);
        g[v].push_back(u);
    }
    dep[1] = fa[1] = 1;
    dfs1(1, -1);
    dfs2(1, 1);
    build(1, n, 1);
    while (m --) {
        cin >> op;
        if (op[0] == 'C') {
            cin >> a >> b >> c;
            chain_update(a, b, c);
        }
        else {
            cin >> a >> b;
            chain_query(a, b);
        }
    }
    return 0;
}
12-23 11:44