有一个地方简单贪心证明一下就好了
不大难的树形dp
code:
#include <bits/stdc++.h> #define N 500004 #define LL long long #define setIO(s) freopen(s".in","r",stdin) using namespace std; char *p1,*p2,buf[100000]; #define nc() (p1==p2&&(p2=(p1=buf)+fread(buf,1,100000,stdin),p1==p2)?EOF:*p1++) int rd() {int x=0; char c=nc(); while(c<48) c=nc(); while(c>47) x=(((x<<2)+x)<<1)+(c^48),c=nc(); return x;} struct data { int f,size,id; data(int f=0,int size=0,int id=0):f(f),size(size),id(id){} }; bool cmp(data a,data b) { return a.f-2*a.size==b.f-2*b.size?a.f>b.f:a.f-2*a.size>b.f-2*b.size; } int n,edges; vector<data>G[N]; int hd[N],to[N<<1],nex[N<<1],val[N],f[N],size[N]; void add(int u,int v) { nex[++edges]=hd[u],hd[u]=edges,to[edges]=v; } void dfs(int u,int ff) { size[u]=1; for(int i=hd[u];i;i=nex[i]) { int v=to[i]; if(v==ff) continue; dfs(v,u); G[u].push_back(data(f[v]+1,size[v],v)); size[u]+=size[v]; } sort(G[u].begin(),G[u].end(),cmp); int cur=0; if(u!=1) f[u]=val[u]; for(int i=0;i<G[u].size();++i) { f[u]=max(f[u],cur+G[u][i].f); cur+=2*G[u][i].size; } } int main() { // setIO("input"); int i,j; n=rd(); for(i=1;i<=n;++i) val[i]=rd(); for(i=1;i<n;++i) { int u,v; u=rd(),v=rd(); add(u,v), add(v,u); } dfs(1,0); f[1]=max(f[1], size[1]*2-2+val[1]); printf("%d\n",f[1]); return 0; }