题意:

给一棵树,每次询问删掉两条边,问剩下的三棵树的最大直径

点10W,询问10W,询问相互独立

Solution:

考虑线段树/倍增维护树的直径

考虑一个点集的区间 [l, r]

而我们知道了有 l <= k < r,

且知道 [l, k] 和 [k + 1, r] 两个区间的最长链的端点及长度

假设两个区间的直径端点分别为 (l1, r1) 和 (l2, r2)

那么 [l, r] 这个区间的直径长度为

dis(l1, r1) dis(l1, l1)  dis(l1, r2)

dis(r1, l2) dis(r1, r2) dis(l2, r2)

六个值中的最大值

本题因为操作子树,所以我们维护dfs序的区间最长链即可

证明:

首先有一个结论:

树上任意一个点在树中的最远点是树的直径的某个端点。我们可以用反证法轻易地证明这一点。

再扩展一下,有以下结论:树上任意一个点在树中的一个点集中的最远点是该点集中最长链的一个端点。

其实我们把点集等价地看为一棵虚树,然后就能用相似的证法解决了。

代码:

 #include <stdio.h>
#include <algorithm> using namespace std; const int N = 2e5 + ; int T, n, m; int len, head[N], ST[][N]; struct edge{int u, v, w;}ee[N]; int cnt, fa[N], log_2[N], st[N], en[N], dfn[N], dis[N], dep[N], pos[N]; struct edges{int to, next, cost;}e[N]; inline void add(int u, int v, int w) {
e[++ len] = (edges){v, head[u], w}, head[u] = len;
e[++ len] = (edges){u, head[v], w}, head[v] = len;
} inline void dfs1(int u) {
st[u] = ++ cnt, dfn[cnt] = u;
for (int v, i = head[u]; i; i = e[i].next) {
v = e[i].to;
if (v == fa[u]) continue;
fa[v] = u, dep[v] = dep[u] + ;
dis[v] = dis[u] + e[i].cost, dfs1(v);
}
en[u] = cnt;
} inline void dfs2(int u) {
dfn[++ cnt] = u, pos[u] = cnt;
for (int v, i = head[u]; i; i = e[i].next) {
v = e[i].to;
if (v == fa[u]) continue;
dfs2(v), dfn[++ cnt] = u;
}
} int mmin(int x, int y) {
if (dep[x] < dep[y]) return x;
return y;
} inline int lca(int u, int v) {
static int w;
if (pos[u] > pos[v]) swap(u, v);
w = log_2[pos[v] - pos[u] + ];
return mmin(ST[w][pos[u]], ST[w][pos[v] - ( << w) + ]);
} inline int dist(int u, int v) {
int Lca = lca(u, v);
return dis[u] + dis[v] - dis[Lca] * ;
} inline void build() {
for (int i = ; i <= cnt; i ++)
ST[][i] = dfn[i];
for (int i = ; i < ; i ++)
for (int j = ; j <= cnt; j ++)
if (j + ( << (i - )) > cnt) ST[i][j] = ST[i - ][j];
else ST[i][j] = mmin(ST[i - ][j], ST[i - ][j + ( << (i - ))]);
} int M; struct node {
int l, r, dis;
}tr[N << ]; inline void update(int o, int o1, int o2) {
static int d;
static node tmp;
if (tr[o1].dis == -) {tr[o] = tr[o2]; return;}
if (tr[o2].dis == -) {tr[o] = tr[o1]; return;}
if (tr[o1].dis > tr[o2].dis) tmp = tr[o1];
else tmp = tr[o2];
d = dist(tr[o1].l, tr[o2].l);
if (d > tmp.dis) tmp.l = tr[o1].l, tmp.r = tr[o2].l, tmp.dis = d;
d = dist(tr[o1].l, tr[o2].r);
if (d > tmp.dis) tmp.l = tr[o1].l, tmp.r = tr[o2].r, tmp.dis = d;
d = dist(tr[o1].r, tr[o2].l);
if (d > tmp.dis) tmp.l = tr[o1].r, tmp.r = tr[o2].l, tmp.dis = d;
d = dist(tr[o1].r, tr[o2].r);
if (d > tmp.dis) tmp.l = tr[o1].r, tmp.r = tr[o2].r, tmp.dis = d;
tr[o] = tmp;
} inline void ask(int s, int t) {
if (s > t) return;
for (s += M - , t += M + ; s ^ t ^ ; s >>= , t >>= ) {
if (~s&) update(, , s ^ );
if ( t&) update(, , t ^ );
}
} inline int get_char() {
static const int SIZE = << ;
static char *T, *S = T, buf[SIZE];
if (S == T) {
T = fread(buf, , SIZE, stdin) + (S = buf);
if (S == T) return -;
}
return *S ++;
} inline void in(int &x) {
static int ch;
while (ch = get_char(), ch > || ch < );x = ch - ;
while (ch = get_char(), ch > && ch < ) x = x * + ch - ;
} int main() {
int u, v, w, ans;
log_2[] = ;
for (int i = ; i <= ; i ++)
if (i == << (log_2[i - ] + ))
log_2[i] = log_2[i - ] + ;
else log_2[i] = log_2[i - ];
for (in(T); T --; ) {
in(n), in(m), cnt = len = ;
for (int i = ; i <= n; i ++)
head[i] = ;
for (int i = ; i < n; i ++) {
in(ee[i].u), in(ee[i].v), in(ee[i].w);
add(ee[i].u, ee[i].v, ee[i].w);
}
dfs1();
for (M = ; M < n + ; M <<= );
for (int i = ; i <= n; i ++)
tr[i + M].l = tr[i + M].r = dfn[i], tr[i + M].dis = ;
for (int i = n + M + ; i <= (M << ) + ; i ++)
tr[i].dis = -;
cnt = , dfs2(), build();
for (int i = M; i; i --)
update(i, i << , i << | );
for (int i = ; i < n; i ++)
if (dep[ee[i].u] > dep[ee[i].v])
swap(ee[i].u, ee[i].v);
for (int u, v, i = ; i <= m; i ++) {
in(u), in(v), ans = ;
u = ee[u].v, v = ee[v].v, w = lca(u, v);
if (w == u || w == v) {
if (w != u) swap(u, v);
tr[].dis = -, ask(, st[u] - ), ask(en[u] + , n), ans = max(ans, tr[].dis);
tr[].dis = -, ask(st[u], st[v] - ), ask(en[v] + , en[u]), ans = max(ans, tr[].dis);
tr[].dis = -, ask(st[v], en[v]), ans = max(ans, tr[].dis);
}
else {
if (st[u] > st[v]) swap(u, v);
tr[].dis = -, ask(, st[u] - ), ask(en[u] + , st[v] - ), ask(en[v] + , n), ans = max(ans, tr[].dis);
tr[].dis = -, ask(st[u], en[u]), ans = max(ans, tr[].dis);
tr[].dis = -, ask(st[v], en[v]), ans = max(ans, tr[].dis);
}
printf("%d\n", ans);
}
}
return ;
}

