将两棵树分别按照深度分块,即每$\sqrt{n}$深度分为一块
对于两棵树任意两个块的根对,统计其中公共的节点并计算出答案(用按秩合并并查集),复杂度为$o(n\sqrt{n}\logn)$
(其实这玩意是可以被卡掉的,因为这种分块无法保证块的大小和数量)
1 #include<bits/stdc++.h> 2 using namespace std; 3 #define N 20005 4 #define K 505 5 struct ji{ 6 int nex,to; 7 }edge[N<<1]; 8 struct ji2{ 9 int son,fa,sh; 10 }; 11 vector<ji2>v; 12 int E,t,n,m,x,y,xx[N],yy[N],head[N],ff[N],bl[N],ans[N],f[N*5],sh[N*5]; 13 int find(int k){ 14 if (k==f[k])return k; 15 return find(f[k]); 16 } 17 ji2 update(int x,int y){ 18 x=find(x); 19 y=find(y); 20 if (x==y)return ji2{0,0,0}; 21 ans[0]++; 22 if (sh[x]>sh[y])swap(x,y); 23 ji2 o={x,y,sh[y]}; 24 f[x]=y; 25 sh[y]=max(sh[y],sh[x]+1); 26 return o; 27 } 28 void clean(ji2 o){ 29 if (!o.son)return; 30 f[o.son]=o.son; 31 sh[o.fa]=o.sh; 32 ans[0]--; 33 } 34 void add(int x,int y){ 35 edge[E].nex=head[x]; 36 edge[E].to=y; 37 head[x]=E++; 38 } 39 void tj(int k,int fa,int p){ 40 ji2 o=update(xx[k],yy[k]); 41 if (bl[k-n]==p){ 42 v.clear(); 43 for(int i=k-n;bl[i]==p;i=ff[i])v.push_back(update(xx[i],yy[i])); 44 ans[k-n]=m-ans[0]; 45 for(int i=0;i<v.size();i++)clean(v[i]); 46 } 47 for(int i=head[k];i!=-1;i=edge[i].nex) 48 if (edge[i].to!=fa)tj(edge[i].to,k,p); 49 clean(o); 50 } 51 void dfs(int k,int fa,int sh){ 52 ff[k]=fa; 53 if (sh%K)bl[k]=bl[fa]; 54 else bl[k]=++x; 55 for(int i=head[k];i!=-1;i=edge[i].nex) 56 if (edge[i].to!=fa)dfs(edge[i].to,k,sh+1); 57 } 58 void calc(int k,int fa,int p){ 59 ji2 o=update(xx[k],yy[k]); 60 if (bl[k]!=bl[fa]){ 61 if (k<=n)calc(n+1,0,bl[k]); 62 else tj(k,fa,p); 63 } 64 for(int i=head[k];i!=-1;i=edge[i].nex) 65 if (edge[i].to!=fa)calc(edge[i].to,k,p); 66 clean(o); 67 } 68 int main(){ 69 scanf("%d",&t); 70 while (t--){ 71 scanf("%d%d",&n,&m); 72 E=0; 73 memset(head,-1,sizeof(head)); 74 memset(sh,0,sizeof(sh)); 75 for(int i=1;i<=m;i++)f[i]=i; 76 for(int i=1;i<=n;i++)scanf("%d%d",&xx[i],&yy[i]); 77 for(int i=1;i<n;i++){ 78 scanf("%d%d",&x,&y); 79 add(x,y); 80 add(y,x); 81 } 82 for(int i=n+1;i<=2*n;i++)scanf("%d%d",&xx[i],&yy[i]); 83 for(int i=1;i<n;i++){ 84 scanf("%d%d",&x,&y); 85 add(x+n,y+n); 86 add(y+n,x+n); 87 } 88 x=0; 89 dfs(1,0,0); 90 x=0; 91 dfs(n+1,0,0); 92 calc(1,0,0); 93 for(int i=1;i<=n;i++)printf("%d\n",ans[i]); 94 } 95 }