树链剖分,考虑用线段树维护左端点颜色、右端点颜色和段数,合并时判断中间两个点能否重合即可,注意在树剖时合并的顺序

  1 #include<bits/stdc++.h>
  2 using namespace std;
  3 #define N 100005
  4 #define L (k<<1)
  5 #define R (L+1)
  6 #define mid (l+r>>1)
  7 struct ji{
  8     int nex,to;
  9 }edge[N<<1];
 10 struct node{
 11     int ls,rs,ans;
 12 }f[N<<2];
 13 int E,n,m,x,y,z,head[N],fa[N],sz[N],sh[N],ma[N],id[N],top[N],laz[N<<2];
 14 char s[11];
 15 void upd(int k,int x){
 16     laz[k]=x;
 17     f[k]=node{x,x,1};
 18 }
 19 node up(node x,node y){
 20     if (!y.ans)return x;
 21     if (!x.ans)return y;
 22     return node{x.ls,y.rs,x.ans+y.ans-(x.rs==y.ls)};
 23 }
 24 void down(int k){
 25     if (laz[k]>=0){
 26         upd(L,laz[k]);
 27         upd(R,laz[k]);
 28         laz[k]=-1;
 29     }
 30 }
 31 void update(int k,int l,int r,int x,int y,int z){
 32     if ((l>y)||(x>r))return;
 33     if ((x<=l)&&(r<=y)){
 34         upd(k,z);
 35         return;
 36     }
 37     down(k);
 38     update(L,l,mid,x,y,z);
 39     update(R,mid+1,r,x,y,z);
 40     f[k]=up(f[L],f[R]);
 41     if (laz[k]>=0)upd(k,laz[k]);
 42 }
 43 node query(int k,int l,int r,int x,int y){
 44     if ((l>y)||(x>r))return node{0,0,0};
 45     if ((x<=l)&&(r<=y))return f[k];
 46     down(k);
 47     return up(query(L,l,mid,x,y),query(R,mid+1,r,x,y));
 48 }
 49 void add(int x,int y){
 50     edge[E].nex=head[x];
 51     edge[E].to=y;
 52     head[x]=E++;
 53 }
 54 void dfs(int k,int f,int s){
 55     sz[k]=1;
 56     fa[k]=f;
 57     sh[k]=s;
 58     for(int i=head[k];i!=-1;i=edge[i].nex)
 59         if (edge[i].to!=f){
 60             dfs(edge[i].to,k,s+1);
 61             sz[k]+=sz[edge[i].to];
 62             if (sz[ma[k]]<sz[edge[i].to])ma[k]=edge[i].to;
 63         }
 64 }
 65 void dfs2(int k,int fa,int t){
 66     top[k]=t;
 67     update(1,1,n,x+1,x+1,id[k]);
 68     id[k]=++x;
 69     if (ma[k])dfs2(ma[k],k,t);
 70     for(int i=head[k];i!=-1;i=edge[i].nex)
 71         if ((edge[i].to!=ma[k])&&(edge[i].to!=fa))dfs2(edge[i].to,k,edge[i].to);
 72 }
 73 void update(int x,int y,int z){
 74     while (top[x]!=top[y]){
 75         if (sh[top[x]]<sh[top[y]])swap(x,y);
 76         update(1,1,n,id[top[x]],id[x],z);
 77         x=fa[top[x]];
 78     }
 79     if (id[x]>id[y])swap(x,y);
 80     update(1,1,n,id[x],id[y],z);
 81 }
 82 node query(int x,int y){
 83     node ans1={0,0,0},ans2={0,0,0};
 84     while (top[x]!=top[y]){
 85         if (sh[top[x]]<sh[top[y]]){
 86             swap(x,y);
 87             swap(ans1,ans2);
 88         }
 89         ans1=up(query(1,1,n,id[top[x]],id[x]),ans1);
 90         x=fa[top[x]];
 91     }
 92     if (id[x]>id[y])swap(x,y);
 93     else swap(ans1,ans2);
 94     ans1=up(query(1,1,n,id[x],id[y]),ans1);
 95     swap(ans1.ls,ans1.rs);
 96     return up(ans1,ans2);
 97 }
 98 int main(){
 99     scanf("%d%d",&n,&m);
100     for(int i=1;i<=n;i++)scanf("%d",&id[i]);
101     memset(head,-1,sizeof(head));
102     memset(laz,-1,sizeof(laz));
103     for(int i=1;i<n;i++){
104         scanf("%d%d",&x,&y);
105         add(x,y);
106         add(y,x);
107     }
108     x=0;
109     dfs(1,0,0);
110     dfs2(1,0,1);
111     for(int i=1;i<=m;i++){
112         scanf("%s%d%d",s,&x,&y);
113         if (s[0]=='Q')printf("%d\n",query(x,y).ans);
114         else{
115             scanf("%d",&z);
116             update(x,y,z);
117         }
118     }
119 }
View Code
02-01 00:30