挺有趣的一个数据结构题,不必用 LCT 维护,只不过 LCT 比较好写 ~ 

code:

#include <cstdio>
#include <string>
#include <cstring>
#include <algorithm>
#define N 200008
#define ll long long
#define lson s[x].ch[0]
#define rson s[x].ch[1]
using namespace std;
ll ans;
void setIO(string s)
{
    freopen((s+".in").c_str(),"r",stdin);
    // freopen((s+".out").c_str(),"w",stdout);
}
int edges;
int sta[N],hd[N],to[N<<1],nex[N<<1],val[N];
void Add(int u,int v) { nex[++edges]=hd[u],hd[u]=edges,to[edges]=v; }
struct LCT
{
    int ch[2],rev,f,size;
    ll tag,sum,v;
}s[N];
int get(int x) { return s[s[x].f].ch[1]==x; }
int Isr(int x) { return s[s[x].f].ch[0]!=x&&s[s[x].f].ch[1]!=x;}
void pushup(int x)
{
    s[x].sum=s[lson].sum+s[rson].sum+s[x].v;
    s[x].size=s[lson].size+s[rson].size+1;
}
void add(int x,ll v)
{
    s[x].v+=v;
    s[x].sum+=v*1ll*s[x].size;
    s[x].tag+=v;
}
void rotate(int x)
{
    int old=s[x].f,fold=s[old].f,which=get(x);
    if(!Isr(old)) s[fold].ch[s[fold].ch[1]==old]=x;
    s[old].ch[which]=s[x].ch[which^1];
    if(s[old].ch[which]) s[s[old].ch[which]].f=old;
    s[x].ch[which^1]=old,s[old].f=x,s[x].f=fold;
    pushup(old),pushup(x);
}
void pushdown(int x)
{
    if(s[x].tag)
    {
        if(lson) add(lson,s[x].tag);
        if(rson) add(rson,s[x].tag);
        s[x].tag=0;
    }
}
void splay(int x)
{
    int u=x,v=0,fa;
    for(sta[++v]=u;!Isr(u);u=s[u].f) sta[++v]=s[u].f;
    for(;v;--v) pushdown(sta[v]);
    for(u=s[u].f;(fa=s[x].f)!=u;rotate(x))
        if(s[fa].f!=u)
            rotate(get(fa)==get(x)?fa:x);
}
void Access(int x)
{
    for(int y=0;x;y=x,x=s[x].f)
        splay(x),rson=y,pushup(x);
}
void dfs(int u,int ff)
{
    s[u].f=ff;
    s[u].v=val[u];
    for(int i=hd[u];i;i=nex[i])
        if(to[i]!=ff) dfs(to[i],u),s[u].v+=s[to[i]].v;
    ans+=s[u].v*s[u].v;
    pushup(u);
}
int main()
{
    // setIO("input");
    ll tot=0ll;
    int i,j,n,Q,x,y;
    scanf("%d%d",&n,&Q);
    for(i=1;i<n;++i)  scanf("%d%d",&x,&y),Add(x,y),Add(y,x);
    for(i=1;i<=n;++i) scanf("%d",&val[i]),tot+=val[i];
    dfs(1,0);
    for(i=1;i<=Q;++i)
    {
        int op;
        scanf("%d",&op);
        if(op==1)
        {
            scanf("%d%d",&x,&y);
            int d=y-val[x];
            Access(x),splay(x);
            ans+=2ll*s[x].sum*d+1ll*s[x].size*d*d;
            add(x,d);
            tot+=d;
            val[x]=y;

        }
        else
        {
            scanf("%d",&x);
            Access(x),splay(x);
            int siz=s[x].size-1;
            ll sum=s[x].sum-tot;
            printf("%lld\n",ans+1ll*siz*tot*tot-2ll*tot*sum);
        }
    }
    return 0;
}

  

12-24 22:40