题目链接 Free tour II

题意:有$N$个顶点的树,节点间有权值, 节点分为黑点和白点。 找一条最长路径使得 路径上黑点数量不超过K个

这是树的点分治比较基本的题,涉及树上启发式合并……仰望了黄学长的博客之后稍微有点明白了(还没有很深入地理解)

#include <bits/stdc++.h>

using namespace std;

#define rep(i, a, b)	for (int i(a); i <= (b); ++i)
#define dec(i, a, b) for (int i(a); i >= (b); --i) const int N = 200010; int ans, n, k, m, cnt, root, sum, deep_mx, x;
int sz[N], f[N], deep[N], dis[N], tmp[N], mx[N];
bool color[N], vis[N];
vector <pair <int, int > > v[N];
vector <pair <int, int > > st; void getroot(int x, int fa){
sz[x] = 1; f[x] = 0;
for (auto u : v[x]){
int to = u.first;
if (vis[to] || to == fa) continue;
getroot(to, x);
f[x] = max(f[x], sz[to]);
sz[x] += sz[to];
} f[x] = max(f[x], sum - sz[x]);
if (f[x] < f[root]) root = x;
} void getdis(int x, int fa){
deep_mx = max(deep_mx, deep[x]);
for (auto u : v[x]){
int to = u.first;
if (vis[to] || to == fa) continue;
deep[to] = deep[x] + color[to];
dis[to] = dis[x] + u.second;
getdis(to, x);
}
} void getmx(int x, int fa){
tmp[deep[x]] = max(tmp[deep[x]], dis[x]);
for (auto u : v[x]){
int to = u.first;
if (vis[to] || to == fa) continue;
getmx(to, x);
}
} void solve(int x){
vis[x] = 1;
st.clear();
if (color[x]) --k;
for (auto u : v[x]){
int to = u.first;
if (vis[to]) continue;
deep_mx = 0;
deep[to] = color[to];
dis[to] = u.second;
getdis(to, x);
st.push_back({deep_mx, to});
} sort(st.begin(), st.end());
for (int i = 0; i < (int)st.size(); ++i){
getmx(st[i].second, x);
int now = 0;
if (i != 0)
dec(j, st[i].first, 0){
while (now + j < k && now < st[i - 1].first)
++now, mx[now] = max(mx[now], mx[now - 1]);
if (now + j <= k) ans = max(ans, mx[now] + tmp[j]);
}
if (i != (int)st.size() - 1)
rep(j, 0, (int)st[i].first)
mx[j] = max(mx[j], tmp[j]), tmp[j] = 0;
else
rep(j, 0, (int)st[i].first){
if (j <= k) ans = max(ans, max(tmp[j], mx[j]));
tmp[j] = mx[j] = 0;
}
} if (color[x]) ++k;
for (auto u : v[x]){
int to = u.first;
if (vis[to]) continue;
root = 0;
sum = sz[to];
getroot(to, x);
solve(root);
}
} int main(){ scanf("%d%d%d", &n, &k, &m);
rep(i, 1, m){
scanf("%d", &x);
color[x] = 1;
} rep(i, 1, n - 1){
int x, y, z;
scanf("%d%d%d", &x, &y, &z);
v[x].push_back({y, z});
v[y].push_back({x, z});
} sum = n; f[0] = n;
getroot(1, 0);
solve(root);
printf("%d\n", ans);
return 0;
}
05-15 02:17