第一次接触树分治,看了论文又照挑战上抄的代码,也就理解到这个层次了。。
以后做题中再慢慢体会学习。
题目链接:
http://poj.org/problem?id=1741
题意:
给定树和树边的权重,求有多少对顶点之间的边的权重之和小于等于K。
分析:
树分治。
直接枚举不可,我们将树划分成若干子树。
那么两个顶点有两种情况:
- u,v属于同一子树的顶点对
- u,v属于不同子树的顶点对
第一种情况,对子树递归即可求得。
第二种情况,从u到v的路径必然经过了顶点s,只要先求出每个顶点到s的距离再做统计即可。(注意在第二种情况中减去第一种重复计算的部分)
当树退化成链的形式时,递归的深度则退化为O(n),所以选择每次都找到树的重心作为分隔顶点。重心就是删掉此结点后得到的最大子树的顶点数最少的顶点,删除重心后得到的所有子树顶点数必然不超过n/2。
查找重心的时候,假设根为v,先在v的子树中找到一个顶点,使删除该顶点后的最大子树的顶点数最少,然后考虑删除v的情况,获得最大子树的顶点数。
两者选择最小的一个,此时选中的顶点即为重心。
递归的每一层都做了排序O(nlogn),递归深度O(logn),总体时间复杂度O(nlog2n)。
代码:
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<vector>
#define sa(a) scanf("%d", &a)
#define mem(a,b) memset(a, b, sizeof(a))
using namespace std;
const int maxn = 1e4 + 5, oo = 0x3f3f3f3f;
int cnt[maxn], vis[maxn];
int K, ans;
struct EDGE{int to; int length;int next;};
int head[maxn];
EDGE edge[maxn * 2];
typedef pair<int, int>pii;
int tot = 0;
void addedge(int u, int v, int l)
{
edge[tot].to = v;
edge[tot].length = l;
edge[tot].next = head[u];
head[u] = tot++;
edge[tot].to = u;
edge[tot].length = l;
edge[tot].next = head[v];
head[v] = tot++;
}
int countsubtree(int v, int p)
{
int ans = 1;
for(int i = head[v]; i != -1; i = edge[i].next){
int w = edge[i].to;
if(w == p||vis[w]) continue;
ans += countsubtree(w, v);
}
return cnt[v] = ans;
}
pii findc(int v, int p, int t)
{
pii res = pii(oo, 0);
int s = 1, m = 1;
for(int i = head[v]; i != -1; i = edge[i].next){
int w = edge[i].to;
if(w == p || vis[w]) continue;
res = min(res, findc(w, v, t));
m = max(m, cnt[w]);
s += cnt[w];
}
m = max(m, t - s);
res = min(res, pii(m, v));
return res;
}
void findpath(int v, int p, int d, vector<int>&ds)
{
ds.push_back(d);
for(int i = head[v]; i != -1; i = edge[i].next){
int w = edge[i].to;
if(w == p || vis[w]) continue;
findpath(w, v, d +edge[i].length, ds);
}
}
int count_pair(vector<int>&ds)
{
int res = 0;
sort(ds.begin(), ds.end());
int j = ds.size() - 1;
int i = 0;
while(i < j){
while(j > i &&ds[i] + ds[j] > K) j--;
res += j - i;
i++;
}
return res;
/*
int j = ds.size();
for(int i = 0; i < ds.size(); i++){
while(j > 0 && ds[i] + ds[j - 1] > K) j--;
res += j - (j > i?1:0);
}
return res / 2;*/
}
void solve(int v)
{
vector<int>ds;
countsubtree(v, -1);
int s = findc(v, -1, cnt[v]).second;
vis[s] =true;
//(1)
for(int i = head[s]; i != -1; i = edge[i].next){
int w = edge[i].to;
if(vis[w]) continue;
solve(w);
}
//(2)
ds.push_back(0);
for(int i = head[s]; i != -1; i = edge[i].next){
int w = edge[i].to;
if(vis[w]) continue;
vector<int>ts;
findpath(w, s, edge[i].length, ts);
ans -= count_pair(ts);
ds.insert(ds.end(), ts.begin(), ts.end());
}
vis[s] = false;
ans += count_pair(ds);
}
void init()
{
tot = 0;
ans = 0;
mem(head, -1);
mem(vis, 0);
mem(cnt, 0);
}
int main (void)
{
int n;
while(scanf("%d%d", &n, &K)== 2 && n + K){
int u, v, l;
init();
for(int i = 0; i < n - 1; i++){
sa(u),sa(v),sa(l);
addedge(u, v, l);
}
solve(1);
printf("%d\n", ans);
}
return 0;
}