挺有趣的一个数据结构题,不必用 LCT 维护,只不过 LCT 比较好写 ~
code:
#include <cstdio> #include <string> #include <cstring> #include <algorithm> #define N 200008 #define ll long long #define lson s[x].ch[0] #define rson s[x].ch[1] using namespace std; ll ans; void setIO(string s) { freopen((s+".in").c_str(),"r",stdin); // freopen((s+".out").c_str(),"w",stdout); } int edges; int sta[N],hd[N],to[N<<1],nex[N<<1],val[N]; void Add(int u,int v) { nex[++edges]=hd[u],hd[u]=edges,to[edges]=v; } struct LCT { int ch[2],rev,f,size; ll tag,sum,v; }s[N]; int get(int x) { return s[s[x].f].ch[1]==x; } int Isr(int x) { return s[s[x].f].ch[0]!=x&&s[s[x].f].ch[1]!=x;} void pushup(int x) { s[x].sum=s[lson].sum+s[rson].sum+s[x].v; s[x].size=s[lson].size+s[rson].size+1; } void add(int x,ll v) { s[x].v+=v; s[x].sum+=v*1ll*s[x].size; s[x].tag+=v; } void rotate(int x) { int old=s[x].f,fold=s[old].f,which=get(x); if(!Isr(old)) s[fold].ch[s[fold].ch[1]==old]=x; s[old].ch[which]=s[x].ch[which^1]; if(s[old].ch[which]) s[s[old].ch[which]].f=old; s[x].ch[which^1]=old,s[old].f=x,s[x].f=fold; pushup(old),pushup(x); } void pushdown(int x) { if(s[x].tag) { if(lson) add(lson,s[x].tag); if(rson) add(rson,s[x].tag); s[x].tag=0; } } void splay(int x) { int u=x,v=0,fa; for(sta[++v]=u;!Isr(u);u=s[u].f) sta[++v]=s[u].f; for(;v;--v) pushdown(sta[v]); for(u=s[u].f;(fa=s[x].f)!=u;rotate(x)) if(s[fa].f!=u) rotate(get(fa)==get(x)?fa:x); } void Access(int x) { for(int y=0;x;y=x,x=s[x].f) splay(x),rson=y,pushup(x); } void dfs(int u,int ff) { s[u].f=ff; s[u].v=val[u]; for(int i=hd[u];i;i=nex[i]) if(to[i]!=ff) dfs(to[i],u),s[u].v+=s[to[i]].v; ans+=s[u].v*s[u].v; pushup(u); } int main() { // setIO("input"); ll tot=0ll; int i,j,n,Q,x,y; scanf("%d%d",&n,&Q); for(i=1;i<n;++i) scanf("%d%d",&x,&y),Add(x,y),Add(y,x); for(i=1;i<=n;++i) scanf("%d",&val[i]),tot+=val[i]; dfs(1,0); for(i=1;i<=Q;++i) { int op; scanf("%d",&op); if(op==1) { scanf("%d%d",&x,&y); int d=y-val[x]; Access(x),splay(x); ans+=2ll*s[x].sum*d+1ll*s[x].size*d*d; add(x,d); tot+=d; val[x]=y; } else { scanf("%d",&x); Access(x),splay(x); int siz=s[x].size-1; ll sum=s[x].sum-tot; printf("%lld\n",ans+1ll*siz*tot*tot-2ll*tot*sum); } } return 0; }