A tree is a connected graph that doesn't contain any cycles.
The distance between two vertices of a tree is the length (in edges) of the shortest path between these vertices.
You are given a tree with n vertices and a positive number k. Find the number of distinct pairs of the vertices which have a distance of exactly k between them. Note that pairs (v, u) and (u, v) are considered to be the same pair.
The first line contains two integers n and k (1 ≤ n ≤ 50000, 1 ≤ k ≤ 500) — the number of vertices and the required distance between the vertices.
Next n - 1 lines describe the edges as "ai bi" (without the quotes) (1 ≤ ai, bi ≤ n, ai ≠ bi), where ai and bi are the vertices connected by the i-th edge. All given edges are different.
Print a single integer — the number of distinct pairs of the tree's vertices which have a distance of exactly k between them.
Please do not use the %lld specifier to read or write 64-bit integers in С++. It is preferred to use the cin, cout streams or the %I64d specifier.
5 2
1 2
2 3
3 4
2 5
4
5 3
1 2
2 3
3 4
4 5
2
In the first sample the pairs of vertexes at distance 2 from each other are (1, 3), (1, 5), (3, 5) and (2, 4).
这道题就是求树上距离为K的点对数量。以前写过<=K的点对数量,直接<=K的数量 - <K的数量,讲道理应该也是可以的,但是一直TLE11和TLE17样例。。。
最后换了一种写法,直接求,没有中间-子树的过程,最后过了,有点迷。。。
代码:
1 //树分治-点分治 2 #include<bits/stdc++.h> 3 using namespace std; 4 typedef long long ll; 5 //#pragma GCC optimize(2) 6 //#define FI(n) FastIO::read(n) 7 const int inf=1e9+7; 8 const int maxn=1e5+10; 9 const int maxm=500+10; 10 11 int head[maxn<<1],tot; 12 int root,allnode,n,m,k; 13 bool vis[maxn]; 14 int deep[maxn],dis[maxn],siz[maxn],maxv[maxn];//deep[0]子节点个数(路径长度),maxv为重心节点 15 int num[maxm],cnt[maxm]; 16 ll ans=0; 17 18 //namespace FastIO {//读入挂 19 // const int SIZE = 1 << 16; 20 // char buf[SIZE], obuf[SIZE], str[60]; 21 // int bi = SIZE, bn = SIZE, opt; 22 // int read(char *s) { 23 // while (bn) { 24 // for (; bi < bn && buf[bi] <= ' '; bi++); 25 // if (bi < bn) break; 26 // bn = fread(buf, 1, SIZE, stdin); 27 // bi = 0; 28 // } 29 // int sn = 0; 30 // while (bn) { 31 // for (; bi < bn && buf[bi] > ' '; bi++) s[sn++] = buf[bi]; 32 // if (bi < bn) break; 33 // bn = fread(buf, 1, SIZE, stdin); 34 // bi = 0; 35 // } 36 // s[sn] = 0; 37 // return sn; 38 // } 39 // bool read(int& x) { 40 // int n = read(str), bf; 41 // 42 // if (!n) return 0; 43 // int i = 0; if (str[i] == '-') bf = -1, i++; else bf = 1; 44 // for (x = 0; i < n; i++) x = x * 10 + str[i] - '0'; 45 // if (bf < 0) x = -x; 46 // return 1; 47 // } 48 //}; 49 50 inline int read() 51 { 52 int x=0;char ch=getchar(); 53 while(ch<'0'||ch>'9')ch=getchar(); 54 while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();} 55 return x; 56 } 57 58 struct node{ 59 int to,next,val; 60 }edge[maxn<<1]; 61 62 void add(int u,int v,int w)//前向星存图 63 { 64 edge[tot].to=v; 65 edge[tot].next=head[u]; 66 edge[tot].val=w; 67 head[u]=tot++; 68 } 69 70 void init()//初始化 71 { 72 memset(head,-1,sizeof head); 73 memset(vis,0,sizeof vis); 74 tot=0; 75 } 76 77 void get_root(int u,int father)//重心 78 { 79 siz[u]=1;maxv[u]=0; 80 for(int i=head[u];~i;i=edge[i].next){ 81 int v=edge[i].to; 82 if(v==father||vis[v]) continue; 83 get_root(v,u);//递归得到子树大小 84 siz[u]+=siz[v]; 85 maxv[u]=max(maxv[u],siz[v]);//更新u节点的maxv 86 } 87 maxv[u]=max(maxv[u],allnode-siz[u]);//保存节点size 88 if(maxv[u]<maxv[root]) root=u;//更新当前子树的重心 89 } 90 91 void get_dis(int u,int father)//获取子树所有节点与根的距离 92 { 93 if(dis[u]>k) return ; 94 ans+=num[k-dis[u]]; 95 cnt[dis[u]]++;//计数 96 for(int i=head[u];~i;i=edge[i].next){ 97 int v=edge[i].to; 98 if(v==father||vis[v]) continue; 99 int w=edge[i].val; 100 dis[v]=dis[u]+w; 101 get_dis(v,u); 102 } 103 } 104 105 void cal(int u,int now) 106 { 107 for(int i=1;i<=k;i++){//初始化,清空 108 num[i]=0; 109 } 110 num[0]=1; 111 for(int i=head[u];~i;i=edge[i].next){ 112 int v=edge[i].to; 113 if(vis[v]) continue; 114 for(int j=0;j<=k;j++){//初始化 115 cnt[j]=0; 116 } 117 dis[v]=now; 118 get_dis(v,u);//跑路径 119 for(int j=0;j<=k;j++){ 120 num[j]+=cnt[j];//计数 121 } 122 } 123 } 124 125 void solve(int u)//分治处理 126 { 127 vis[u]=1; 128 cal(u,1); 129 for(int i=head[u];~i;i=edge[i].next){ 130 int v=edge[i].to; 131 int w=edge[i].val; 132 if(vis[v]) continue; 133 allnode=siz[v]; 134 root=0; 135 get_root(v,u); 136 solve(root); 137 } 138 } 139 140 int main() 141 { 142 // FI(n);FI(k); 143 n=read();k=read(); 144 init(); 145 for(int i=1;i<n;i++){ 146 int u,v,w;w=1; 147 // FI(u);FI(v); 148 u=read();v=read(); 149 add(u,v,w); 150 add(v,u,w); 151 } 152 root=0;allnode=n;maxv[0]=inf; 153 get_root(1,0); 154 solve(root); 155 printf("%lld\n",ans); 156 return 0; 157 }