模拟赛的T2,多敲了两行成功爆掉~
写线段树合并的时候一定要注意一下不能随意新开节点.
code:
#include <bits/stdc++.h> #define N 100009 #define ll long long #define setIO(s) freopen(s".in","r",stdin),freopen(s".out","w",stdout) using namespace std; int n,edges; int A[N],hd[N],to[N<<1],nex[N<<1],kk[N],rt[N],ans1[N]; ll val[N<<1]; ll ans2[N]; void add(int u,int v,int c) { nex[++edges]=hd[u],hd[u]=edges,to[edges]=v,val[edges]=c; } struct Segment_Tree { #define lson p[x].ls #define rson p[x].rs int tot; struct Node { int ls,rs,size; ll dis,num,rt,mul,maxx; }p[N*90]; int newnode() { return ++tot; } void mark(int x,ll d) { p[x].mul+=d; p[x].rt+=d*p[x].num; } void pushdown(int x,int l,int r) { if(p[x].mul) { int mid=(l+r)>>1; if(lson) mark(lson, p[x].mul); if(rson) mark(rson, p[x].mul); p[x].mul=0; } } int merge(int l,int r,int u,int v) { if(!u||!v) return u+v; pushdown(u,l,r); pushdown(v,l,r); int now=newnode(); p[now].dis=p[u].dis+p[v].dis+p[u].num*p[v].rt+p[u].rt*p[v].num; p[now].num=p[u].num+p[v].num; p[now].rt=p[u].rt+p[v].rt; p[now].maxx=max(p[u].maxx, p[v].maxx); if(l==r) { if(p[now].num) p[now].size=1; p[now].maxx=p[now].dis; return now; } int mid=(l+r)>>1; p[now].ls=merge(l,mid,p[u].ls,p[v].ls); p[now].rs=merge(mid+1,r,p[u].rs,p[v].rs); p[now].size=p[p[now].ls].size+p[p[now].rs].size; p[now].maxx=max(p[p[now].ls].maxx, p[p[now].rs].maxx); return now; } int solve(int l,int r,int x) { if(l==r) return l; int mid=(l+r)>>1; pushdown(x,l,r); if(l<=mid && p[lson].size && p[lson].maxx==p[x].maxx) return solve(l,mid,lson); else return solve(mid+1,r,rson); } void update(int &x,int l,int r,int pp) { if(!x) x=newnode(); if(l==r) { p[x].size=1; p[x].num=1; return; } pushdown(x, l, r); int mid=(l+r)>>1; if(pp<=mid) update(lson,l,mid,pp); else update(rson,mid+1,r,pp); p[x].maxx=max(p[lson].maxx, p[rson].maxx); p[x].size=p[lson].size+p[rson].size; } ll dfss(int l,int r,int x,int kth) { if(l==r) return p[x].dis; int mid=(l+r)>>1; pushdown(x,l,r); int sz=p[lson].size; if(sz>=kth) return dfss(l,mid,lson,kth); else return dfss(mid+1,r,rson,kth-sz); } #undef lson #undef rson }seg; void dfs(int u,int ff,int pp) { seg.update(rt[u],1,n,A[u]); for(int i=hd[u];i;i=nex[i]) { int v=to[i]; if(v==ff) continue; dfs(v,u,val[i]); } if(seg.p[rt[u]].size<kk[u]) ans2[u]=-1; else { ans2[u]=seg.dfss(1,n,rt[u],kk[u]); } ans1[u]=seg.solve(1,n,rt[u]); seg.mark(rt[u], 1ll*pp); rt[ff]=seg.merge(1,n,rt[u], rt[ff]); } int main() { int i,j; // setIO("input"); scanf("%d",&n); for(i=1;i<n;++i) { int u,v,c; scanf("%d%d%d",&u,&v,&c), add(u,v,c), add(v,u,c); } for(i=1;i<=n;++i) scanf("%d",&A[i]); for(i=1;i<=n;++i) scanf("%d",&kk[i]); dfs(1,0,0); for(i=1;i<=n;++i) printf("%d %lld\n",ans1[i],ans2[i]); return 0; }