一道树的直径

树网的核 BZOJ原题链接

树网的核 洛谷原题链接

消防 BZOJ原题链接

消防 洛谷原题链接

一份代码四倍经验,爽

显然要先随便找一条直径,然后直接枚举核的两个端点,对每一次枚举的核遍历核上的每个点,用\(dfs\)求出核外节点到核的最大值即可,时间复杂度为\(O(n^3)\),这在\(NOIP\)的原数据范围下是可以过的,但对于数据加强版就必须要优化了。

发现当枚举到直径上的某个点时,核的另一端在不超过\(s\)的前提下显然越远越好。这样就直接优化掉一个\(n\)了,但我们还可以继续优化。

设直径上的点为\(u_1,u_2,\dots,u_t\),当前枚举到的核的两端点为\(x_i,x_j\)。

根据直径的最长性,我们可以发现对于该核的偏心距实际上就是\(\max\{\max\limits_{k=1}^{t}\{d[u_k]\},dis(u_1,x_i),dis(x_j,u_t)\}\),数组\(d\)表示直径外节点(不经过直径上的点)到\(u_k\)的最大值,\(dis\)表示两点间的距离。

而\(\max\limits_{k=1}^{t}\{d[u_k]\}\)显然是个定值,至于\(dis\),我们可专门剖出直径上的所有边,然后用在枚举核的左端点时用两个变量维护即可,时间复杂度\(O(n)\)。

#include<cstdio>
using namespace std;
const int N = 5e5 + 10;
struct dd {
int dis, x;
};
dd D[N], a[N];
int fi[N], di[N << 1], da[N << 1], ne[N << 1], l, ma;
bool v[N];
inline int re()
{
int x = 0;
char c = getchar();
bool p = 0;
for (; c<'0' || c>'9'; c = getchar())
p |= c == '-';
for (; c >= '0'&&c <= '9'; c = getchar())
x = x * 10 + (c - '0');
return p ? -x : x;
}
inline int maxn(int x, int y)
{
return x > y ? x : y;
}
inline int minn(int x, int y)
{
return x < y ? x : y;
}
inline void add(int x, int y, int z)
{
di[++l] = y;
da[l] = z;
ne[l] = fi[x];
fi[x] = l;
}
void dfs(int x, int fa, int dis, int la)
{
int i, y;
if (dis > ma)
{
ma = dis;
D[0].x = x;
}
D[x].x = fa;
D[x].dis = la;
for (i = fi[x]; i; i = ne[i])
{
y = di[i];
if (y != fa)
dfs(y, x, dis + da[i], da[i]);
}
}
void dfs_2(int x, int dis)
{
int i, y;
v[x] = 1;
if (dis > ma)
ma = dis;
for (i = fi[x]; i; i = ne[i])
{
y = di[i];
if (!v[y])
dfs_2(y, dis + da[i]);
}
}
int main()
{
int i, j, n, m, x, y, z, s = 0, k = 0, tail = 0, head = 0, an = 1e9;
n = re();
m = re();
for (i = 1; i < n; i++)
{
x = re();
y = re();
z = re();
add(x, y, z);
add(y, x, z);
}
dfs(1, 0, 0, 0);
ma = 0;
dfs(D[0].x, 0, 0, 0);
for (i = D[0].x; i; i = D[i].x)
{
v[i] = 1;
a[++k].x = i;
a[k].dis = D[i].dis;
}
ma = 0;
for (i = 1; i <= k; i++)
dfs_2(a[i].x, 0);
for (j = 1; j < n; j++)
if (s + a[j].dis <= m)
s += a[j].dis;
else
break;
for (i = j; i < n; i++)
tail += a[i].dis;
an = minn(an, maxn(ma, maxn(head, tail)));
for (i = 1; i < n; i++)
{
s -= a[i].dis;
head += a[i].dis;
for (; j < n; j++)
if (s + a[j].dis <= m)
{
s += a[j].dis;
tail -= a[j].dis;
}
else
break;
an = minn(an, maxn(ma, maxn(head, tail)));
}
printf("%d", an);
return 0;
}
05-08 08:29