插头DP学习笔记

用途

有些 状压 \(DP\) 问题要求我们记录状态的连通性信息,这类问题一般被形象的称为插头 \(DP\) 或连通性状态压缩 \(DP\)

例如格点图的哈密顿路径计数,求棋盘的黑白染色方案满足相同颜色之间形成一个连通块的方案数,以及特定图的生成树计数等等。

这些问题通常需要我们对状态的连通性进行编码,讨论状态转移过程中连通性的变化。

例题

洛谷P5056 【模板】插头dp

首先要明确两个概念:

轮廓线:已决策状态和未决策状态的分界线。

插头DP学习笔记-LMLPHP

插头:一个格子某个方向的插头存在,表示这个格子在这个方向与相邻格子相连。

插头DP学习笔记-LMLPHP

我们要状压的就是轮廓线上插头的状态。

具体来说,可以把路径的合并看作括号的匹配。

一般的状压 \(dp\) 只有 \(0,1\) 两个状态,分别表示有和没有,但是这道题需要用三进制的状态压缩,\(0,1,2\) 分别表示没有括号,有左括号,有右括号。

之所以要分左右括号是因为要区分下面这两种情况:

插头DP学习笔记-LMLPHP

插头DP学习笔记-LMLPHP

第一种情况中间的两个括号是不能合并的,因为要恰好形成一条回路,第二种情况则能够合并。

一条轮廓线会由 \(n+1\) 条线段组成,其中 \(n\) 条是左右方向的,另外 \(1\) 条是上下方向的。

为了方便解压状态,我们用四进制来表示,同时减少枚举的状态,要把所有的状态存到哈希表里。

转移的时候大力分类讨论:

\(1\)、当前的位置不能有路径经过

如果没有向右的插头或者向下的插头,直接继承上一个格子的答案,

否则不存在合法的方案。

if(s[i][j]=='*'){
	if(!r && !dow) f[now].ad(nzt,nval);
}

\(2\)、当前的位置必须经过并且没有向右的插头或者向下的插头。

需要在当前的格子新开一个向右的插头和向下的插头,并且把向右的插头标记为右括号,把向下的插头标记为左括号。

我在转移状态之前就去判断这个状态是否合法,这样会比较好写。

else if(!r && !dow){
	if(s[i][j+1]=='.' && s[i+1][j]=='.') f[now].ad(nzt|2|(1<<j*2),nval);
}

\(3\)、当前的位置必须经过并且只有向右的插头或者向下的插头。

可以继续沿着之前的方向或者改变插头的方向,左右括号不变。

 else if(r && !dow){
		if(s[i][j+1]=='.') f[now].ad(nzt,nval);
		if(s[i+1][j]=='.') f[now].ad(nzt^r|(r<<j*2),nval);
} else if(dow && !r){
		if(s[i+1][j]=='.') f[now].ad(nzt,nval);
		if(s[i][j+1]=='.') f[now].ad(nzt^(dow<<j*2)|dow,nval);
}

\(4\)、当前的位置必须经过并且有一个代表左括号的右插头和一个代表右括号的下插头。

如果当前的点是图中右下角的终止节点并且不存在其它匹配的括号更新答案。

else if(r==1 && dow==2){
		if(i==edx && j==edy && (nzt^r^(dow<<j*2))==0) ans+=nval;
}

\(5\)、当前的位置必须经过并且有一个代表左括号的下插头和一个代表右括号的右插头。

将这两个括号匹配。

else if(r==2 && dow==1){
		f[now].ad(nzt^r^(dow<<j*2),nval);
}

\(6\)、当前的位置必须经过并且有一个下插头和一个右插头,并且这两个插头都代表左括号。

一直向右找,找到第一个左括号和右括号恰好匹配的位置,把这个位置的右括号改为左括号,之前的两个左括号直接匹配。

else if(r==1 && dow==1){
	cs1=nzt^r^(dow<<j*2);
	for(rg int o=j+1,p=1;o<=m;o++){
		cs2=cs1>>o*2&3;
		p+=(cs2==1)-(cs2==2);
		if(!p){
			cs1^=3<<o*2;
			break;
		}
	}
	f[now].ad(cs1,nval);
}

