Description
国家有一个大工程,要给一个非常大的交通网络里建一些新的通道。
我们这个国家位置非常特殊,可以看成是一个单位边权的树,城市位于顶点上。
在 2 个国家 a,b 之间建一条新通道需要的代价为树上 a,b 的最短路径。
现在国家有很多个计划,每个计划都是这样,我们选中了 k 个点,然后在它们两两之间 新建 C(k,2)条 新通道。
现在对于每个计划,我们想知道:
1.这些新通道的代价和
2.这些新通道中代价最小的是多少
3.这些新通道中代价最大的是多少
Input
第一行 n 表示点数。
接下来 n-1 行,每行两个数 a,b 表示 a 和 b 之间有一条边。
点从 1 开始标号。 接下来一行 q 表示计划数。
对每个计划有 2 行,第一行 k 表示这个计划选中了几个点。
第二行用空格隔开的 k 个互不相同的数表示选了哪 k 个点。
Output
输出 q 行,每行三个数分别表示代价和,最小代价,最大代价。
题解:建出来虚树后就不是很难了
#include<bits/stdc++.h>
#define setIO(s) freopen(s".in","r",stdin), freopen(s".out","w",stdout)
#define maxn 2000001
#define inf 1000000000
#define ll long long
using namespace std;
vector<int>G[maxn];
int edges,tim,root,top;
int hd[maxn], to[maxn<<1], val[maxn<<1], nex[maxn<<1];
int dep[maxn],Top[maxn],hson[maxn],siz[maxn],dfn[maxn],fa[maxn],arr[maxn],S[maxn],mk[maxn];
inline void add(int u,int v)
{
nex[++edges]=hd[u],hd[u]=edges,to[edges]=v,val[edges]=1;
}
void dfs1(int u,int ff)
{
fa[u]=ff,siz[u]=1,dep[u]=dep[ff]+1,dfn[u]=++tim;
for(int i=hd[u];i;i=nex[i])
{
int v=to[i];
if(v==ff) continue;
dfs1(v,u);
siz[u]+=siz[v];
if(siz[v]>siz[hson[u]]) hson[u]=v;
}
}
void dfs2(int u,int tp)
{
Top[u]=tp;
if(hson[u]) dfs2(hson[u],tp);
for(int i=hd[u];i;i=nex[i])
{
int v=to[i];
if(v==fa[u]||v==hson[u]) continue;
dfs2(v,v);
}
}
inline int LCA(int x,int y)
{
while(Top[x]!=Top[y])
{
dep[Top[x]] > dep[Top[y]] ? x = fa[Top[x]] : y = fa[Top[y]];
}
return dep[x] < dep[y] ? x : y;
}
inline int getdis(int x,int y)
{
return dep[x] + dep[y] - (dep[LCA(x,y)] << 1);
}
inline void addvir(int u,int v)
{
G[u].push_back(v);
}
void insert(int x)
{
if(top<=1) { S[++top]=x; return; }
int lca=LCA(x, S[top]);
if(lca == S[top]) { S[++top] = x; return; }
while(top > 1 && dep[S[top - 1]] >= dep[lca]) addvir(S[top - 1], S[top]), --top;
if(lca != S[top]) addvir(lca, S[top]), S[top] = lca;
S[++top] = x;
}
bool cmp(int i,int j)
{
return dfn[i] < dfn[j];
}
ll ans=0,a1,a2;
int size[maxn],d1[maxn],d2[maxn],dmin1[maxn],k,dmin2[maxn];
void DP(int x)
{
size[x]=mk[x];
d1[x]=d2[x]=0;
if(!mk[x]) d1[x]=d2[x]=-inf;
dmin1[x]=dmin2[x]=inf;
if(mk[x]) dmin1[x]=0;
for(int i=0;i<G[x].size();++i)
{
int v = G[x][i],w = dep[G[x][i]] - dep[x];
DP(v);
if(mk[v])
{
if(w <= dmin1[x]) dmin2[x]=dmin1[x], dmin1[x]=w;
else if(w < dmin2[x]) dmin2[x]=w;
}
else
{
if(w + dmin1[v] <= dmin1[x]) dmin2[x]=dmin1[x], dmin1[x]=w + dmin1[v];
else if(w + dmin1[v] < dmin2[x]) dmin2[x] = w + dmin1[v];
}
int curd=w+d1[v];
if(curd >= d1[x])
{
d2[x]=d1[x], d1[x]=curd;
}
else if(curd > d2[x])
{
d2[x] = curd;
}
ans+=1ll*size[v]*w*(k-size[v]),size[x]+=size[v];
}
a1=max(a1, 1ll*(d1[x] + d2[x]));
a2=min(a2, 1ll*(dmin1[x] + dmin2[x]));
}
void init(int x)
{
d1[x]=d2[x]=0;
dmin1[x]=dmin2[x]=inf;
for(int i=0;i<G[x].size();++i) init(G[x][i]);
G[x].clear();
}
inline void work()
{
scanf("%d",&k);
for(int i=1;i<=k;++i) scanf("%d",&arr[i]);
for(int i=1;i<=k;++i) mk[arr[i]] = 1;
sort(arr+1,arr+1+k,cmp);
top=S[0]=root=ans=0;
if(arr[1]!=1) S[top=1]=1;
for(int i=1;i<=k;++i) insert(arr[i]);
while(top > 1) addvir(S[top-1], S[top]),--top;
a1=-inf, a2=inf, DP(1);
printf("%lld %lld %lld\n",ans,a2,a1);
init(1);
for(int i=1;i<=k;++i) mk[arr[i]]=0;
}
int main()
{
// setIO("input");
int n;
scanf("%d",&n);
for(int i=1;i<n;++i)
{
int a,b;
scanf("%d%d",&a,&b);
add(a,b), add(b,a);
}
dfs1(1,0),dfs2(1,1);
int Q;
scanf("%d",&Q);
for(int i=1;i<=Q;++i) work();
return 0;
}