题目链接:HDU-6547 Tree
题意
wls 有三棵树,树上每个节点都有一个值 $a_i$,现在有 2 种操作:
1. 将一条链上的所有节点的值开根号向下取整;
2. 求一条链上值的和;
链的定义是两点之间的最短路。
思路
树链剖分裸题,区间开根号可用线段树做,利用 $10^9$ 范围内的数经过少数几次开根号之后就会达到 1,标记线段树区间最大值,若为 1 则无需再往下更新。
树链剖分传送门:https://www.cnblogs.com/kangkang-/p/8486150.html
代码实现
#include <stdio.h> #include <iostream> #include <cmath> #define REP(i, a, b) for (int i = a; i <= b; i++) using namespace std; typedef long long LL; const double esp = 1e-8; const int MAXN = 110000; struct Node { int to, next; } edg[MAXN<<1]; struct segmentTree { int left, right; LL sum, maxx; } tree[MAXN<<2]; int head[MAXN], siz[MAXN], top[MAXN], hson[MAXN], dep[MAXN], fa[MAXN], id[MAXN], rnk[MAXN]; int N, M, R, A[MAXN], idx = 0, dfs_cnt = 0; inline int read() { int x = 0, f = 1; char ch = getchar(); while (ch < '0' || ch > '9') { if (ch == '-') f = -1; ch = getchar(); } while (ch >= '0' && ch <= '9') { x = x * 10 + ch - '0'; ch = getchar(); } return x * f; } inline void adde(int u, int v) { edg[++idx].to = v; edg[idx].next = head[u]; head[u] = idx; } void dfs1(int u, int father, int depth) { dep[u] = depth; fa[u] = father; siz[u] = 1; for (int i = head[u]; i; i = edg[i].next) { int v = edg[i].to; if (v != fa[u]) { dfs1(v, u, depth + 1); siz[u] += siz[v]; if (hson[u] == -1 || siz[v] > siz[hson[u]]) hson[u] = v; } } } void dfs2(int u, int t) { id[u] = ++dfs_cnt; rnk[dfs_cnt] = u; top[u] = t; if (!hson[u]) return ; dfs2(hson[u], t); for (int i = head[u]; i; i = edg[i].next) { int v = edg[i].to; if (v != hson[u] && v != fa[u]) dfs2(v, v); } } void buildtree(int i, int l, int r) { tree[i].left = l; tree[i].right = r; if (l == r) tree[i].sum = tree[i].maxx = A[rnk[l]]; else { int mid = (l + r) >> 1; buildtree(i << 1, l , mid); buildtree(i << 1 | 1, mid + 1, r); tree[i].sum = tree[i<<1].sum + tree[i<<1|1].sum; tree[i].maxx = max(tree[i<<1].maxx, tree[i<<1|1].maxx); } } void update(int i, int x, int y) { if (tree[i].left > y || tree[i].right < x) return ; if (tree[i].left == tree[i].right) { tree[i].maxx = sqrt(tree[i].maxx) + esp; tree[i].sum = sqrt(tree[i].sum) + esp; return ; } if (tree[i].maxx == 1) return ; int l = i << 1, r = i << 1 | 1; update(l, x, y); update(r, x, y); tree[i].sum = tree[l].sum + tree[r].sum; } LL query(int i, int x, int y) { int l = i << 1, r = i << 1 | 1; if (x <= tree[i].left && tree[i].right <= y) return tree[i].sum; if (tree[i].left > y || tree[i].right < x) return 0; return query(l, x, y) + query(r, x, y); } void update_path(int u, int v) { int tu = top[u], tv = top[v]; while (tu != tv) { if (dep[tu] < dep[tv]) swap(u, v), swap(tu, tv); update(1, id[tu], id[u]); u = fa[tu], tu = top[u]; } if (dep[u] < dep[v]) swap(u, v); update(1, id[v], id[u]); } LL query_path(int u, int v) { LL res = 0; int tu = top[u], tv = top[v]; while (tu != tv) { if (dep[tu] < dep[tv]) swap(u, v), swap(tu, tv); res += query(1, id[tu], id[u]); u = fa[tu], tu = top[u]; } if (dep[u] < dep[v]) swap(u, v); return res + query(1, id[v], id[u]); } int main() { N = read(), M = read(), R = 1; REP(i, 1, N) A[i] = read(); REP(i, 2, N) { int u = read(), v = read(); adde(u, v); adde(v, u); } dfs1(R, 0, 1); dfs2(R, R); buildtree(1, 1, N); while (M--) { int opt = read(); switch (opt) { case 0: { int x = read(), y = read(); LL z; update_path(x, y); break; } case 1: { int x = read(), y = read(); printf("%lld\n", query_path(x, y)); break; } } } return 0; }