[LuoguP3384]

对于这个树剖吧。。。一开始不打算学的。。。但是貌似挺有用的,于是乎下狠心学了一哈。本蒟蒻初学树剖可能理解不太透彻,如有不对之处还请各位巨佬指正


先来一个树剖的了解[IvanovCraft的博客]

首先我们需要了解一下树链剖分可以用来干什么:

1.可以求树上差分,LCA(时间复杂度会小一些)

2.最重要的是可以修改树上的边权或者点权(比如本题)

剩下的目前本蒟蒻还too vegetable,所以对其理解不够深,许多用法可能以后会更新

Code:

#include <bits/stdc++.h>
#define ls x << 1
#define rs x << 1 | 1
using namespace std;
const int N = 200001;
int n, m, r, p;
int tot, head[N];//邻接链表存边 
struct edge{
    int net, v;
}e[N];
int a[N * 4], v[N * 4];
int cnt, w[N], wt[N], son[N], d[N], siz[N], fa[N], id[N], top[N];
//w[]存读入的点,wt[]存dfs遍历后的点,son[]存每个节点的重儿子,d[]存每个点的深度,size[]存以当前节点为根节点的子树的大小
//fa[]存每个节点的父亲节点,id[]存每个点的dfs序,top[]存当前节点所在重链的根节点
void add(int u, int v) {
    e[++tot].net = head[u];
    e[tot].v = v;
    head[u] = tot;
}
void push(int x, int l) {//更新懒标 
    v[ls] += v[x];
    v[rs] += v[x];
    a[ls] += v[x] * (l - l / 2);//这里上下互换也可以,只是为了将这条链分成两部分。
                                //又考虑到长度不可能全是偶数,所以一个上取整,另一个下取整。 
    a[rs] += v[x] * (l / 2);
    a[ls] %= p;
    a[rs] %= p;
    v[x] = 0;
}
void build(int x, int l, int r) {
    if (l == r) {
        a[x] = wt[l];//wt[]数组存的是dfs遍历后的点 
        a[x] %= p;
        return;
    }
    int mid = (l + r) / 2;
    build(ls, l, mid);
    build(rs, mid + 1, r);
    a[x] = (a[ls] + a[rs]) % p;
}
int query(int x, int l, int r, int L, int R) {//线段树维护和 
    int res = 0;
    if (L <= l && r <= R) {
        res += a[x];
        res %= p;
        return res;
    }
    if (v[x]) {
        push(x, r - l + 1);
    }
    int mid = (l + r) / 2;
    if (L <= mid) {
        res += query(ls, l, mid, L, R);
    }
    if (R > mid) {
        res += query(rs, mid + 1, r, L, R);
    }
    return res % p;
}
void change(int x, int l, int r, int L, int R, int k) {//十分正常的修改 
    if (L <= l && r <= R) {
        v[x] += k;
        a[x] += k * (r - l + 1);
        a[x] %= p;
        return;
    }
    if (v[x]) {
        push(x, r - l + 1);
    }
    int mid = (l + r) / 2;
    if (L <= mid) {
        change(ls, l, mid, L, R, k);
    }
    if(R > mid) {
        change(rs, mid + 1, r, L, R, k);
    }
    a[x] = (a[ls] + a[rs]) % p;
}
void Qchange(int x, int y, int k) {
    k %= p;
    while (top[x] != top[y]) {//当这两个点不在同一条重链上 
        if (d[top[x]] < d[top[y]]) {//调整深度,让x变成更深的那一个
            swap(x, y);
        }
        change(1, 1, n, id[top[x]], id[x], k);//修改 
        x = fa[top[x]];//每次结束后都要跳到父亲那里 
    }
    if (d[x] > d[y]) {//当两节点在一条重链上但不是同一节点时
                      //调整深度 
        swap(x, y);
    }
    change(1, 1, n, id[x], id[y], k);//在最后的距离上进行修改 
}
int Qquery(int x, int y) {
    int ans = 0;
    while (top[x] != top[y]) {//当这两个点不在同一条重链上 
        if (d[top[x]] < d[top[y]]) {//调整深度,让x变成更深的那一个
            swap(x, y);
        }
        ans += query(1, 1, n, id[top[x]], id[x]);//累加答案 
        ans %= p;
        x = fa[top[x]];//每次结束后都要跳到父亲那里 
    }
    if (d[x] > d[y]) {
        swap(x, y);
    }
    ans += query(1, 1, n, id[x], id[y]);
    return ans % p;
}
void Schange(int x, int k) {
    change(1, 1, n, id[x], id[x] + siz[x] - 1, k);
    //因为要修改以x为根节点的子树,所以边界就是 id[x] 到 id[x] + siz[x] - 1
}
int Squery(int x) {
    return query(1, 1, n, id[x], id[x] + siz[x] - 1);//同上 
}
void dfs1(int x, int f, int deep) {//核心函数1
    d[x] = deep;//深度 
    fa[x] = f;//父亲节点
    siz[x] = 1;//当前节点大小为1 
    int maxson = -1;
    for (int i = head[x]; i; i = e[i].net) {//枚举所有出边 
        int to = e[i].v;
        if (to == f) {//肯定不能回到自己父亲吧~ 
            continue;
        }
        dfs1(to, x, deep + 1);//继续dfs 
        siz[x] += siz[to];//大小累加 
        if (siz[to] > maxson) {//跟新重儿子 
            son[x] = to;
            maxson = siz[to];
        }
    }
}
void dfs2(int x, int topf) {//核心函数2 
    id[x] = ++cnt;//按照重链dfs来储存新的dfs序 
    wt[cnt] = w[x];//储存下来 
    top[x] = topf;//因为在同一重链上,直接赋值 
    if (!son[x]) {//如果当前节点没有重儿子就停止 
        return;
    }
    dfs2(son[x], topf);//优先重儿子进行深搜 
    for (int i = head[x]; i; i = e[i].net) {
        int to = e[i].v;
        if (to == fa[x] || to == son[x]) {//如果是自己的父亲或者是重儿子就跳过 
            continue;
        }
        dfs2(to, to);//因为轻儿子不在重链上,所以它的top就是自己 
    }
}
int main() {
    scanf("%d%d%d%d", &n, &m, &r, &p);
    for (int i = 1; i <= n; i++) {
        scanf("%d", &w[i]);
    }
    for (int i = 1; i < n; i++) {
        int u, v;
        scanf("%d%d", &u, &v);
        add(u, v);
        add(v, u);
    }
    dfs1(r, 0, 1);
    dfs2(r, r);
    build(1, 1, n);
    for (int i = 1; i <= m; i++) {
        int o, x, y, z;
        scanf("%d", &o);
        if (o == 1) {
            scanf("%d%d%d", &x, &y, &z);
            Qchange(x, y, z);
        } else if (o == 2) {
            scanf("%d%d", &x, &y);
            printf("%d\n", Qquery(x, y));
        } else if (o == 3) {
            scanf("%d%d", &x, &z);
            Schange(x, z);
        } else {
            scanf("%d", &x);
            printf("%d\n", Squery(x));
        }
    }
    return 0;
}
01-04 22:45