code:

#include <bits/stdc++.h>

#define N 200009

#define ll long long

#define setIO(s) freopen(s".in","r",stdin)

using namespace std;

ll Sum[N];

int n,edges,root,sn;

int val[N],hd[N],to[N<<1],nex[N<<1],size[N],mx[N],vis[N],A[N];

inline void add(int u,int v)
{
nex[++edges]=hd[u],hd[u]=edges,to[edges]=v;
} void getroot(int u,int ff)
{
size[u]=1,mx[u]=0; for(int i=hd[u];i;i=nex[i])
{
int v=to[i]; if(v==ff||vis[v]) continue; getroot(v,u); size[u]+=size[v]; mx[u]=max(mx[u],size[v]);
} mx[u]=max(mx[u],sn-size[u]); if(mx[u]<mx[root]) root=u; } int ou; ll tmp,tot,bu[N]; map<int,ll>cn[N]; map<int,ll>::iterator it; int dep[N],cnt[N],siz[N]; void getnode(int top,int u,int ff,int cur)
{ if(!cnt[val[u]]) ++cur; ++cnt[val[u]]; Sum[top]+=(ll)cur; siz[u]=1; for(int i=hd[u];i;i=nex[i])
{
int v=to[i]; if(v==ff||vis[v]) continue; getnode(top,v,u,cur); siz[u]+=siz[v];
} --cnt[val[u]];
}
void get_col(int top,int u,int ff)
{
if(!cnt[val[u]]) cn[top][val[u]]+=(ll)siz[u]; ++cnt[val[u]]; for(int i=hd[u];i;i=nex[i])
{ int v=to[i]; if(v==ff||vis[v]) continue; get_col(top,v,u); }
--cnt[val[u]];
}
void calc_v(int u,int ff)
{
ll tt=bu[val[u]]; tmp=tmp-bu[val[u]]+ou; bu[val[u]]=ou; Sum[u]+=tmp; for(int i=hd[u];i;i=nex[i])
{
int v=to[i]; if(vis[v]||v==ff) continue; calc_v(v,u);
} tmp=tmp-bu[val[u]]+tt; bu[val[u]]=tt;
}
void clr(int u,int ff)
{
cn[u].clear();
bu[val[u]]=0;
for(int i=hd[u];i;i=nex[i])
{
int v=to[i];
if(v==ff||vis[v]) continue;
clr(v,u);
}
}
void calc(int u)
{
tot=0; getnode(u,u,0,0); for(int i=hd[u];i;i=nex[i])
{
int v=to[i]; if(vis[v]) continue; // memset(cnt,0,sizeof(cnt)); get_col(v,v,u); for(it=cn[v].begin();it!=cn[v].end();it++)
{
tot+=it->second; bu[it->first]+=it->second;
}
} for(int i=hd[u];i;i=nex[i])
{
int v=to[i]; if(vis[v]) continue; tmp=tot; ou=siz[u]-siz[v]; for(it=cn[v].begin();it!=cn[v].end();it++)
{
bu[it->first]-=it->second; tmp-=it->second;
} ll tt=bu[val[u]]; tmp=tmp-bu[val[u]]+ou; bu[val[u]]=ou; calc_v(v,u); bu[val[u]]=tt; for(it=cn[v].begin();it!=cn[v].end();it++) bu[it->first]+=it->second; } clr(u,0);
}
void dfs(int u)
{
calc(u); vis[u]=1; for(int i=hd[u];i;i=nex[i])
{
int v=to[i]; if(vis[v]) continue; root=0,sn=size[v],getroot(v,u),dfs(root);
}
}
int main()
{
// setIO("input"); int i,j; scanf("%d",&n); for(i=1;i<=n;++i) scanf("%d",&val[i]), A[i]=val[i]; sort(A+1,A+1+n); for(i=1;i<=n;++i) val[i]=lower_bound(A+1,A+1+n,val[i])-A; for(i=1;i<n;++i)
{
int u,v; scanf("%d%d",&u,&v),add(u,v),add(v,u);
} sn=mx[0]=n,root=0,getroot(1,0),dfs(root); for(i=1;i<=n;++i) printf("%lld\n",Sum[i]); return 0;
}

  

05-11 13:02