题目:
题目链接:https://www.luogu.com.cn/problem/P5666
小简单正在学习离散数学,今天的内容是图论基础,在课上他做了如下两条笔记:
- 一个大小为 \(n\) 的树由 \(n\) 个结点与 \(n − 1\) 条无向边构成,且满足任意两个结点间有且仅有一条简单路径。在树中删去一个结点及与它关联的边,树将分裂为若干个子树;而在树中删去一条边(保留关联结点,下同),树将分裂为恰好两个子树。
- 对于一个大小为 \(n\) 的树与任意一个树中结点 \(c\),称 \(c\) 是该树的重心当且仅当在树中删去 \(c\) 及与它关联的边后,分裂出的所有子树的大小均不超过 \(\lfloor \frac{n}{2} \rfloor\)(其中 \(\lfloor x \rfloor\) 是下取整函数)。对于包含至少一个结点的树,它的重心只可能有 1 或 2 个。
课后老师给出了一个大小为 \(n\) 的树 \(S\),树中结点从 \(1 \sim n\) 编号。小简单的课后作业是求出 \(S\) 单独删去每条边后,分裂出的两个子树的重心编号和之和。即:
\]
上式中,\(E\) 表示树 \(S\) 的边集,\((u,v)\) 表示一条连接 \(u\) 号点和 \(v\) 号点的边。\(S'_u\) 与 \(S'_v\) 分别表示树 \(S\) 删去边 \((u,v)\) 后,\(u\) 号点与 \(v\) 号点所在的被分裂出的子树。
小简单觉得作业并不简单,只好向你求助,请你教教他。
思路:
总算\(A\)了\(qwq\),我太菜了。
考虑每一个点能有哪些边可以对他做贡献。
假设现在树根为\(1\),分两种情况:
如果我们不在\(1\)号节点的重儿子中割边,那么我们只要保证割边之后这棵树的大小不小于\(2\times\)重儿子大小即可。
我们设割去边的另一棵树的大小为\(t\),那么也就是说我们只要保证\(2\times max1\leq n-t\),也就是\(t\leq n-2\times max1\)。如果我们在\(1\)号节点的重儿子重割边,那么\(1\)号节点的重儿子可能改变也可能不改变。
如果改变,设\(max2\)表示\(1\)号节点原来的第二大子树的大小,那么我们需要满足\(2\times max2\leq n-t\),也就是\(t\leq n-2\times max2\)。
如果不改变,那么我们需要满足\(2\times (max1-t)\leq n-t\),也就是\(t\geq 2\times max1-n\)
综上,在\(1\)号节点的重儿子重割边只要满足\(2\times max1-n\leq t\leq n-2\times max2\)即可。
那么我们如果可以求出\(1\)号节点每一颗子树的大小,以及在每一棵子树内的有多少个大小为\(t\)的子树,并且支持区间查询(这样就可以求出一颗子树内有多少个子树大小取值在任意区间\([l,r]\)了),那么就可以完成这道题。
我们可以用主席树来维护以\(1\)为根时,\(dfs\)序在\([x,y]\)之间的所有节点,有多少个大小在\([l,r]\)之内。这样就可以直接完成\(1\)为根的计算。
考虑换根,我们把根从\(1\to x\)时,我们发现,以\(x\)为根的子树分为两种:以\(1\)为根时,在\(x\)的子树下的所有子树 和 换根后\(1\)与\(1\)的其他子树所构成的一颗子树。此时求前者的\(t\)是没有问题的,但是要求后者的\(t\),我们考虑用整棵树的\(t\)的取值个数\(-\)前者的\(t\)的取值个数即可。
整棵树的\(t\)的取值个数可以用树状数组动态维护。
累计答案时分类讨论一下当前节点的重儿子是前者还是后者即可。
时间复杂度\(O(n\log n)\)
代码:
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
typedef long long ll;
const int N=300010;
int T,n,tot,head[N],size[N],id[N],rk[N],root[N];
ll ans;
struct edge
{
int next,to,dis;
}e[N*2];
struct Treenode
{
int lc,rc,cnt;
};
struct BIT
{
int c[N];
void clr()
{
memset(c,0,sizeof(c));
}
void add(int x,int val)
{
if (x<=0) return;
for (int i=x;i<=n;i+=i&-i)
c[i]+=val;
}
int ask(int x)
{
if (x<=0) return 0;
int sum=0;
for (int i=x;i;i-=i&-i)
sum+=c[i];
return sum;
}
}bit;
struct Tree
{
Treenode tree[N*50];
int tot;
void clr()
{
memset(tree,0,sizeof(tree));
tot=0;
}
int build(int l,int r)
{
int p=++tot;
if (l==r) return p;
int mid=(l+r)>>1;
tree[p].lc=build(l,mid);
tree[p].rc=build(mid+1,r);
return p;
}
int update(int now,int l,int r,int k)
{
int p=++tot;
tree[p]=tree[now]; tree[p].cnt++;
if (l==r) return p;
int mid=(l+r)>>1;
if (k<=mid) tree[p].lc=update(tree[now].lc,l,mid,k);
else tree[p].rc=update(tree[now].rc,mid+1,r,k);
return p;
}
int ask(int nowl,int nowr,int l,int r,int ql,int qr)
{
if (ql==l && qr==r)
return tree[nowr].cnt-tree[nowl].cnt;
if (ql>qr) return 0;
int mid=(l+r)>>1;
if (qr<=mid) return ask(tree[nowl].lc,tree[nowr].lc,l,mid,ql,qr);
else if (ql>mid) return ask(tree[nowl].rc,tree[nowr].rc,mid+1,r,ql,qr);
else return ask(tree[nowl].lc,tree[nowr].lc,l,mid,ql,mid)+ask(tree[nowl].rc,tree[nowr].rc,mid+1,r,mid+1,qr);
}
}Tree;
void add(int from,int to)
{
e[++tot].to=to;
e[tot].next=head[from];
head[from]=tot;
}
int dfs1(int x,int fa)
{
size[x]=1; id[x]=++tot; rk[tot]=x;
for (int i=head[x];~i;i=e[i].next)
if (e[i].to!=fa) size[x]+=dfs1(e[i].to,x);
bit.add(size[x],1);
// root[id[x]]=Tree.update(root[id[x]-1],1,n,size[x]);
return size[x];
}
void add_ans(int x,int fa)
{
int max1=0,max2=0,pos;
for (int i=head[x];~i;i=e[i].next)
{
int v=e[i].to;
if (size[v]>max1) max2=max1,max1=size[v],pos=v;
else if (size[v]>max2) max2=size[v];
}
if (pos!=fa)
{
int cnt_in=Tree.ask(root[id[pos]-1],root[id[pos]+size[pos]-1],1,n,1,n-max1*2);
int cnt_all=bit.ask(n-max1*2);
int cnt=Tree.ask(root[id[pos]-1],root[id[pos]+size[pos]-1],1,n,max(max1*2-n,1),n-max2*2);
ans+=1LL*x*(cnt_all-cnt_in+cnt);
}
else
{
int cnt=Tree.ask(root[id[x]],root[id[x]+size[x]-1],1,n,1,n-max1*2);
int cnt_all=bit.ask(n-max2*2)-bit.ask(max(max1*2-n-1,0));
int cnt_in=Tree.ask(root[id[x]],root[id[x]+size[x]-1],1,n,max(max1*2-n,1),n-max2*2);
ans+=1LL*x*(cnt_all-cnt_in+cnt);
}
}
void dfs2(int x,int fa)
{
bit.add(size[x],-1);
add_ans(x,fa);
int Cpy=size[x];
for (int i=head[x];~i;i=e[i].next)
{
int v=e[i].to;
if (v!=fa)
{
size[x]=n-size[v];
bit.add(size[x],1);
dfs2(e[i].to,x);
bit.add(size[x],-1);
}
}
size[x]=Cpy;
bit.add(size[x],1);
}
int main()
{
scanf("%d",&T);
while (T--)
{
memset(head,-1,sizeof(head));
tot=ans=0;
scanf("%d",&n);
Tree.clr(); bit.clr();
root[0]=Tree.build(1,n);
for (int i=1,x,y;i<n;i++)
{
scanf("%d%d",&x,&y);
add(x,y); add(y,x);
}
tot=0;
dfs1(1,0);
for (int i=1;i<=n;i++)
root[i]=Tree.update(root[i-1],1,n,size[rk[i]]);
dfs2(1,0);
printf("%lld\n",ans);
}
return 0;
}