前几天讲了虚树,今天就来切一道虚树的题喽。
题目概述:
给你一棵树,有\(q\)个询问,每个询问给出\(k\)个点,两两连边,边的长度为其在树上的距离,求出这\(k\)个点连边总长度、最短的一条边以及最长的一条边。
其中\(\sum_{i=1}^qk<=5^4\)
大体思路:
显然是道虚树的题,那就先构建棵虚树。
对于第二三两个询问,我们都可以通过一个非常简单的DP来实现,每个询问分别用两个数组求出以它为根,第一、第二(大/小)的路径的值,但是要注意这里的两个值不能都是从同一个子树中转移过来的。
关于第一个询问,我们考虑每一条边对于答案的贡献,不难发现其被加的次数为\(siz[u]\times siz[v]\),那么我们把每一条边的贡献加起来即可。
具体实现 :
对于非询问点(即通过\(LCA\)关系而加入虚树的点),我们不妨称之为关系户。
关系户不能进入\(siz\)的统计(显然)。
关系户只能作为连接点,即只能以第一、第二(大/小)的路径的值相加来更新答案。(显然)
虚树的根为所有点加入完毕后,剩下的栈中的那个栈顶。(显然)
代码:
#include <bits/stdc++.h>
using namespace std;
const int N=1e6+5,inf=1e9;
int f[N][21],fir[N],sec[N],Fir[N],Sec[N],nxt[N<<1],vet[N<<1],head[N],dep[N],dfn[N],siz[N],stk[N],a[N];
int n,m,ans1,ans2,Q,u,v,tot,tim,top,rt;
long long ans0;
bool flag[N];
struct Edge{
int v,d;
};
vector <Edge> E[N];
void add(int u,int v){
nxt[++tot]=head[u];
vet[tot]=v;
head[u]=tot;
}
void dfs0(int u,int fa){
f[u][0]=fa,dep[u]=dep[fa]+1,dfn[u]=++tim;
for (int i=1;i<=20;i++)
f[u][i]=f[f[u][i-1]][i-1];
for (int i=head[u];i;i=nxt[i]){
int v=vet[i];
if (v==fa) continue;
dfs0(v,u);
}
}
int lca(int a,int b){
if (dep[a]<dep[b]) swap(a,b);
if (b==0) return 0;
int res=inf;
for (int i=20;i>=0;i--)
if (dep[f[a][i]]>=dep[b]) a=f[a][i];
if (a==b) return a;
for (int i=20;i>=0;i--)
if (f[a][i]!=f[b][i])
a=f[a][i],b=f[b][i];
return f[a][0];
}
int dis(int a,int b){return dep[a]+dep[b]-dep[lca(a,b)]*2;}
void ins(int x){
int l=lca(stk[top],x);
while (top>1&&dep[stk[top-1]]>dep[l]){
E[stk[top]].push_back((Edge){stk[top-1],dis(stk[top],stk[top-1])});
E[stk[top-1]].push_back((Edge){stk[top],dis(stk[top],stk[top-1])});
top--;
}
if (dep[l]<dep[stk[top]]){
E[l].push_back((Edge){stk[top],dis(stk[top],l)});
E[stk[top]].push_back((Edge){l,dis(stk[top],l)});
top--;
}
if (stk[top]!=l) stk[++top]=l;
stk[++top]=x;
}
void dfs(int u,int fa){
Fir[u]=Sec[u]=inf,fir[u]=sec[u]=siz[u]=0;
if (flag[u]) siz[u]=1,Fir[u]=0;
for (int i=0;i<E[u].size();i++){
int v=E[u][i].v;
if (v==fa) continue;
dfs(v,u);
ans0+=1ll*E[u][i].d*siz[v]*(m-siz[v]);
siz[u]+=siz[v];
int x=Fir[v]+E[u][i].d;
if (x<=Fir[u])
Sec[u]=Fir[u],Fir[u]=x;
else
if (x<Sec[u]) Sec[u]=x;
x=fir[v]+E[u][i].d;
if (x>=fir[u])
sec[u]=fir[u],fir[u]=x;
else
if (x>sec[u]) sec[u]=x;
}
ans1=min(ans1,Fir[u]+Sec[u]);
ans2=max(ans2,fir[u]+sec[u]);
E[u].clear();
}
bool cmp(int a,int b){return dfn[a]<dfn[b];}
int main(){
scanf("%d",&n);
for (int i=1;i<n;i++){
scanf("%d%d",&u,&v);
add(u,v),add(v,u);
}
dfs0(1,0);
scanf("%d",&Q);
while (Q--){
ans0=ans2=0,ans1=inf;
scanf("%d",&m);
if (m==1){
puts("0 0 0");
continue;
}
top=0;
for (int i=1;i<=m;i++)
scanf("%d",&a[i]),flag[a[i]]=1;
rt=a[1];
sort(a+1,a+1+m,cmp);
for (int i=1;i<=m;i++) ins(a[i]);
while (top>1){
E[stk[top]].push_back((Edge){stk[top-1],dis(stk[top],stk[top-1])});
E[stk[top-1]].push_back((Edge){stk[top],dis(stk[top],stk[top-1])});
top--;
}
dfs(stk[top],0);
for (int i=1;i<=m;i++) flag[a[i]]=0;
printf("%lld %d %d\n",ans0,ans1,ans2);
}
return 0;
}