题目链接:https://www.luogu.org/problemnew/show/P2590

我想学树剖QAQ

#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
using namespace std;
const int maxn = 31000;
int fa[maxn], dep[maxn], size[maxn], son[maxn], top[maxn], seg[maxn], rev[maxn<<2];
int sum[maxn<<2], num[maxn], mx[maxn];
struct edge{
int to, next;
}e[maxn<<2];
int head[maxn<<2], cnt;
int summ, maxx, n, m;
void query(int k, int l, int r, int L, int R)//区间查询
{
if(L > r||R < l) return;
if(L <= l&&r <= R)
{
summ += sum[k];
maxx = max(maxx, mx[k]);
return;
}
int mid = l+r>>1, res = 0;
if(mid >= L) query(k<<1, l, mid, L, R);
if(mid+1 <= R) query((k<<1)+1, mid+1, r, L, R);
}
void change(int k, int l, int r, int val, int pos)//单点修改
{
if(pos>r || pos<l) return;
if(l == r && r == pos)
{
sum[k] = val;
mx[k] = val;
return;
}
int mid = l+r>>1;
if(mid >= pos) change(k<<1, l, mid, val, pos);
if(mid+1 <= pos) change((k<<1)+1, mid+1, r, val, pos);
sum[k] = sum[k<<1]+sum[(k<<1)+1];
mx[k] = max(mx[k<<1], mx[(k<<1)+1]);
}
void dfs1(int u, int f)
{
int v;
size[u] = 1;
fa[u] = f;
dep[u] = dep[f]+1;
for(int i = head[u]; v = e[i].to, i; i = e[i].next)
{
if(v != f)
{
dfs1(v,u);
size[u] += size[v];
if(size[v] > size[son[u]])
son[u] = v;
}
}
}
void dfs2(int u, int f)
{
int v;
if(son[u])
{
seg[son[u]] = ++seg[0];
top[son[u]] = top[u];
rev[seg[0]] = son[u];
dfs2(son[u],u);
}
for(int i = head[u]; v = e[i].to, i; i = e[i].next)
{
if(!top[v])
{
seg[v] = ++seg[0];
rev[seg[0]] = v;
top[v] = v;
dfs2(v,u);
}
}
}
void build(int k, int l, int r)
{
int mid = l+r>>1;
if(l == r)
{
mx[k] = sum[k] = num[rev[l]];
return;
}
build(k<<1, l, mid);
build((k<<1)+1, mid+1, r);
sum[k] = sum[k<<1]+sum[(k<<1)+1];
mx[k] = max(mx[k<<1],mx[(k<<1)+1]);
}
inline int read()
{
char c;
int k = 1;
while((c = getchar())<'0' || c>'9')
if(c == '-') k = -1;
int res = c-'0';
while((c = getchar())>='0' && c<='9')
res = res*10+c-'0';
return res*k;
}
inline void add(int u, int v)
{
e[++cnt].next = head[u];
e[cnt].to = v;
head[u] = cnt;
}
inline void insert(int u, int v)
{
add(u, v);
add(v, u);
}
inline void ask(int u, int v)
{
int fu = top[u], fv = top[v];
while(fu != fv)
{
if(dep[fu]<dep[fv]) swap(u,v), swap(fu,fv);
query(1,1,seg[0],seg[fu],seg[u]);
u = fa[fu], fu = top[u];
}
if(dep[u]>dep[v]) swap(u,v);
query(1,1,seg[0],seg[u],seg[v]);
}
int main()
{
n = read();
for(int i = 1; i < n; i++)
insert(read(),read());
for(int i = 1; i <= n; i++)
num[i] = read();
dfs1(1,0);
seg[0] = seg[1] = rev[1] = top[1] = 1;
dfs2(1,0);
build(1,1,seg[0]);
m = read();
char opt[10];
int u, v;
for(int i = 1; i <= m; i++)
{
scanf("%s",opt+1);
u = read(); v = read();
if(opt[1] == 'C')
change(1,1,seg[0],v,seg[u]);
else
{
summ = 0;
maxx = -0x7fffffff;
ask(u,v);
if(opt[2] == 'M')
printf("%d\n",maxx);
else
printf("%d\n",summ);
}
}
return 0;
}
05-11 19:45