参考:http://hzwer.com/6888.html

把k条道路权值设为0,和其他边一起跑MST,然后把此时选中的其他边设为必选,在新图中加上必选变缩成k个点,把所有边重标号,枚举k跳边的选取情况,和其他边做MST,建出树,k条边的权值在树上取min

#include<iostream>
#include<cstdio>
#include<algorithm>
using namespace std;
const int N=300005,inf=1e9;
int n,m,k,top,cnt,s,f[N],fa[N],p[N],po[N],de[N],h[N],mn[N];
long long va[N],sum[N],ans;
bool mk[N];
struct qwe
{
int ne,to;
}e[N];
struct bian
{
int u,v,w;
}a[N],b[25],q[N];
bool cmp(const bian &a,const bian &b)
{
return a.w<b.w;
}
int read()
{
int r=0,f=1;
char p=getchar();
while(p>'9'||p<'0')
{
if(p=='-')
f=-1;
p=getchar();
}
while(p>='0'&&p<='9')
{
r=r*10+p-48;
p=getchar();
}
return r*f;
}
inline void add(int u,int v)
{
cnt++;
e[cnt].ne=h[u];
e[cnt].to=v;
h[u]=cnt;
}
inline void ins(int u,int v)
{
add(u,v);
add(v,u);
}
inline int zhao(int x)
{
return x==fa[x]?x:fa[x]=zhao(fa[x]);
}
inline int zh(int x)
{
return x==f[x]?x:f[x]=zh(f[x]);
}
void dp(int u)
{
sum[u]=va[u];
for(int i=h[u];i;i=e[i].ne)
if(e[i].to!=f[u])
{
de[e[i].to]=de[u]+1;
f[e[i].to]=u;
dp(e[i].to);
sum[u]+=sum[e[i].to];
}
}
void wk()
{
cnt=0;
for(int i=1;i<=k+1;i++)
{
int p=po[i];
h[p]=f[p]=0;
fa[p]=p,mn[p]=inf;
}
for(int i=1;i<=k;i++)
if(mk[i])
{
int fu=zhao(b[i].u),fv=zhao(b[i].v);
if(fu==fv)
return;
fa[fu]=fv;
ins(b[i].u,b[i].v);
}
for(int i=1;i<=k;i++)
{
int fu=zhao(q[i].u),fv=zhao(q[i].v);
if(fu!=fv)
{
fa[fu]=fv;
ins(q[i].u,q[i].v);
}
}
dp(s);
for(int i=1;i<=k;i++)
{
int u=q[i].u,v=q[i].v;
if(de[u]>de[v])
swap(u,v);
while(de[u]!=de[v])
mn[v]=min(mn[v],q[i].w),v=f[v];
while(u!=v)
{
mn[v]=min(mn[v],q[i].w);
mn[u]=min(mn[u],q[i].w);
u=f[u],v=f[v];
}
}
long long con=0;
for(int i=1;i<=k;i++)
if(mk[i])
{
int u=b[i].u,v=b[i].v;
if(de[u]>de[v])
swap(u,v);
con+=mn[v]*sum[v];
}
ans=max(ans,con);
}
void dfs(int x)
{
if(x==k+1)
{
wk();
return;
}
mk[x]=0;
dfs(x+1);
mk[x]=1;
dfs(x+1);
}
int main()
{
n=read(),m=read(),k=read();
for(int i=1;i<=m;i++)
a[i].u=read(),a[i].v=read(),a[i].w=read();
sort(a+1,a+1+m,cmp);
for(int i=1;i<=k;i++)
b[i].u=read(),b[i].v=read();
for(int i=1;i<=n;i++)
p[i]=read();
for(int i=1;i<=n;i++)
fa[i]=f[i]=i;
for(int i=1;i<=k;i++)
fa[zhao(b[i].u)]=zhao(b[i].v);
for(int i=1;i<=m;i++)
{
int u=a[i].u,v=a[i].v;
if(zhao(u)!=zhao(v))
{
fa[zhao(u)]=fa[zhao(v)];
f[zh(u)]=f[zh(v)];
}
}
s=zh(1);
for(int i=1;i<=n;i++)
{
int z=zh(i);
va[z]+=p[i];
if(z==i)
po[++po[0]]=i;
}
for(int i=1;i<=k;i++)
b[i].u=zh(b[i].u),b[i].v=zh(b[i].v);
for(int i=1;i<=m;i++)
a[i].u=zh(a[i].u),a[i].v=zh(a[i].v);
for(int i=1;i<=m;i++)
{
int fu=zh(a[i].u),fv=zh(a[i].v);
if(fu!=fv)
mk[i]=1,f[fu]=fv;
}
for(int i=1;i<=m;i++)
if(mk[i])
q[++top]=a[i];
dfs(1);
printf("%lld\n",ans);
return 0;
}
05-11 22:36