我丢 之前sun在某校集训给我看过 当时也没想起来 今天补省集的锅的时候发现 wok这题我还听过?!
身败名裂.jpg (可是你记性不好这事情不已经人尽皆知了吗?
咳咳 回归正题
考虑对于两个同色的点:
1)不构成祖先关系
那么两个子树里的点都不能选 相当于矩形覆盖
2)构成祖先关系
祖先刨掉一个子树,儿子子树不能选
拆成两个矩形
最后考虑统计答案,发现对称做然后(总点数-答案)/2就是答案
(因为对角线上的点总是合法的 所以要加上qwq)
然后就是矩形的并数点了 直接扫描线+线段树就好了
//Love and Freedom.
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<cmath>
#include<vector>
#include<cassert>
#define ll long long
#define inf 20021225
#define N 100010
#define pb push_back
using namespace std;
int read()
{
int s=,f=; char ch=getchar();
while(ch<''||ch>''){if(ch=='-') f=-; ch=getchar();}
while(ch>=''&&ch<='') s=s*+ch-'',ch=getchar();
return f*s;
}
//--------- tree -----------
struct edge{int to,lt;}e[N<<];
int in[N],cnt,dep[N],sz[N],f[N][],dfn[N],tms,idfn[N];
void add(int x,int y)
{
e[++cnt].to=y; e[cnt].lt=in[x]; in[x]=cnt;
e[++cnt].to=x; e[cnt].lt=in[y]; in[y]=cnt;
}
void dfs(int x)
{
dfn[x]=++tms; idfn[tms]=x; sz[x]=;
for(int i=;i<;i++) f[x][i]=f[f[x][i-]][i-];
for(int i=in[x];i;i=e[i].lt)
{
int y=e[i].to; if(y==f[x][]) continue;
dep[y]=dep[x]+; f[y][]=x; dfs(y); sz[x]+=sz[y];
}
}
int LCA(int x,int y)
{
if(dep[x]<dep[y]) swap(x,y);
int len=dep[x]-dep[y];
for(int i=;i<;i++) if(len>>i&)
x=f[x][i];
if(x==y) return x;
for(int i=;~i;i--) if(f[x][i]!=f[y][i])
x=f[x][i],y=f[y][i];
return f[x][];
}
int col[N];
struct node
{
int xd,xu,y,v;
}p[N*];
bool operator<(node a,node b){return a.y<b.y;}
//----------- SGT ------------
// 0->not fully covered >0 -> fully covered
#define ls x<<1
#define rs x<<1|1
int s[N<<],tag[N<<];
void insert(int x,int l,int r,int LL,int RR,int v)
{
if(RR<LL) return;
if(LL<=l&&RR>=r)
{
tag[x]+=v;
if(tag[x]>) s[x]=r-l+;
else if(l==r) s[x]=;
else s[x]=s[ls]+s[rs];
return;
}
int mid=l+r>>;
if(LL<=mid) insert(ls,l,mid,LL,RR,v);
if(mid<RR) insert(rs,mid+,r,LL,RR,v);
if(!tag[x]) s[x]=s[ls]+s[rs];
}
vector<int> cc[N];
int tot;
int main()
{
//freopen("tree3-4.in","r",stdin);
int n=read();
for(int i=;i<=n;i++) col[i]=read(),cc[col[i]].pb(i);
for(int i=;i<n;i++) add(read(),read()); dfs();
for(int i=;i<=n;i++)
for(int j=;j<cc[i].size();j++) for(int k=j+;k<cc[i].size();k++)
{
int x=cc[i][j],y=cc[i][k],z=LCA(x,y);
if(z==x||z==y)
{
if(z==y) swap(x,y);
int len=dep[y]-dep[x]-; x=y;
for(int i=;i<;i++) if(len>>i&)
x=f[x][i];
p[++tot]=(node){,dfn[x]-,dfn[y],};
p[++tot]=(node){,dfn[x]-,dfn[y]+sz[y],-};
p[++tot]=(node){dfn[x]+sz[x],n,dfn[y],};
p[++tot]=(node){dfn[x]+sz[x],n,dfn[y]+sz[y],-}; p[++tot]=(node){dfn[y],dfn[y]+sz[y]-,,};
p[++tot]=(node){dfn[y],dfn[y]+sz[y]-,dfn[x],-};
p[++tot]=(node){dfn[y],dfn[y]+sz[y]-,dfn[x]+sz[x],};
p[++tot]=(node){dfn[y],dfn[y]+sz[y]-,n+,-};
}
else
{
p[++tot]=(node){dfn[x],dfn[x]+sz[x]-,dfn[y],};
p[++tot]=(node){dfn[x],dfn[x]+sz[x]-,dfn[y]+sz[y],-};
p[++tot]=(node){dfn[y],dfn[y]+sz[y]-,dfn[x],};
p[++tot]=(node){dfn[y],dfn[y]+sz[y]-,dfn[x]+sz[x],-};
}
}
sort(p+,p+tot+); ll ans=;
for(int i=;i<=tot;i++)
{
insert(,,n,p[i].xd,p[i].xu,p[i].v);
if(i!=tot && p[i].y!=p[i+].y)
ans+=1ll*(p[i+].y-p[i].y)*s[];
}
printf("%lld\n",(1ll*n*(n+)-ans)>>);
return ;
}