题意:求一颗树上三点距离两两相等的三元组对数
n<=1e5
思路:From https://blog.bill.moe/bzoj4543-hotel/
f[i][j]表示以i为根的子树中距离i为j的点的个数
g[i][j]表示以i为根的子树中两点距离他们的lca为d,lca距离i为d-j的两点对数
g[i][j]找到一个子树外的f[i][j]就对答案有贡献
朴素的方程为:设v为u的一个儿子
ans+=f[u][j]*g[v][j+1]+g[u][j]*f[y][j-1]
g[u][j+1]+=f[u][j+1]*f[v][j]
g[u][j-1]+=g[v][j]
f[u][j+1]+=f[v][j]
显然f[i][j]只和深度有关,且f[u]的[1,len[u]]这一段是所有f[v]的[0,len[u]-1]右移一位之和
为了防止同一个子树中的信息算多了,先算ans部分再执行后面三步更新
指针的写法我完全是抄的
1 #include<bits/stdc++.h> 2 using namespace std; 3 typedef long long ll; 4 typedef unsigned int uint; 5 typedef unsigned long long ull; 6 typedef pair<int,int> PII; 7 typedef pair<ll,ll> Pll; 8 typedef vector<int> VI; 9 typedef vector<PII> VII; 10 typedef pair<ll,int>P; 11 #define N 100010 12 #define M 200010 13 #define fi first 14 #define se second 15 #define MP make_pair 16 #define pi acos(-1) 17 #define mem(a,b) memset(a,b,sizeof(a)) 18 #define rep(i,a,b) for(int i=(int)a;i<=(int)b;i++) 19 #define per(i,a,b) for(int i=(int)a;i>=(int)b;i--) 20 #define lowbit(x) x&(-x) 21 #define Rand (rand()*(1<<16)+rand()) 22 #define id(x) ((x)<=B?(x):m-n/(x)+1) 23 #define ls p<<1 24 #define rs p<<1|1 25 26 const ll MOD=1e9+7,inv2=(MOD+1)/2; 27 double eps=1e-6; 28 int INF=1<<30; 29 ll inf=5e13; 30 int dx[4]={-1,1,0,0}; 31 int dy[4]={0,0,-1,1}; 32 33 int head[M],vet[M],nxt[M],tot; 34 int len[N],son[N]; 35 ll tmp[N*5],*f[N],*g[N],*now=tmp,ans; 36 37 int read() 38 { 39 int v=0,f=1; 40 char c=getchar(); 41 while(c<48||57<c) {if(c=='-') f=-1; c=getchar();} 42 while(48<=c&&c<=57) v=(v<<3)+v+v+c-48,c=getchar(); 43 return v*f; 44 } 45 46 void add(int a,int b) 47 { 48 nxt[++tot]=head[a]; 49 vet[tot]=b; 50 head[a]=tot; 51 } 52 53 void dfs(int u,int fa,int d) 54 { 55 len[u]=0; 56 int e=head[u]; 57 while(e) 58 { 59 int v=vet[e]; 60 if(v!=fa) 61 { 62 dfs(v,u,d+1); 63 if(len[v]>len[son[u]]) 64 { 65 son[u]=v; 66 len[u]=len[v]+1; 67 } 68 } 69 e=nxt[e]; 70 } 71 } 72 73 void solve(int u,int fa) 74 { 75 if(son[u]) 76 { 77 f[son[u]]=f[u]+1; 78 g[son[u]]=g[u]-1; 79 solve(son[u],u); 80 } 81 f[u][0]=1; 82 ans+=g[u][0]; 83 int e=head[u]; 84 while(e) 85 { 86 int v=vet[e]; 87 if(v!=fa&&v!=son[u]) 88 { 89 f[v]=now; 90 now+=(len[v]<<1)+2; 91 g[v]=now; 92 now+=(len[v]<<1)+2; 93 solve(v,u); 94 rep(j,0,len[v]) 95 { 96 if(j) ans+=f[u][j-1]*g[v][j]; 97 ans+=g[u][j+1]*f[v][j]; 98 } 99 rep(j,0,len[v]) 100 { 101 g[u][j+1]+=f[u][j+1]*f[v][j]; 102 if(j) g[u][j-1]+=g[v][j]; 103 f[u][j+1]+=f[v][j]; 104 } 105 } 106 e=nxt[e]; 107 } 108 } 109 int main() 110 { 111 int n=read(); 112 tot=0; 113 rep(i,1,n-1) 114 { 115 int x=read(),y=read(); 116 add(x,y); 117 add(y,x); 118 } 119 len[0]=-1; 120 dfs(1,0,1); 121 ans=0; 122 f[1]=now,now+=(len[1]<<1)+2,g[1]=now,now+=(len[1]<<1)+2; 123 solve(1,0); 124 printf("%lld\n",ans); 125 return 0; 126 }