大概就是二分+树上差分...
题意:给你树上m条路径,你要把一条边权变为0,使最长的路径最短。
最大的最小,看出二分(事实上我并没有看出来...)
然后二分k,对于所有大于k的边,树上差分求出最长公共边,然后看是否可以。
(yy的解法②:边按照长度排序,然后二分。删除最长公共边。据logeadd juru说是三分)
代码量3.6k,180行,还是有点长的。
#include <cstdio>
#include <algorithm>
#include <cstring>
const int N = ; inline void read(int &x) {
char c = getchar();
x = ;
while(c > '' || c < '') {
c = getchar();
}
while(c <= '' && c >= '') {
x = (x << ) + (x << ) + c - ;
c = getchar();
}
return;
} struct Edge {
int v, nex, len;
}edge[N << ]; int top; int e[N], n, m, lm, fa[N][], d[N], lenth[N]; /// 点
int l[N], r[N], mid[N], len[N]; /// 路径
bool use[N]; /// 树上差分
int num, large, R, f[N]; inline void add(int x, int y, int z) {
edge[++top].v = y;
edge[top].len = z;
edge[top].nex = e[x];
e[x] = top;
return;
} inline void DFS1(int x, int f) {
fa[x][] = f;
for(int i = e[x]; i; i = edge[i].nex) {
int y = edge[i].v;
if(y != f) {
lenth[y] = lenth[x] + edge[i].len;
d[y] = d[x] + ;
DFS1(y, x);
}
}
return;
} inline void getlca() {
while(( << lm) < n) {
lm++;
}
DFS1(, );
for(int i = ; i <= lm; i++) {
for(int x = ; x <= n; x++) {
fa[x][i] = fa[fa[x][i - ]][i - ];
}
}
return;
} inline int lca(int x, int y) {
if(d[x] > d[y]) {
std::swap(x, y);
}
int t = lm;
while(t > - && d[y] > d[x]) {
if(d[fa[y][t]] >= d[x]) {
y = fa[y][t];
}
t--;
}
if(x == y) {
return x;
}
t = lm;
while(t > - && fa[x][] != fa[y][]) {
if(fa[x][t] != fa[y][t]) {
x = fa[x][t];
y = fa[y][t];
}
t--;
}
return fa[x][];
} inline int DFS(int x) {
int cnt = ;
for(int i = e[x]; i; i = edge[i].nex) {
int y = edge[i].v;
if(y == fa[x][]) {
continue;
}
int temp = DFS(y);
cnt += temp;
if(temp == num) {
large = std::max(large, edge[i].len);
}
}
cnt += f[x];
return cnt;
} inline bool check(int k) {
num = ;
memset(f, , sizeof(f));
for(int i = ; i <= m; i++) {
bool t = len[i] > k;
use[i] = t;
num += t;
if(t) {
f[l[i]]++;
f[r[i]]++;
f[mid[i]] -= ;
}
}
large = ;
DFS();
return R - large <= k;
} inline int getlong(int i) {
int x = l[i];
int ans = ;
while(x != mid[i]) {
ans = std::max(ans, lenth[x] - lenth[fa[x][]]);
x = fa[x][];
}
x = r[i];
while(x != mid[i]) {
ans = std::max(ans, lenth[x] - lenth[fa[x][]]);
x = fa[x][];
}
return ans;
} int main() {
scanf("%d%d", &n, &m);
int x, y, z;
for(int i = ; i < n; i++) {
//scanf("%d%d%d", &x, &y, &z);
read(x);
read(y);
read(z);
add(x, y, z);
add(y, x, z);
}
getlca();
int dr = , dl = , dm, A = ;
for(int i = ; i <= m; i++) {
//scanf("%d%d", &l[i], &r[i]);
read(l[i]);
read(r[i]);
mid[i] = lca(l[i], r[i]);
len[i] = lenth[l[i]] + lenth[r[i]] - * lenth[mid[i]];
if(len[i] > dr) {
dr = len[i];
A = i;
}
}
R = dr;
dl = dr - getlong(A);
if(dl < ) {
printf("ERROR ");
}
while(dl < dr) {
dm = (dr + dl) / ;
if(check(dm)) {
//printf("check %d 1 \n", dm);
dr = dm;
}
else {
//printf("check %d 0 \n", dm);
dl = dm + ;
}
}
printf("%d", dr);
return ;
}
AC代码