采蘑菇

Time Limit: 20 Sec  Memory Limit: 256 MB

Description

  【Foreign】采蘑菇 [点分治]-LMLPHP

Input

  【Foreign】采蘑菇 [点分治]-LMLPHP

Output

  【Foreign】采蘑菇 [点分治]-LMLPHP

Sample Input

  5
  1 2 3 2 3
  1 2
  1 3
  2 4
  2 5

Sample Output

  10
  9
  12
  9
  11

HINT

  【Foreign】采蘑菇 [点分治]-LMLPHP

Main idea

  询问从以每个点为起始点时,各条路径上的颜色种类的和。

Solution

  我们看到题目,立马想到了O(n^2)的做法,然后从这个做法研究一下本质,我们确定了可以以点分治作为框架。

  我们先用点分治来确定一个center(重心)。然后计算跟这个center有关的路径。设现在要统计的是经过center,对x提供贡献的路径。

  我们先记录一个记录Sum[x]表示1~i-1子树中 颜色x 第一次出现的位置的那个点 的子树和,然后我们就利用这个Sum来解题。

  我们显然可以分两种情况来讨论:

  (1)统计center->x出现颜色的贡献
    显然,这时候,对于center->x这一段,直接像O(n^2)做法那样记录一个color表示到目前为止出现的颜色个数,然后加一下即可。再记录一个record表示当前可有的贡献和,一旦出现过一个颜色,那么这个颜色在1~i-1子树上出现第一次以下的点,对于x就不再提供贡献了,record减去Sum[这个颜色],然后这样深搜往下计算即可。

  (2)统计center->x没出现过的颜色的贡献
    显然,对于center->x上没出现过的颜色,直接往下深搜,一开始为record为(All - Sum[center]),一旦出现了一个颜色,record则减去这个Sum。同样表示不再提供贡献即可。

  我们这样做就可以求出每个子树前缀对于其的贡献了,倒着再做一边即可求出全部的贡献。统计x的时候,顺便统计一下center。可以满足效率,成功AC这道题。

Code

 #include<iostream>
#include<algorithm>
#include<cstdio>
#include<cstring>
#include<cstdlib>
#include<cmath>
using namespace std; const int ONE = ;
const int INF = ;
const int MOD = 1e9+; int n,x,y;
int Val[ONE];
int next[ONE],first[ONE],go[ONE],tot;
int vis[ONE];
int Ans[ONE],Sum[ONE];
int All; int get()
{
int res,Q=; char c;
while( (c=getchar())< || c>)
if(c=='-')Q=-;
if(Q) res=c-;
while((c=getchar())>= && c<=)
res=res*+c-;
return res*Q;
} void Add(int u,int v)
{
next[++tot]=first[u]; first[u]=tot; go[tot]=v;
next[++tot]=first[v]; first[v]=tot; go[tot]=u;
} namespace Point
{
int center;
int Stack[ONE],top;
int total,Max,center_vis[ONE];
int num,V[ONE]; struct power
{
int size,maxx;
}S[ONE]; void Getsize(int u,int father)
{
S[u].size=;
S[u].maxx=;
for(int e=first[u];e;e=next[e])
{
int v=go[e];
if(v==father || center_vis[v]) continue;
Getsize(v,u);
S[u].size += S[v].size;
S[u].maxx = max(S[u].maxx,S[v].size);
}
} void Getcenter(int u,int father,int total)
{
S[u].maxx = max(S[u].maxx,total-S[u].size);
if(S[u].maxx < Max)
{
Max = S[u].maxx;
center = u;
} for(int e=first[u];e;e=next[e])
{
int v=go[e];
if(v==father || center_vis[v]) continue;
Getcenter(v,u,total);
}
} void Ad_sum(int u,int father)
{
if(!vis[Val[u]])
{
Stack[++top] = Val[u];
All += S[u].size; Sum[Val[u]] += S[u].size;
}
vis[Val[u]]++;
for(int e=first[u];e;e=next[e])
{
int v=go[e];
if(v==father || center_vis[v]) continue;
Ad_sum(v,u);
}
vis[Val[u]]--;
} void Calc_in(int u,int father,int center,int Size,int f_time,int record)
{
if(!vis[Val[u]]) f_time++, record += Size, record -= Sum[Val[u]];
Ans[u] += record; Ans[center]+=f_time;
Ans[u] += f_time; vis[Val[u]] ++;
for(int e=first[u];e;e=next[e])
{
int v=go[e];
if(v==father || center_vis[v]) continue;
Calc_in(v,u,center,Size,f_time,record);
}
vis[Val[u]] --;
} void Calc_not(int u,int father,int record)
{
if(!vis[Val[u]]) record -= Sum[ Val[u] ];
Ans[u] += record; vis[Val[u]] ++;
for(int e=first[u];e;e=next[e])
{
int v=go[e];
if(v==father || center_vis[v]) continue;
Calc_not(v,u,record);
}
vis[Val[u]] --;
} void Dfs(int u)
{
Max = n;
Getsize(u,);
Getcenter(u,,S[u].size);
Getsize(center,);
center_vis[center] = ; int num=; for(int e=first[center];e;e=next[e]) if(!center_vis[go[e]]) V[++num]=go[e]; for(int i=;i<=num;i++)
{
int v=V[i];
int Size = S[center].size - S[v].size - ;
vis[Val[center]] = ;
Calc_in(v,center,center, Size,,All - Sum[Val[center]] + Size);
vis[Val[center]] = ;
Ad_sum(v,center);
}
while(top) Sum[Stack[top--]]=; All=; for(int i=num;i>=;i--)
{
int v=V[i];
vis[Val[center]] = ;
Calc_not(v,center, All-Sum[Val[center]]);
vis[Val[center]] = ;
Ad_sum(v,center);
} while(top) Sum[Stack[top--]]=; All=;
for(int e=first[center];e;e=next[e])
{
int v=go[e];
if(center_vis[v]) continue;
Dfs(v);
}
} } int main()
{
n=get();
for(int i=;i<=n;i++) Val[i]=get(); for(int i=;i< n;i++)
{
x=get(); y=get();
Add(x,y);
} Point:: Dfs();
for(int i=;i<=n;i++)
printf("%d\n",Ans[i]+);
}
05-11 20:41