\(7\)、当前的位置必须经过并且有一个下插头和一个右插头,并且这两个插头都代表右括号。

和上面的情况一样,但是需要改成向左找。

else if(r==2 && dow==2){
	cs1=nzt^r^(dow<<j*2);
	for(rg int o=j-1,p=1;o;o--){
		cs2=cs1>>o*2&3;
		p+=(cs2==2)-(cs2==1);
		if(!p){
			cs1^=3<<o*2;
			break;
		}
	}
	f[now].ad(cs1,nval);
}

代码

#include<cstdio>
#include<iostream>
#include<cstring>
#include<cmath>
#include<algorithm>
#define rg register
const int maxn=14,mod=1e5+3,maxm=1e5+5;
int n,m,edx,edy;
char s[maxn][maxn];
long long ans=0;
struct has{
	struct asd{
		int nxt,zt;
		long long val;
	}b[maxm];
	has(){
		memset(h,-1,sizeof(h));
		tot=1;
	}
	int tot,h[maxm];
	void cls(){
		memset(h,-1,sizeof(h));
		tot=1;
	}
	void ad(rg int zt,rg long long val){
		rg int now=zt%mod;
		for(rg int i=h[now];i!=-1;i=b[i].nxt){
			if(b[i].zt==zt){
				b[i].val+=val;
				return;
			}
		}
		b[tot].val=val;
		b[tot].zt=zt;
		b[tot].nxt=h[now];
		h[now]=tot++;
	}
}f[2];
int main(){
	scanf("%d%d",&n,&m);
	for(rg int i=1;i<=n;i++){
		scanf("%s",s[i]+1);
		for(rg int j=1;j<=m;j++){
			if(s[i][j]=='.') edx=i,edy=j;
		}
	}
	rg int now=0,nzt,r,dow,cs1,cs2;
	rg long long nval;
	f[0].ad(0,1);
	for(rg int i=1;i<=n;i++){
		for(rg int j=1;j<=m;j++){
			now^=1;
			f[now].cls();
			for(rg int k=1;k<f[now^1].tot;k++){
				nzt=f[now^1].b[k].zt,nval=f[now^1].b[k].val;
				r=nzt&3,dow=nzt>>j*2&3;
				if(s[i][j]=='*'){
					if(!r && !dow) f[now].ad(nzt,nval);
				} else if(!r && !dow){
					if(s[i][j+1]=='.' && s[i+1][j]=='.') f[now].ad(nzt|2|(1<<j*2),nval);
				} else if(r && !dow){
					if(s[i][j+1]=='.') f[now].ad(nzt,nval);
					if(s[i+1][j]=='.') f[now].ad(nzt^r|(r<<j*2),nval);
				} else if(dow && !r){
					if(s[i+1][j]=='.') f[now].ad(nzt,nval);
					if(s[i][j+1]=='.') f[now].ad(nzt^(dow<<j*2)|dow,nval);
				} else if(r==1 && dow==2){
					if(i==edx && j==edy && (nzt^r^(dow<<j*2))==0) ans+=nval;
				} else if(r==2 && dow==1){
					f[now].ad(nzt^r^(dow<<j*2),nval);
				} else if(r==1 && dow==1){
					cs1=nzt^r^(dow<<j*2);
					for(rg int o=j+1,p=1;o<=m;o++){
						cs2=cs1>>o*2&3;
						p+=(cs2==1)-(cs2==2);
						if(!p){
							cs1^=3<<o*2;
							break;
						}
					}
					f[now].ad(cs1,nval);
				} else if(r==2 && dow==2){
					cs1=nzt^r^(dow<<j*2);
					for(rg int o=j-1,p=1;o;o--){
						cs2=cs1>>o*2&3;
						p+=(cs2==2)-(cs2==1);
						if(!p){
							cs1^=3<<o*2;
							break;
						}
					}
					f[now].ad(cs1,nval);
				}
			}
		}
	}
	printf("%lld\n",ans);
	return 0;
}
04-06 02:52