树链剖分就行了,注意线段树上颜色的合并
Code
#include <cstdio>
#include <algorithm>
#define N 100010
#define MID int mid=(l+r)>>1,ls=id<<1,rs=id<<1|1
#define len (r-l+1)
using namespace std; struct tree{
int lc,rc,sum,tag;
tree(){lc=rc=tag=-1;sum=0;}
friend tree operator +(tree a,tree b){
if(a.lc==-1) return b;
if(b.lc==-1) return a;
tree c;
c.lc=a.lc,c.rc=b.rc;
c.sum=a.sum+b.sum-(a.rc==b.lc?1:0);
return c;
}
}T[N*4];
struct info{int to,nex;}e[N*2];
int n,m,tot,head[N],cnt,A[N];
int tid[N],dep[N],son[N],fa[N],sz[N],tp[N],tw[N]; inline int read(){
int x=0,f=1;char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}
return x*f;
} inline void Link(int u,int v){
e[++tot].nex=head[u];head[u]=tot;e[tot].to=v;
} void dfs(int u,int pre){
sz[u]=1;
for(int i=head[u],mx=0;i;i=e[i].nex){
int v=e[i].to;
if(v==pre) continue;
fa[v]=u;
dep[v]=dep[u]+1;
dfs(v,u);
sz[u]+=sz[v];
if(sz[v]>mx){son[u]=v;mx=sz[v];}
}
} void dddfs(int u,int top){
tp[u]=top;
tid[u]=++cnt;
tw[cnt]=A[u];
if(!son[u]) return; dddfs(son[u],top);
for(int i=head[u];i;i=e[i].nex){
int v=e[i].to;
if(v==fa[u]||v==son[u]) continue;
dddfs(v,v);
}
} void build(int l,int r,int id){
if(l==r){T[id].sum=1;T[id].lc=T[id].rc=tw[l];return;}
MID;
build(l,mid,ls);
build(mid+1,r,rs);
T[id]=T[ls]+T[rs];
} void Init(){
n=read(),m=read();
for(int i=1;i<=n;A[i++]=read());
for(int i=1;i<n;++i){
int u=read(),v=read();
Link(u,v),Link(v,u);
}
dfs(1,0);
dddfs(1,1);
build(1,n,1);
} inline void pushdown(int l,int r,int id){
int &tmp=T[id].tag;
if(tmp==-1) return;
MID;
T[ls].lc=T[ls].rc=T[rs].lc=T[rs].rc=tmp;
T[ls].sum=T[rs].sum=1;
T[ls].tag=T[rs].tag=tmp;
tmp=-1;
} int query(int l,int r,int id,int L,int R){
if(L<=l&&r<=R) return T[id].sum;
pushdown(l,r,id);
MID;
int res=0;
if(R<=mid) res+=query(l,mid,ls,L,R);
else if(L>mid) res+=query(mid+1,r,rs,L,R);
else res+=query(l,mid,ls,L,R),res+=query(mid+1,r,rs,L,R),res-=(T[ls].rc==T[rs].lc)?1:0;
return res;
} int qDot(int l,int r,int id,int x){
if(l==r&&l==x) return T[id].lc;
pushdown(l,r,id);
MID;
if(x<=mid) return qDot(l,mid,ls,x);
else return qDot(mid+1,r,rs,x);
} inline int qRange(int u,int v){
int res=0;
while(tp[u]!=tp[v]){
if(dep[tp[u]]<dep[tp[v]]) swap(u,v);
res+=query(1,n,1,tid[tp[u]],tid[u]);
int x=qDot(1,n,1,tid[tp[u]]),y=qDot(1,n,1,tid[fa[tp[u]]]);
if(x==y) --res;
u=fa[tp[u]];
}
if(dep[u]>dep[v]) swap(u,v);
res+=query(1,n,1,tid[u],tid[v]);
return res;
} void update(int l,int r,int id,int L,int R,int x){
if(L<=l&&r<=R){
T[id].sum=1;
T[id].lc=T[id].rc=T[id].tag=x;
return;
}
pushdown(l,r,id);
MID;
if(L<=mid) update(l,mid,ls,L,R,x);
if(R>mid) update(mid+1,r,rs,L,R,x);
T[id]=T[ls]+T[rs];
} void updRange(int u,int v,int x){
while(tp[u]!=tp[v]){
if(dep[tp[u]]<dep[tp[v]]) swap(u,v);
update(1,n,1,tid[tp[u]],tid[u],x);
u=fa[tp[u]];
}
if(dep[u]>dep[v]) swap(u,v);
update(1,n,1,tid[u],tid[v],x);
} inline void solve(){
char ch;
while(m--){
for(ch=getchar();ch!='C'&&ch!='Q';ch=getchar());
if(ch=='Q'){
int u=read(),v=read();
printf("%d\n",qRange(u,v));
}else{
int u=read(),v=read(),x=read();
updRange(u,v,x);
}
}
} int main(){Init();solve();}