JSOI 2016 独特的树叶

题面描述

有两颗大小分别为\(n\)\(n+1\)的树,问删除后者的一个叶子,两者能否同构。如果能,输出最小的叶子标号使得删去这个叶子两者同构。

数据范围:\(n\le 10^5\)

思路

用一些玄学方法把第一棵树每个点及其子树\(hash\)起来,再\(up-down\)一下,求出每个点作为根这整棵树的\(hash\)值,存入\(map\)中。

再用类似的方法操作第二棵树即可。

关于树哈希的方法,建议使用异或,这样不仅撤销方便,而且无视顺序,无需排序。

代码

#include<bits/stdc++.h>
#define pii pair<int,int>
#define mkp(x,y) make_pair(x,y)
#define go(x) for(int i=head[x];i;i=edge[i].nxt)
#define now edge[i].v
using namespace std;
const int sz=1e5+7;
const int p1=1e8+7;
const int p2=998244353;
const int q1=52437;
const int q2=9813475;
int n;
int ans;
int u,v,cnt;
int r[sz];
int siz[sz];
int head[sz];
int w[sz][2];
int hs[sz][2];
map<pii,int>mp;
struct Edge{
    int v,nxt;
}edge[sz<<1];
void make_edge(int u,int v){
    edge[++cnt]=(Edge){v,head[u]};head[u]=cnt;
    edge[++cnt]=(Edge){u,head[v]};head[v]=cnt;
}
void dfs(int x,int fa){
    siz[x]=1;
    go(x) if(now!=fa){
        dfs(now,x);
        siz[x]+=siz[now];
        hs[x][0]^=1ll*w[siz[now]][0]*hs[now][0]%p1;
        hs[x][1]^=1ll*w[siz[now]][1]*hs[now][1]%p2;
    }
    hs[x][0]^=siz[x];
    hs[x][1]^=siz[x];
}
void Dfs(int x,int fa,int n0,int n1){
    mp[mkp(n0,n1)]=1;
    int k0,k1,t0,t1;
    go(x) if(now!=fa){
        k0=n0^n^(n-siz[now])^(1ll*w[siz[now]][0]*hs[now][0]%p1);
        k1=n1^n^(n-siz[now])^(1ll*w[siz[now]][1]*hs[now][1]%p2);
        t0=hs[now][0]^siz[now]^n^(1ll*w[n-siz[now]][0]*k0%p1);
        t1=hs[now][1]^siz[now]^n^(1ll*w[n-siz[now]][1]*k1%p2);
        Dfs(now,x,t0,t1);
    }
}
void getans(int x,int fa,int n0,int n1){
    int k0,k1,t0,t1;
    go(x) if(now!=fa){
        if(siz[now]==1){
            t0=n0^n^(n-1)^(1ll*w[siz[now]][0]*hs[now][0]%p1);
            t1=n1^n^(n-1)^(1ll*w[siz[now]][1]*hs[now][1]%p2);
            if(mp.find(mkp(t0,t1))!=mp.end()) ans=min(ans,now);
            continue;
        }
        k0=n0^n^(n-siz[now])^(1ll*w[siz[now]][0]*hs[now][0]%p1);
        k1=n1^n^(n-siz[now])^(1ll*w[siz[now]][1]*hs[now][1]%p2);
        t0=hs[now][0]^siz[now]^n^(1ll*w[n-siz[now]][0]*k0%p1);
        t1=hs[now][1]^siz[now]^n^(1ll*w[n-siz[now]][1]*k1%p2);
        getans(now,x,t0,t1);
    }
}
int main(){
    scanf("%d",&n);
    w[0][0]=w[0][1]=1;
    for(int i=1;i<=n+1;i++){
        w[i][0]=1ll*w[i-1][0]*q1%p1;
        w[i][1]=1ll*w[i-1][1]*q2%p2;
    }
    for(int i=1;i<n;i++){
        scanf("%d%d",&u,&v);
        make_edge(u,v);
    }
    dfs(1,0);
    Dfs(1,0,hs[1][0],hs[1][1]);
    n++;
    cnt=0;
    ans=INT_MAX;
    memset(head,0,sizeof(head));
    memset(hs,0,sizeof(hs));
    for(int i=1;i<n;i++){
        scanf("%d%d",&u,&v);
        r[u]++,r[v]++;
        make_edge(u,v);
    }
    dfs(1,0);
    getans(1,0,hs[1][0],hs[1][1]);
    if(r[1]==1){
        int v=edge[head[1]].v;
        if(mp.find(mkp(hs[v][0],hs[v][1]))!=mp.end()) ans=min(ans,1);
    }
    if(ans!=INT_MAX) printf("%d\n",ans);
    else puts("-1");
}
12-14 09:36