一开始没带脑子算错了复杂度,少算了个log开心的写了树剖LCA,还在dfs的时候求siz忘记把儿子的siz加上了

T到死...发现是带2个log,该死出题人多组数据不给数据组数,改写ST表O(1)求LCA,复杂度只带1个log过了

理论上线段树也可以用ST表代替,复杂度O(n)...当然不可能啦,预处理nlogn,回答O(1)

附加训练 51nod 1766

 #include <stdio.h>
#include <algorithm> using namespace std; const int N = 1e5 + ; int n, m, M, tot, head[N], st[][N << ], log_2[N << ]; int cnt, dis[N], dep[N], pos[N], dfn[N << ]; struct edge{int to, next, cost;}e[N << ]; int mmin(int x, int y) {
return dep[x] < dep[y] ? x : y;
} void add(int u, int v, int w) {
e[++ tot] = (edge){v, head[u], w}, head[u] = tot;
e[++ tot] = (edge){u, head[v], w}, head[v] = tot;
} void dfs(int u, int fr) {
dfn[++ cnt] = u, pos[u] = cnt;
for (int v, i = head[u]; i; i = e[i].next) {
v = e[i].to;
if (v == fr) continue;
dep[v] = dep[u] + , dis[v] = dis[u] + e[i].cost;
dfs(v, u), dfn[++ cnt] = u;
}
} int lca(int u, int v) {
if (pos[u] > pos[v]) swap(u, v);
int w = log_2[pos[v] - pos[u] + ];
return mmin(st[w][pos[u]], st[w][pos[v] - ( << w) + ]);
} int dist(int u, int v) {
return dis[u] + dis[v] - dis[lca(u, v)] * ;
} struct node {
int l, r, dis; node operator + (const node &a) const {
node res;
if (dis == -) return a;
if (a.dis == -) return *this;
if (dis > a.dis) res = *this;
else res = a;
int d = dist(l, a.l);
if (d > res.dis) res.l = l, res.r = a.l, res.dis = d;
d = dist(l, a.r);
if (d > res.dis) res.l = l, res.r = a.r, res.dis = d;
d = dist(r, a.l);
if (d > res.dis) res.l = r, res.r = a.l, res.dis = d;
d = dist(r, a.r);
if (d > res.dis) res.l = r, res.r = a.r, res.dis = d;
return res;
} node operator * (const node &a) const {
node res; res.dis = -;
int d = dist(l, a.l);
if (d > res.dis) res.l = l, res.r = a.l, res.dis = d;
d = dist(l, a.r);
if (d > res.dis) res.l = l, res.r = a.r, res.dis = d;
d = dist(r, a.l);
if (d > res.dis) res.l = r, res.r = a.l, res.dis = d;
d = dist(r, a.r);
if (d > res.dis) res.l = r, res.r = a.r, res.dis = d;
return res;
}
}tr[N << ]; node ask(int s, int t) {
node res; res.dis = -;
for (s += M - , t += M + ; s ^ t ^ ; s >>= , t >>= ) {
if (~s&) res = res + tr[s ^ ];
if ( t&) res = res + tr[t ^ ];
}
return res;
} int main() {
scanf("%d", &n);
for (int u, v, w, i = ; i < n; i ++)
scanf("%d %d %d", &u, &v, &w), add(u, v, w);
dfs(, ); for (int i = ; i <= cnt; i ++)
st[][i] = dfn[i];
for (int i = ; i < ; i ++)
for (int j = ; j <= cnt; j ++)
if (j + ( << (i - )) > cnt) st[i][j] = st[i - ][j];
else st[i][j] = mmin(st[i - ][j], st[i - ][j + ( << (i - ))]);
log_2[] = ;
for (int i = ; i <= cnt; i ++)
log_2[i] = log_2[i - ] + (i == ( << (log_2[i - ] + ))); for (M = ; M < n + ; M <<= );
for (int i = ; i <= n; i ++) tr[i + M] = (node){i, i, };
for (int i = n + ; i <= M + ; i ++) tr[i + M].dis = -;
for (int i = M; i; i --) tr[i] = tr[i << ] + tr[i << | ]; node tmp; int a, b, c, d;
for (scanf("%d", &m); m --; ) {
scanf("%d %d %d %d", &a, &b, &c, &d);
tmp = ask(a, b) * ask(c, d);
printf("%d\n", tmp.dis);
}
return ;
}

相对简单一点了

05-17 01:45