题意:求一颗树上三点距离两两相等的三元组对数

n<=1e5

思路:From https://blog.bill.moe/bzoj4543-hotel/

f[i][j]表示以i为根的子树中距离i为j的点的个数

g[i][j]表示以i为根的子树中两点距离他们的lca为d,lca距离i为d-j的两点对数

g[i][j]找到一个子树外的f[i][j]就对答案有贡献

朴素的方程为:设v为u的一个儿子

ans+=f[u][j]*g[v][j+1]+g[u][j]*f[y][j-1]

g[u][j+1]+=f[u][j+1]*f[v][j]

g[u][j-1]+=g[v][j]

f[u][j+1]+=f[v][j]

显然f[i][j]只和深度有关,且f[u]的[1,len[u]]这一段是所有f[v]的[0,len[u]-1]右移一位之和

为了防止同一个子树中的信息算多了,先算ans部分再执行后面三步更新

指针的写法我完全是抄的

  1 #include<bits/stdc++.h>
  2 using namespace std;
  3 typedef long long ll;
  4 typedef unsigned int uint;
  5 typedef unsigned long long ull;
  6 typedef pair<int,int> PII;
  7 typedef pair<ll,ll> Pll;
  8 typedef vector<int> VI;
  9 typedef vector<PII> VII;
 10 typedef pair<ll,int>P;
 11 #define N  100010
 12 #define M  200010
 13 #define fi first
 14 #define se second
 15 #define MP make_pair
 16 #define pi acos(-1)
 17 #define mem(a,b) memset(a,b,sizeof(a))
 18 #define rep(i,a,b) for(int i=(int)a;i<=(int)b;i++)
 19 #define per(i,a,b) for(int i=(int)a;i>=(int)b;i--)
 20 #define lowbit(x) x&(-x)
 21 #define Rand (rand()*(1<<16)+rand())
 22 #define id(x) ((x)<=B?(x):m-n/(x)+1)
 23 #define ls p<<1
 24 #define rs p<<1|1
 25
 26 const ll MOD=1e9+7,inv2=(MOD+1)/2;
 27       double eps=1e-6;
 28       int INF=1<<30;
 29       ll inf=5e13;
 30       int dx[4]={-1,1,0,0};
 31       int dy[4]={0,0,-1,1};
 32
 33 int head[M],vet[M],nxt[M],tot;
 34 int len[N],son[N];
 35 ll tmp[N*5],*f[N],*g[N],*now=tmp,ans;
 36
 37 int read()
 38 {
 39    int v=0,f=1;
 40    char c=getchar();
 41    while(c<48||57<c) {if(c=='-') f=-1; c=getchar();}
 42    while(48<=c&&c<=57) v=(v<<3)+v+v+c-48,c=getchar();
 43    return v*f;
 44 }
 45
 46 void add(int a,int b)
 47 {
 48     nxt[++tot]=head[a];
 49     vet[tot]=b;
 50     head[a]=tot;
 51 }
 52
 53 void dfs(int u,int fa,int d)
 54 {
 55     len[u]=0;
 56     int e=head[u];
 57     while(e)
 58     {
 59         int v=vet[e];
 60         if(v!=fa)
 61         {
 62             dfs(v,u,d+1);
 63             if(len[v]>len[son[u]])
 64             {
 65                 son[u]=v;
 66                 len[u]=len[v]+1;
 67             }
 68         }
 69         e=nxt[e];
 70     }
 71 }
 72
 73 void solve(int u,int fa)
 74 {
 75     if(son[u])
 76     {
 77         f[son[u]]=f[u]+1;
 78         g[son[u]]=g[u]-1;
 79         solve(son[u],u);
 80     }
 81     f[u][0]=1;
 82     ans+=g[u][0];
 83     int e=head[u];
 84     while(e)
 85     {
 86         int v=vet[e];
 87         if(v!=fa&&v!=son[u])
 88         {
 89             f[v]=now;
 90             now+=(len[v]<<1)+2;
 91             g[v]=now;
 92             now+=(len[v]<<1)+2;
 93             solve(v,u);
 94             rep(j,0,len[v])
 95             {
 96                 if(j) ans+=f[u][j-1]*g[v][j];
 97                 ans+=g[u][j+1]*f[v][j];
 98             }
 99             rep(j,0,len[v])
100             {
101                 g[u][j+1]+=f[u][j+1]*f[v][j];
102                 if(j) g[u][j-1]+=g[v][j];
103                 f[u][j+1]+=f[v][j];
104             }
105         }
106         e=nxt[e];
107     }
108 }
109 int main()
110 {
111     int n=read();
112     tot=0;
113     rep(i,1,n-1)
114     {
115         int x=read(),y=read();
116         add(x,y);
117         add(y,x);
118     }
119     len[0]=-1;
120     dfs(1,0,1);
121     ans=0;
122     f[1]=now,now+=(len[1]<<1)+2,g[1]=now,now+=(len[1]<<1)+2;
123     solve(1,0);
124     printf("%lld\n",ans);
125     return 0;
126 }
01-16 00:40