【CF833D】Red-Black Cobweb
题面
题解
看到这种统计路径的题目当然是淀粉质啦。
考虑转化一下信息设一条路径上有红点\(a\)个,黑点\(b\)个
则\(2min(a,b)\geq max(a,b)\)
\(\Leftrightarrow 2*a\geq b\)且\(2*b\geq a\)
现在我们需要将过一个点的两条路径合并
设第一条为红\(a_1\),黑\(b_1\),第二条为红\(a_2\),黑\(b_2\)
则有
\[2(a_1+a_2)\geq b_1+b_2\\
2(b_1+b_2)\geq a_1+a_2
\]
2(b_1+b_2)\geq a_1+a_2
\]
将一个下标的放一边以便维护
\[2a_2-b_2\geq b_1-2a_1\\
2b_2-a_2\geq a_1-2b_1
\]
2b_2-a_2\geq a_1-2b_1
\]
每次遍历完一颗子树,按时间加入所有的路径,将不等式左边看作查询二维平面,
右边看作插入坐标,就是一个\(cdq\)分治
复杂度是\(nlog^4\)(因为中间还有快速幂),但常数很小
代码
#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <cmath>
#include <algorithm>
using namespace std;
inline int gi() {
register int data = 0, w = 1;
register char ch = 0;
while (!isdigit(ch) && ch != '-') ch = getchar();
if (ch == '-') w = -1, ch = getchar();
while (isdigit(ch)) data = 10 * data + ch - '0', ch = getchar();
return w * data;
}
const int MAX_N = 1e5 + 5;
const int Mod = 1e9 + 7;
int fpow(int x, int y) {
int res = 1;
while (y) {
if (y & 1) res = 1ll * res * x % Mod;
x = 1ll * x * x % Mod;
y >>= 1;
}
return res;
}
struct Point { int x, y, op, v; } ;
bool operator < (const Point &l, const Point &r) { return (l.x == r.x) ? (l.y < r.y) : (l.x < r.x); }
struct Graph { int to, cost, col, next; } e[MAX_N << 1]; int fir[MAX_N], e_cnt = 0;
void clearGraph() { memset(fir, -1, sizeof(fir)); e_cnt = 0; }
void Add_Edge(int u, int v, int w, int c) {
e[e_cnt].to = v, e[e_cnt].cost = w, e[e_cnt].col = c, e[e_cnt].next = fir[u];
fir[u] = e_cnt++;
}
int N, ans = 1, size[MAX_N];
bool used[MAX_N];
int centroid, sz, rmx, c1[MAX_N << 2], c2[MAX_N << 2];
Point stk[MAX_N], q[MAX_N << 2];
int top, cnt;
inline int lb(int x) { return x & -x; }
void add(int x, int v) { while (x <= N * 4 + 1) c1[x] = 1ll * c1[x] * v % Mod, c2[x]++, x += lb(x); }
int Sum(int x) { int res = 1; while (x > 0) res = 1ll * c1[x] * res % Mod, x -= lb(x); return res; }
int Cnt(int x) { int res = 0; while (x > 0) res += c2[x], x -= lb(x); return res; }
void Set(int x) { while (x <= N * 4 + 1) c1[x] = 1, c2[x] = 0, x += lb(x); }
void search_centroid(int x, int fa) {
size[x] = 1; int mx = 0;
for (int i = fir[x]; ~i; i = e[i].next) {
int v = e[i].to;
if (v == fa || used[v]) continue;
search_centroid(v, x);
size[x] += size[v];
mx = max(mx, size[v]);
}
mx = max(mx, sz - size[x]);
if (mx < rmx) rmx = mx, centroid = x;
}
void dfs(int x, int fa, int R, int B, int val) {
stk[++top] = (Point){R, B, 0, val};
for (int i = fir[x]; ~i; i = e[i].next) {
int v = e[i].to;
if (v == fa || used[v]) continue;
if (e[i].col == 0) dfs(v, x, R + 1, B, 1ll * val * e[i].cost % Mod);
else if (e[i].col != 0) dfs(v, x, R, B + 1, 1ll * val * e[i].cost % Mod);
}
}
void Div(int l, int r) {
if (l >= r) return ;
int mid = (l + r) >> 1;
Div(l, mid); Div(mid + 1, r);
int j = l;
for (int i = mid + 1; i <= r; i++) {
if (!q[i].op) continue;
while (q[j].x <= q[i].x && j <= mid) { if (!q[j].op) add(q[j].y, q[j].v); ++j; }
ans = 1ll * ans * Sum(q[i].y) % Mod * fpow(q[i].v, Cnt(q[i].y)) % Mod;
}
for (int i = l; i < j; i++) if (!q[i].op) Set(q[i].y);
inplace_merge(&q[l], &q[mid + 1], &q[r + 1]);
}
void solve(int x) {
used[x] = 1;
cnt = 0; int Pls = 2 * N + 1;
for (int i = fir[x]; ~i; i = e[i].next) {
int v = e[i].to;
if (used[v]) continue;
top = 0;
if (e[i].col == 0) dfs(v, x, 1, 0, e[i].cost);
else if (e[i].col == 1) dfs(v, x, 0, 1, e[i].cost);
for (int j = 1; j <= top; j++) {
int a = stk[j].x, b = stk[j].y;
q[++cnt] = (Point){2 * a - b + Pls, 2 * b - a + Pls, 1, stk[j].v};
}
for (int j = 1; j <= top; j++) {
int a = stk[j].x, b = stk[j].y;
q[++cnt] = (Point){b - 2 * a + Pls, a - 2 * b + Pls, 0, stk[j].v};
if (2 * min(a, b) >= max(a, b)) ans = 1ll * ans * stk[j].v % Mod;
}
}
Div(1, cnt);
for (int i = fir[x]; ~i; i = e[i].next) {
int v = e[i].to;
if (used[v]) continue;
sz = rmx = size[v];
search_centroid(v, x);
solve(centroid);
}
}
int main () {
clearGraph();
N = gi();
for (int i = 1; i < N; i++) {
int u = gi(), v = gi(), w = gi(), c = gi();
Add_Edge(u, v, w, c);
Add_Edge(v, u, w, c);
}
for (int i = 1; i <= 4 * N + 1; i++) c1[i] = 1, c2[i] = 0;
sz = rmx = N;
search_centroid(1, 0);
solve(centroid);
printf("%d\n", ans);
return 0;
}