每一次枚举到重心 按子树中的黑点数SORT一下 启发式合并
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
const int MAXN = 2e6 + ;
const int MAXM = 2e6 + ;
int to[MAXM << ], nxt[MAXM << ], Head[MAXN], ed = ;
int cost[MAXM << ];
const int INF = ~0u >> ;
inline void addedge(int u, int v, int c)
{
to[++ed] = v;
cost[ed] = c;
nxt[ed] = Head[u];
Head[u] = ed;
}
inline void ADD(int u, int v, int c)
{
addedge(u, v, c);
addedge(v, u, c);
}
inline const int readin()
{
int r = , k = ;
char c = getchar();
for (; c < '' || c > ''; c = getchar()) if (c == '-') {
k = -;
}
for (; c >= '' && c <= ''; c = getchar()) {
r = r * + c - '';
}
return k * r;
}
int n, k, kk, m, anser, cnt, maxdep, summaxdep;
int sz[MAXN], f[MAXN], dep[MAXN], sumsz, root;
bool vis[MAXN];
int ok[MAXN], blasz[MAXN];
int h[MAXN], g[MAXN];
struct node {
int blaval;
int id;
} o[MAXN];
bool cmp(node a, node b)
{
return a.blaval < b.blaval;
}
void getroot(int x, int fa)
{
sz[x] = ;
f[x] = ;
for (int i = Head[x]; i; i = nxt[i]) {
int v = to[i];
if (v == fa || vis[v]) {
continue;
}
getroot(v, x);
sz[x] += sz[v];
f[x] = max(f[x], sz[v]);
}
f[x] = max(f[x], sumsz - sz[x]);
if (f[x] < f[root]) {
root = x;
}
}
void update(int x, int blanum, int deep, int fa)
{
if (blanum > kk) {
return ;
}
h[blanum] = max(h[blanum], deep);
for (int i = Head[x]; i; i = nxt[i]) {
int v = to[i];
if (vis[v] || v == fa) {
continue;
}
update(v, blanum + ok[v], deep + cost[i], x);
}
}
void getdeep(int x, int fa)
{
blasz[x] = ok[x];
for (int i = Head[x]; i; i = nxt[i]) {
int v = to[i];
if (v == fa || vis[v]) {
continue;
}
getdeep(v, x);
blasz[x] += blasz[v];
}
}
void calc(int x, int d)
{
cnt = ;
for (int i = Head[x]; i; i = nxt[i]) {
int v = to[i];
if (vis[v]) {
continue;
}
getdeep(v, x);
node now;
now.blaval = blasz[v];
now.id = i;
o[++cnt] = now;
}
}
void solve(int x)
{
summaxdep = -;
kk = k - ok[x];
int s;
vis[x] = ;
calc(x, );
sort(o + , o + cnt + , cmp);
for (int i = ; i <= cnt; i++) {
maxdep = -;
int depnow = o[i].blaval;
int v = to[o[i].id];
int c = cost[o[i].id];
s = min(depnow, kk);
for (int j = ; j <= s; j++) {
h[j] = -INF;
}
update(v, ok[v], c, x);
if (i == ) {
for (int j = ; j <= s; j++) {
g[j] = h[j];
}
} else {
for (int j = ; j <= s; j++) {
int aim = kk - j;
aim = min(aim, summaxdep);
if (h[j] != -INF && g[aim] != -INF) {
anser = max(anser, h[j] + g[aim]);
}
}
for (int j = ; j <= s; j++) {
g[j] = max(h[j], g[j]);
}
}
summaxdep = s;
for (int j = ; j <= summaxdep; j++) {
g[j] = max(g[j], g[j - ]);
}
}
anser = max(anser, g[min(kk, summaxdep)]);
int totsz = sumsz;
for (int i = Head[x]; i; i = nxt[i]) {
int v = to[i];
if (vis[v]) {
continue;
}
root = ;
sumsz = sz[v] > sz[x] ? totsz - sz[x] : sz[v];
getroot(v, );
solve(root);
}
}
int main()
{
cnt = anser = ;
n = readin(), k = readin(), m = readin();
for (int now, i = ; i <= m; i++) {
now = readin();
ok[now] = ;
}
int u, v, c;
for (int i = ; i < n; i++) {
u = readin(), v = readin(), c = readin();
ADD(u, v, c);
}
root = , sumsz = f[] = n;
getroot(, );
solve(root);
printf("%d\n", anser);
return ;
}