[xsy2164]theory-LMLPHP

又积累了一个网络流模型:最大权闭合子图,相关证明去看论文,感觉自己不是很懂证明,但现在还是先把建模记下来再说吧

枚举一个点,硬点它一定要被选中,那么以它为根,如果选了$x$就必须要选$fa_x$,这就是闭合图的定义了,再加上权值最大,所以直接上最大权闭合子图即可

最大权闭合子图的建模方法:把原图中的每一条边容量设为$+\infty$,对每个$v_x\gt0$的点,连边$(s,x,v_x)$,对每个$v_x\lt0$的点,连边$(x,t,|v_x|)$,所有正权点的权值之和减去$s\rightarrow t$的最小割就是答案

#include<stdio.h>
#include<string.h>
const int inf=1000000000;
int min(int a,int b){return a<b?a:b;}
int max(int a,int b){return a>b?a:b;}
int h[110],cur[110],nex[810],to[810],cap[810],dis[110],q[110],M,S,T;
void add(int a,int b,int c){
	M++;
	to[M]=b;
	cap[M]=c;
	nex[M]=h[a];
	h[a]=M;
	M++;
	to[M]=a;
	cap[M]=0;
	nex[M]=h[b];
	h[b]=M;
}
bool bfs(){
	int head,tail,i,x;
	head=tail=1;
	q[1]=S;
	memset(dis,-1,sizeof(dis));
	dis[S]=0;
	while(head<=tail){
		x=q[head];
		head++;
		for(i=h[x];i;i=nex[i]){
			if(cap[i]>0&&dis[to[i]]==-1){
				dis[to[i]]=dis[x]+1;
				if(to[i]==T)return 1;
				tail++;
				q[tail]=to[i];
			}
		}
	}
	return dis[T]>0;
}
int dfs(int x,int flow){
	if(x==T)return flow;
	int i,f;
	for(i=cur[x];i;i=nex[i]){
		if(cap[i]>0&&dis[to[i]]==dis[x]+1){
			f=dfs(to[i],min(flow,cap[i]));
			if(f){
				cap[i]-=f;
				cap[i^1]+=f;
				if(cap[i])cur[x]=i;
				return f;
			}
		}
	}
	dis[x]=-1;
	return 0;
}
int dicnic(){
	int ans=0,tmp;
	while(bfs()){
		memcpy(cur,h,sizeof(h));
		while(tmp=dfs(S,inf))ans+=tmp;
	}
	return ans;
}
struct tree{
	int h[110],nex[210],to[210],M;
	void reset(){
		M=0;
		memset(h,0,sizeof(h));
	}
	void add(int a,int b){
		M++;
		to[M]=b;
		nex[M]=h[a];
		h[a]=M;
	}
	void dfs(int f,int x){
		for(int i=h[x];i;i=nex[i]){
			if(to[i]!=f){
				::add(to[i],x,inf);
				dfs(x,to[i]);
			}
		}
	}
}a,b;
int v[110],n;
int get(int x){
	int i,sum;
	M=1;
	memset(h,0,sizeof(h));
	a.dfs(0,x);
	b.dfs(0,x);
	sum=0;
	for(i=1;i<=n;i++){
		if(v[i]>0){
			sum+=v[i];
			add(S,i,v[i]);
		}
		if(v[i]<0)add(i,T,-v[i]);
	}
	return sum-dicnic();
}
int main(){
	int T,i,x,y,ans;
	scanf("%d",&T);
	while(T--){
		a.reset();
		b.reset();
		scanf("%d",&n);
		S=n+1;
		::T=n+2;
		for(i=1;i<=n;i++)scanf("%d",v+i);
		for(i=1;i<n;i++){
			scanf("%d%d",&x,&y);
			a.add(x,y);
			a.add(y,x);
		}
		for(i=1;i<n;i++){
			scanf("%d%d",&x,&y);
			b.add(x,y);
			b.add(y,x);
		}
		ans=-inf;
		for(i=1;i<=n;i++)ans=max(ans,get(i));
		printf("%d\n",ans);
	}
}
05-11 00:11