题目链接:http://www.lydsy.com/JudgeOnline/problem.php?id=2243

树链剖分的点剖分+线段树。漏了一个小地方,调了一下午...... 还是要细心啊!

结构体里lc表示这个区间的最左端的颜色,rc表示这个区间的最右端的颜色,sum表示这个区间的颜色段数目。回溯合并的时候要注意,左孩子的右端颜色要是等于右孩子左端颜色 sum就要-1。

代码如下:

 #include <iostream>
#include <cstdio>
#include <cstring>
using namespace std;
const int MAXN = 2e5 + ;
struct data {
int next , to;
}edge[MAXN << ];
struct segtree {
int l , r , sum , lazy; //sum表示颜色段数量,lazy表示颜色延迟标记
int lc , rc; //lc表示区间的最左端颜色,rc表示最右端颜色
}T[MAXN << ];
int head[MAXN] , tot;
int son[MAXN] , size[MAXN] , par[MAXN] , dep[MAXN] , cnt;
int top[MAXN] , id[MAXN] , fid[MAXN];
int a[MAXN]; void init() {
memset(head , - , sizeof(head));
tot = cnt = ;
} inline void add(int u , int v) {
edge[tot].next = head[u];
edge[tot].to = v;
head[u] = tot++;
} void dfs1(int u , int p , int d) {
dep[u] = d , size[u] = , son[u] = u , par[u] = p;
for(int i = head[u] ; ~i ; i = edge[i].next) {
int v = edge[i].to;
if(v == p)
continue;
dfs1(v , u , d + );
if(size[v] >= size[son[u]])
son[u] = v;
size[u] += size[v];
}
} void dfs2(int u , int p , int t) {
top[u] = t , id[u] = ++cnt;
fid[cnt] = u;
if(son[u] != u)
dfs2(son[u] , u , t);
for(int i = head[u] ; ~i ; i = edge[i].next) {
int v = edge[i].to;
if(v == p || v == son[u])
continue;
dfs2(v , u , v);
}
} void pushdown(int p) {
if(T[p].lazy != -) {
int ls = p << , rs = (p << )|;
T[ls].rc = T[ls].lc = T[rs].lc = T[rs].rc = T[p].lazy;
T[ls].lazy = T[rs].lazy = T[p].lazy;
T[ls].sum = T[rs].sum = ; //变成同一个颜色 sum就为1了
T[p].lazy = -;
}
} void pushup(int p) {
T[p].lc = T[p << ].lc , T[p].rc = T[(p << )|].rc; //这里注意要回溯上来,父节点的左右端颜色要更新
T[p].sum = T[p << ].sum + T[(p << )|].sum - (T[p << ].rc == T[(p << )|].lc); //合并操作:要是左孩子的最右端颜色等于右孩子最左端颜色,那就需要-1
} void build(int p , int l , int r) {
int mid = (l + r) >> ;
T[p].r = r , T[p].l = l , T[p].lc = a[fid[l]] , T[p].rc = a[fid[r]] , T[p].lazy = -;
if(l == r) {
T[p].sum = ;
return ;
}
build(p << , l , mid);
build((p << )| , mid + , r);
pushup(p);
} void update(int p , int l , int r , int color) {
int mid = (T[p].l + T[p].r) >> ;
if(T[p].l == l && T[p].r == r) {
T[p].sum = , T[p].lazy = T[p].rc = T[p].lc = color;
return ;
}
pushdown(p);
if(r <= mid) {
update(p << , l , r , color);
}
else if(l > mid) {
update((p << )| , l , r , color);
}
else {
update(p << , l , mid , color);
update((p << )| , mid + , r , color);
}
pushup(p);
} int query(int p , int l , int r) {
int mid = (T[p].l + T[p].r) >> ;
if(T[p].l == l && T[p].r == r) {
return T[p].sum;
}
pushdown(p);
if(r <= mid) {
return query(p << , l , r);
}
else if(l > mid) {
return query((p << )| , l , r);
}
else {
return query(p << , l , mid) + query((p << )| , mid + , r) - (T[p << ].rc == T[(p << )|].lc);
}
} int query_pos_color(int p , int pos) {
int mid = (T[p].l + T[p].r) >> ;
if(T[p].l == T[p].r && pos == T[p].r) {
return T[p].lc;
}
pushdown(p);
if(pos <= mid) {
query_pos_color(p << , pos);
}
else {
query_pos_color((p << )| , pos);
}
} void find_update(int u , int v , int val) {
int fu = top[u] , fv = top[v];
while(fu != fv) {
if(dep[fu] >= dep[fv]) {
update( , id[fu] , id[u] , val);
u = par[fu];
fu = top[u];
}
else {
update( , id[fv] , id[v] , val);
v = par[fv];
fv = top[v];
}
}
if(dep[u] > dep[v])
update( , id[v] , id[u] , val);
else
update( , id[u] , id[v] , val);
} int find_ans(int u , int v) {
int fu = top[u] , fv = top[v] , res = ;
while(fu != fv) {
if(dep[fu] >= dep[fv]) {
res += query( , id[fu] , id[u]);
if(query_pos_color( , id[fu]) == query_pos_color( , id[par[fu]])) //要是fu节点和其父节点颜色相同就-1
res--;
u = par[fu];
fu = top[u];
}
else {
res += query( , id[fv] , id[v]);
if(query_pos_color( , id[fv]) == query_pos_color( , id[par[fv]])) //上同
res--;
v = par[fv];
fv = top[v];
}
}
if(dep[u] > dep[v]) {
res += query( , id[v] , id[u]);
return res;
}
else {
res += query( , id[u] , id[v]);
return res;
}
} int main()
{
int n , m , u , v , val;
char q[];
while(~scanf("%d %d" , &n , &m)) {
init();
for(int i = ; i <= n ; ++i)
scanf("%d" , a + i);
for(int i = ; i < n ; ++i) {
scanf("%d %d" , &u , &v);
add(u , v);
add(v , u);
}
dfs1( , , );
dfs2( , , );
build( , , cnt);
while(m--) {
scanf("%s" , q);
if(q[] == 'Q') {
scanf("%d %d" , &u , &v);
printf("%d\n" , find_ans(u , v));
}
else {
scanf("%d %d %d" , &u , &v , &val);
find_update(u , v , val);
}
}
}
return ;
}
05-11 14:04