模拟赛的T2,多敲了两行成功爆掉~

写线段树合并的时候一定要注意一下不能随意新开节点. 

code: 

#include <bits/stdc++.h>
#define N 100009
#define ll long long
#define setIO(s) freopen(s".in","r",stdin),freopen(s".out","w",stdout)
using namespace std;
int n,edges;
int A[N],hd[N],to[N<<1],nex[N<<1],kk[N],rt[N],ans1[N];
ll val[N<<1];
ll ans2[N];
void add(int u,int v,int c)
{
    nex[++edges]=hd[u],hd[u]=edges,to[edges]=v,val[edges]=c;
}
struct Segment_Tree
{
    #define lson p[x].ls
    #define rson p[x].rs
    int tot;
    struct Node
    {
        int ls,rs,size;
        ll dis,num,rt,mul,maxx;
    }p[N*90];
    int newnode() { return ++tot; }
    void mark(int x,ll d)
    {
        p[x].mul+=d;
        p[x].rt+=d*p[x].num;
    }
    void pushdown(int x,int l,int r)
    {
        if(p[x].mul)
        {
            int mid=(l+r)>>1;
            if(lson) mark(lson, p[x].mul);
            if(rson) mark(rson, p[x].mul);
            p[x].mul=0;
        }
    }
    int merge(int l,int r,int u,int v)
    {
        if(!u||!v) return u+v;
        pushdown(u,l,r);
        pushdown(v,l,r);
        int now=newnode();
        p[now].dis=p[u].dis+p[v].dis+p[u].num*p[v].rt+p[u].rt*p[v].num;
        p[now].num=p[u].num+p[v].num;
        p[now].rt=p[u].rt+p[v].rt;
        p[now].maxx=max(p[u].maxx, p[v].maxx);
        if(l==r)
        {
            if(p[now].num) p[now].size=1;
            p[now].maxx=p[now].dis;
            return now;
        }
        int mid=(l+r)>>1;
        p[now].ls=merge(l,mid,p[u].ls,p[v].ls);
        p[now].rs=merge(mid+1,r,p[u].rs,p[v].rs);
        p[now].size=p[p[now].ls].size+p[p[now].rs].size;
        p[now].maxx=max(p[p[now].ls].maxx, p[p[now].rs].maxx);
        return now;
    }
    int solve(int l,int r,int x)
    {
        if(l==r) return l;
        int mid=(l+r)>>1;
        pushdown(x,l,r);
        if(l<=mid && p[lson].size && p[lson].maxx==p[x].maxx) return solve(l,mid,lson);
        else return solve(mid+1,r,rson);
    }
    void update(int &x,int l,int r,int pp)
    {
        if(!x) x=newnode();
        if(l==r)
        {
            p[x].size=1;
            p[x].num=1;
            return;
        }
        pushdown(x, l, r);
        int mid=(l+r)>>1;
        if(pp<=mid) update(lson,l,mid,pp);
        else update(rson,mid+1,r,pp);
        p[x].maxx=max(p[lson].maxx, p[rson].maxx);
        p[x].size=p[lson].size+p[rson].size;
    }
    ll dfss(int l,int r,int x,int kth)
    {
        if(l==r) return p[x].dis;
        int mid=(l+r)>>1;
        pushdown(x,l,r);
        int sz=p[lson].size;
        if(sz>=kth) return dfss(l,mid,lson,kth);
        else return dfss(mid+1,r,rson,kth-sz);
    }
    #undef lson
    #undef rson
}seg;
void dfs(int u,int ff,int pp)
{
    seg.update(rt[u],1,n,A[u]);
    for(int i=hd[u];i;i=nex[i])
    {
        int v=to[i];
        if(v==ff) continue;
        dfs(v,u,val[i]);
    }
    if(seg.p[rt[u]].size<kk[u]) ans2[u]=-1;
    else
    {
        ans2[u]=seg.dfss(1,n,rt[u],kk[u]);
    }
    ans1[u]=seg.solve(1,n,rt[u]);
    seg.mark(rt[u], 1ll*pp);
    rt[ff]=seg.merge(1,n,rt[u], rt[ff]);
}
int main()
{
    int i,j;
    // setIO("input");
    scanf("%d",&n);
    for(i=1;i<n;++i)
    {
        int u,v,c;
        scanf("%d%d%d",&u,&v,&c), add(u,v,c), add(v,u,c);
    }
    for(i=1;i<=n;++i) scanf("%d",&A[i]);
    for(i=1;i<=n;++i) scanf("%d",&kk[i]);
    dfs(1,0,0);
    for(i=1;i<=n;++i) printf("%d %lld\n",ans1[i],ans2[i]);
    return 0;
}

  

01-19 23:18