2017国家集训队作业[agc008f]Black Radius
时隔4个月,经历了省赛打酱油和中考各种被吊打后,我终于回想起了我博客园的密码= =
题意:
给你一棵树,树上有若干个关键点。选中某个关键点和一个参数d,把所有与关键点距离不超过距离d的点染黑,问一共有多少种染色方案,两种染色方案不同当且仅当存在一个节点在两种方案中的颜色不同,初始全为白色。(点数\(N\leq2*10^5\))
题解:
好题啊,模拟赛遇到这题真的orz不会做:)
做这种题,套路就是找到一种不重不漏的计数方式。假设所有点都是关键点,考虑用二元组\((u,d)\)来表示与\(u\)距离不超过\(d\)的集合,那么,我们要求的其实是所有本质不同的\((u,d)\)的个数。
\(\bullet\)考虑本质不同的\((u,d)\)有什么性质
\(\bullet1.\)要钦定\((u,d)\)不能是全集,最后\(ans\)再加上一就行了
\(\bullet2.\)对于所有相邻的点\(u\)和\(v\),\((u,d)\)不能与\((v,d-1)\)相同
我们记\(u\)与最远点的距离为\(mxd(u)\),\(u\)与次远点的距离为\(sxd(u)\),则这两点就等价于:
d< mxd(u) \\
d<sxd(u)+2
\end{cases}
\]
证明:
一式显然。
不妨设\(u\)是\(v\)的父亲,当前\(d\)使得\(u\)在以它为起点的最长链上点集包括的最远点为\(k\),我们考虑如何让\((u,d)\)与\(v\)的某个二元组出现重复。
有两种情况:
\(\bullet v\)在以它为起点的最长链上,此时为了保证让\((v,d')\)在这条链上的最远点依然为\(k\),则有\(d'=d-1\)。但这时,\((v,d')\)在\(u\)的其它儿子的子树中的最远点向上移了2位,为了保证它们是相同的,就有\(d\ge sxd(u)+2\)。
\(\bullet v\)不在最长链上,同上则\(d'=d+1\),此时一定会满足\(d+1\ge sxd(v)+2\)。
反之即\(d<sxd(u)+2\),证毕。
于是我们就可以愉快地拿到全部是关键点的部分分。
现在我们考虑当它不全是关键点的情况。
对于一个非关键点\(u\),和关键点\(v\),为了满足\((u,d)\)是\(v\)的染色的方案,\((u,d)\)必须包涵\(v\)子树内的所有点,这是\(d\)的下界。对于关键点\(u\)它的\(d\)的下界为\(0\)。树型DP处理即可,时间复杂度\(O(N)\)。
#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
#define fo(i,l,r) for(int i=l;i<=r;i++)
#define of(i,l,r) for(int i=l;i>=r;i--)
#define fe(i,u) for(int i=head[u];i;i=e[i].next)
using namespace std;
typedef long long ll;
inline int rd()
{
static int x,f;
x=0,f=1;
char ch=getchar();
for(;ch<'0'||ch>'9';ch=getchar())if(ch=='-')f=-1;
for(;ch>='0'&&ch<='9';ch=getchar())x=x*10+ch-'0';
return x*f;
}
const int N=200010,Inf=1000000000;
struct edge{
int v,next;
edge(int v=0,int next=0):v(v),next(next){}
}e[N<<1];
int n;ll ans=0;
int head[N],tot=0;
int mx[N],sx[N],mn[N],siz[N];
char s[N];
inline void add(int u,int v){e[++tot]=edge(v,head[u]);head[u]=tot;}
void dfs1(int u,int fat)
{
if(s[u]=='1')mn[u]=0,siz[u]=1;
else mn[u]=Inf;
fe(i,u){
int v=e[i].v;
if(v==fat)continue;
dfs1(v,u);siz[u]+=siz[v];
int dis=mx[v]+1;
if(mx[u]<dis)sx[u]=mx[u],mx[u]=dis;
else sx[u]=max(sx[u],dis);
if(siz[v])mn[u]=min(mn[u],dis);
}
}
void dfs2(int u,int fat)
{
int mxd=min(mx[u]-1,sx[u]+1);
if(mxd>=mn[u])ans+=(ll)mxd-mn[u]+1;
fe(i,u){
int v=e[i].v;
if(v==fat)continue;
int dis=mx[u]==mx[v]+1?sx[u]+1:mx[u]+1;
if(mx[v]<dis)sx[v]=mx[v],mx[v]=dis;
else sx[v]=max(sx[v],dis);
if(siz[v]<siz[1])mn[v]=min(mn[v],dis);
dfs2(v,u);
}
}
int main()
{
n=rd();
fo(i,1,n-1){
int x=rd(),y=rd();
add(x,y);add(y,x);
}
scanf("%s",s+1);
dfs1(1,0);dfs2(1,0);
printf("%lld\n",ans+1ll);
return 0;
}