题目大意:
给你一颗$n(n\le5000)$个点的树,选3个点使得它们两两距离相等,问共有几种选法。
思路:
首先我们不难发现一个性质:对于每3个符合条件的点,我们总能找到一个点使得这个点到那3个点距离相等。
我们不妨称之为“中转点”。
显然答案就是对于每个中转点,不同子树中到这个点距离相等的三元点对的数量。
我们可以先枚举每个点作为中转点的情况。
暴力求出以这个点的每个子结点为根的子树,不同深度的结点的数量(显然深度就是到这个中转点的距离)。
我们可以用calc[i][j]表示对于当前中转点,来自j个不同子树的深度为i的结点共有多少种不同的组合。
转移方程为calc[i][j]+=calc[i][j-1]*cnt[i]。
#include<cstdio>
#include<cctype>
#include<vector>
#include<cstring>
typedef long long int64;
inline int getint() {
register char ch;
while(!isdigit(ch=getchar()));
register int x=ch^'';
while(isdigit(ch=getchar())) x=(((x<<)+x)<<)+(ch^'');
return x;
}
const int N=;
std::vector<int> e[N];
inline void add_edge(const int &u,const int &v) {
e[u].push_back(v);
e[v].push_back(u);
}
int n,cnt[N];
int64 calc[N][];
void dfs(const int &x,const int &par,const int &dep) {
cnt[dep]++;
for(unsigned i=;i<e[x].size();i++) {
const int &y=e[x][i];
if(y==par) continue;
dfs(y,x,dep+);
}
}
int main() {
n=getint();
for(register int i=;i<n;i++) {
add_edge(getint(),getint());
}
int64 ans=;
for(register int x=;x<=n;x++) {
memset(calc,,sizeof calc);
for(register int i=;i<=n;i++) calc[i][]=;
for(register unsigned i=;i<e[x].size();i++) {
memset(cnt,,sizeof cnt);
const int &y=e[x][i];
dfs(y,x,);
for(register int j=;j;j--) {
for(register int i=;i<=n;i++) {
calc[i][j]+=calc[i][j-]*cnt[i];
}
}
}
for(register int i=;i<=n;i++) {
ans+=calc[i][];
}
}
printf("%lld\n",ans);
return ;
} 现在考虑当$n\le10^5$的情况。
考虑$n\le10^5$的情况。
$f[i][j]$标示以$i$为根的子树中,与$i$距离为$j$的点数。$g[i][j]$标示以$i$为根的子树中,与$i$距离为$j$的点对数。则不难想到一种$O(n^2)$的转移:
$$
\begin{align*}
&g[x][i-1]+=g[y][i]\\
&g[x][i+1]+=f[x][i+1]\times f[y][i]\\
&f[x][i+1]+=f[y][i]
\end{align*}
$$
边界为$f[x][0]=1$。
考虑优化这个转移,不难发现,若$y$是$x$枚举到的第一个子结点,则转移时只进行第一、第三个转移。因此我们可以考虑通过指针来实现,免去转移的过程。
将原树进行长链剖分,对于重边直接修改指针,对于轻边暴力转移,可以证明这样是$O(n)$的。
#include<list>
#include<cstdio>
#include<cctype>
typedef long long int64;
inline int getint() {
register char ch;
while(!isdigit(ch=getchar()));
register int x=ch^'';
while(isdigit(ch=getchar())) x=(((x<<)+x)<<)+(ch^'');
return x;
}
const int N=;
std::list<int> e[N];
int dep[N],bot[N];
int64 mem[N*],ans,*f[N],*g[N],*ptr=mem;
inline void add_edge(const int &u,const int &v) {
e[u].push_back(v);
e[v].push_back(u);
}
void dfs(const int &x,const int &par) {
dep[bot[x]=x]=dep[par]+;
for(std::list<int>::iterator i=e[x].begin();i!=e[x].end();i++) {
const int &y=*i;
if(y==par) continue;
dfs(y,x);
if(dep[bot[y]]>dep[bot[x]]) bot[x]=bot[y];
}
for(register std::list<int>::iterator i=e[x].begin();i!=e[x].end();i++) {
const int &y=*i;
if(y==par||(bot[y]==bot[x]&&x!=)) continue;
f[bot[y]]=ptr+=dep[bot[y]]-dep[x]+;
g[bot[y]]=++ptr;
ptr+=(dep[bot[y]]-dep[x])*+;
}
}
void dp(const int &x,const int &par) {
for(std::list<int>::iterator i=e[x].begin();i!=e[x].end();i++) {
const int &y=*i;
if(y==par) continue;
dp(y,x);
if(bot[y]==bot[x]) {
f[x]=f[y]-;
g[x]=g[y]+;
}
}
ans+=g[x][];
f[x][]=;
for(register std::list<int>::iterator i=e[x].begin();i!=e[x].end();i++) {
const int &y=*i;
if(y==par||bot[y]==bot[x]) continue;
for(register int i=;i<=dep[bot[y]]-dep[x];i++) {
ans+=f[x][i-]*g[y][i]+g[x][i+]*f[y][i];
}
for(register int i=;i<=dep[bot[y]]-dep[x];i++) {
g[x][i-]+=g[y][i];
g[x][i+]+=f[x][i+]*f[y][i];
f[x][i+]+=f[y][i];
}
}
}
int main() {
const int n=getint();
for(register int i=;i<n;i++) {
add_edge(getint(),getint());
}
dfs(,);
dp(,);
printf("%lld\n",ans);
return ;
}