我终于码了树链剖分,别人半年前学会的东西我终于入门辣!!!
先放全代码:
1 #include<cstdio> 2 #include<iostream> 3 #include<cstring> 4 #include<cmath> 5 #define MAXN 30010 6 using namespace std; 7 const int INF=(1<<30); 8 struct rr{ 9 int nt,to; 10 }bl[MAXN<<1];int hd[MAXN],itot; 11 void addedge(int x,int y){ 12 bl[++itot]=(rr){hd[x],y}; 13 hd[x]=itot; 14 return ; 15 } 16 int n; 17 int w[MAXN]; 18 int son[MAXN],L[MAXN],top[MAXN],sz[MAXN]; 19 int fat[MAXN]; 20 int sd[MAXN]; 21 int num=0; 22 int line[MAXN];//dfs序上的某一个位置具体是哪个节点 23 void dfs1(int u,int fa){ 24 sz[u]=1;fat[u]=fa; 25 for(int i=hd[u],y;i;i=bl[i].nt) 26 if(bl[i].to!=fa){ 27 y=bl[i].to; 28 dfs1(y,u); 29 sz[u]+=sz[y]; 30 if(sz[son[u]]<sz[y])son[u]=y; 31 } 32 return ; 33 } 34 void dfs2(int u,int fa,int dep){ 35 sd[u]=dep; 36 L[u]=++num;line[num]=u; 37 top[u]=(son[fa]==u)?top[fa]:u; 38 if(son[u])dfs2(son[u],u,dep+1); 39 for(int i=hd[u],y;i;i=bl[i].nt) 40 if(bl[i].to!=fa&&bl[i].to!=son[u])dfs2(bl[i].to,u,dep+1); 41 return ; 42 } 43 struct SEGTREE{ 44 int mx[MAXN*4],sm[MAXN*4]; 45 #define ls (u<<1) 46 #define rs (u<<1|1) 47 void up(int u){ 48 mx[u]=max(mx[ls],mx[rs]); 49 sm[u]=sm[ls]+sm[rs]; 50 return ; 51 } 52 void build(int u,int l,int r){ 53 if(l==r){sm[u]=mx[u]=w[line[l]];return ; } 54 int mid=l+r>>1; 55 build(ls,l,mid),build(rs,mid+1,r); 56 up(u); 57 return ; 58 } 59 void upd(int u,int l,int r,int pos,int val){ 60 if(l==r){mx[u]=sm[u]=val;return ; } 61 int mid=l+r>>1; 62 if(pos<=mid)upd(ls,l,mid,pos,val); 63 else upd(rs,mid+1,r,pos,val); 64 up(u); 65 return ; 66 } 67 int qwsm(int u,int l,int r,int x,int y){ 68 if(l>=x&&r<=y)return sm[u]; 69 int mid=l+r>>1; 70 int res=0; 71 if(mid>=x)res+=qwsm(ls,l,mid,x,y); 72 if(mid+1<=y)res+=qwsm(rs,mid+1,r,x,y); 73 return res; 74 } 75 int qwmx(int u,int l,int r,int x,int y){ 76 if(l>=x&&r<=y)return mx[u]; 77 int mid=l+r>>1; 78 int res=-INF; 79 if(mid>=x)res=max(res,qwmx(ls,l,mid,x,y)); 80 if(mid+1<=y)res=max(res,qwmx(rs,mid+1,r,x,y)); 81 return res; 82 } 83 }T; 84 int q; 85 int gtmx(int x,int y){ 86 int res=-INF; 87 while(top[x]!=top[y]){ 88 if(sd[top[x]]<sd[top[y]])swap(x,y); 89 res=max(res,T.qwmx(1,1,num,L[top[x]],L[x])); 90 x=fat[top[x]]; 91 } 92 if(sd[x]<sd[y])swap(x,y); 93 res=max(res,T.qwmx(1,1,num,L[y],L[x])); 94 return res; 95 } 96 int gtsm(int x,int y){ 97 int res=0; 98 while(top[x]!=top[y]){ 99 if(sd[top[x]]<sd[top[y]])swap(x,y); 100 res+=T.qwsm(1,1,num,L[top[x]],L[x]); 101 x=fat[top[x]]; 102 } 103 if(sd[x]<sd[y])swap(x,y); 104 res+=T.qwsm(1,1,num,L[y],L[x]); 105 return res; 106 } 107 int main(){ 108 //freopen("count1.in","r",stdin); 109 //freopen("hh.out","w",stdout); 110 scanf("%d",&n); 111 for(int i=1,a,b;i<n;++i){ 112 scanf("%d%d",&a,&b); 113 addedge(a,b);addedge(b,a); 114 } 115 for(int i=1;i<=n;++i)scanf("%d",&w[i]); 116 dfs1(1,0);dfs2(1,0,1); 117 T.build(1,1,num); 118 scanf("%d",&q); 119 char opt[17]; 120 int u,v; 121 while(q--){ 122 scanf("%s%d%d",opt,&u,&v); 123 if(opt[0]=='C')T.upd(1,1,num,L[u],v); 124 else if(opt[1]=='M')printf("%d\n",gtmx(u,v)); 125 else printf("%d\n",gtsm(u,v)); 126 } 127 return 0; 128 }
着重解释几个代码片:
需要维护的变量:
1 int son[MAXN];//重儿子 2 int L[MAXN];//dfs序中的位置 3 int top[MAXN];//所在重链的链顶 4 int sz[MAXN];//子树的大小 5 int fat[MAXN]; 6 int sd[MAXN];//深度 7 int num=0; 8 int line[MAXN];//dfs序上的某一个位置具体是哪个节点
$dfs1$:
1 void dfs1(int u,int fa){ 2 sz[u]=1;fat[u]=fa; 3 for(int i=hd[u],y;i;i=bl[i].nt) 4 if(bl[i].to!=fa){ 5 y=bl[i].to; 6 dfs1(y,u); 7 sz[u]+=sz[y]; 8 if(sz[son[u]]<sz[y])son[u]=y;//更新重儿子 9 } 10 return ; 11 }
$dfs2$:
1 void dfs2(int u,int fa,int dep){ 2 sd[u]=dep; 3 L[u]=++num;line[num]=u; 4 top[u]=(son[fa]==u)?top[fa]:u; 5 //判断是否为链断 6 if(son[u])dfs2(son[u],u,dep+1); 7 //先dfs重儿子,确保一条重链的点在dfs序上是连续的 8 for(int i=hd[u],y;i;i=bl[i].nt) 9 if(bl[i].to!=fa&&bl[i].to!=son[u])dfs2(bl[i].to,u,dep+1); 10 return ; 11 }
建线段树就不说什么了
以询问两点路径最大值为例
$qwmx$:
1 int gtmx(int x,int y){ 2 int res=-INF; 3 while(top[x]!=top[y]){ 4 if(sd[top[x]]<sd[top[y]])swap(x,y); 5 res=max(res,T.qwmx(1,1,num,L[top[x]],L[x])); 6 x=fat[top[x]]; 7 } 8 if(sd[x]<sd[y])swap(x,y); 9 res=max(res,T.qwmx(1,1,num,L[y],L[x])); 10 return res; 11 }
例题:
先咕着,并没有做过几道题。。。