题目大意:给出一棵n个点的树和一棵m个点的树,问第一棵树有多少个连通子树与第二棵树同构。(n<=1000,m<=12)
做法:先找出第二棵树的重心(可能为边),以这个重心为根,可以避免重复计算,顺便对第二棵树的每个子树算出判同构的哈希值。枚举第一棵树的一个点/边与第二棵树的根对应,用f[i][j][k]表示以j为父亲的i的子树内,选出子树哈希值为k的方案数,合并的时候用状压DP。前两维合在一起是O(n)级别的,所以总复杂度是O(nm*2^m)。
代码:
#include<cstdio>
#include<algorithm>
#include<vector>
#include<map>
using namespace std;
inline int read()
{
int x;char c;
while((c=getchar())<''||c>'');
for(x=c-'';(c=getchar())>=''&&c<='';)x=x*+c-'';
return x;
}
#define MN 1000
#define MM 13
#define MOD 1000000007
struct edge{int nx,t;}e[MN*+MM*+];
int h[MN+],H1[MM+],H2[MM+],en;
int m,s[MM+],rts,rtx,rty,t[MM+],cnt;
vector<int> v[MM+],vv[MM+];
map<long long,int> mp;
int f[MN+][MN+][MM+];
inline void ins(int*h,int x,int y)
{
e[++en]=(edge){h[x],y};h[x]=en;
e[++en]=(edge){h[y],x};h[y]=en;
}
void dfs(int x,int fa)
{
s[x]=;
int mx=;
for(int i=H1[x];i;i=e[i].nx)if(e[i].t!=fa)
{
dfs(e[i].t,x);
s[x]+=s[e[i].t];
mx=max(mx,s[e[i].t]);
}
mx=max(mx,m-s[x]);
if(mx<rts)rts=mx,rtx=x,rty=;
else if(mx==rts)rty=x;
}
void solve(int x,int fa)
{
long long hash=;
for(int i=H2[x];i;i=e[i].nx)if(e[i].t!=fa)
{
solve(e[i].t,x);
vv[x].push_back(t[e[i].t]);
}
sort(vv[x].begin(),vv[x].end());
for(int i=;i<vv[x].size();++i)hash=hash*+vv[x][i];
t[x]=mp[hash]?mp[hash]:(v[++cnt]=vv[x],mp[hash]=cnt);
}
inline void rw(int&a,int b){if((a+=b)>=MOD)a-=MOD;}
int cal(int x,int fa,int t)
{
if(f[x][fa][t])return f[x][fa][t]-;
int *F=new int[<<v[t].size()];
for(int i=F[]=;i<<<v[t].size();++i)F[i]=;
for(int i=h[x];i;i=e[i].nx)if(e[i].t!=fa)
for(int j=<<v[t].size();j--;)
for(int k=;k<v[t].size();++k)
if(!(j&(<<k))&&(!k||(j&(<<k-))||v[t][k]!=v[t][k-]))
rw(F[j|(<<k)],1LL*F[j]*cal(e[i].t,x,v[t][k]));
f[x][fa][t]=F[(<<v[t].size())-]+;delete F;
return f[x][fa][t]-;
}
int main()
{
int n,i,j,ans=;
for(n=read(),i=;i<n;++i)ins(h,read(),read());
for(m=read(),i=;i<m;++i)ins(H1,read(),read());
rts=m;dfs(,);
if(rty)
{
if(rtx>rty)swap(rtx,rty);
for(i=;i<=m;++i)for(j=H1[i];j;j=e[j].nx)
if(i<e[j].t&&(i!=rtx||e[j].t!=rty))ins(H2,i,e[j].t);
ins(H2,rtx,++m);ins(H2,rty,m);rtx=m;
}
else for(i=;i<=m;++i)for(j=H1[i];j;j=e[j].nx)if(i<e[j].t)ins(H2,i,e[j].t);
solve(rtx,);
if(rty)for(i=;i<=n;++i)for(j=h[i];j;j=e[j].nx)if(i<e[j].t)
{
rw(ans,1LL*cal(i,e[j].t,v[t[m]][])*cal(e[j].t,i,v[t[m]][])%MOD);
if(v[t[m]][]!=v[t[m]][])
rw(ans,1LL*cal(i,e[j].t,v[t[m]][])*cal(e[j].t,i,v[t[m]][])%MOD);
}else;
else for(i=;i<=n;++i)rw(ans,cal(i,,t[rtx]));
printf("%d",ans);
}