Description

传送门

给出一个n个点的树,1号节点为根节点,每个点有一个权值
你需要支持以下操作
1.将以u为根的子树内节点(包括u)的权值加val

2.将(u, v)路径上的节点权值加val
3.询问(u, v)路径上节点的权值两两相乘的和
 

Solution

维护 平方和与 数值和

修改 : 假如修改时有节点a, b, c, 增加t, 那么$sumsq = a^2 + b^2 + c^2 + 2*t*(a + b+c) + 3 * t^2$

所以对于所有的修改, $sumsq = sumsq + 2*t*sum + (r-l+1)*t^2$

这样平方和与数值和就可以维护了

查询只需要输出 $(sum * sum - sumsq) / 2$

Code

 #include<cstdio>
#include<cstring>
#include<algorithm>
#define ll long long
#define rd read()
#define lson nd << 1
#define rson nd << 1 | 1
using namespace std; const int N = 1e5 + ;
const ll mod = 1e9 + ; int n, m;
int top[N], size[N], son[N], f[N], dep[N], cnt;
int head[N], tot;
int A[N], a[N], id[N];
int Li[N << ], Ri[N << ];
ll sum[N << ], sum2[N << ], inv, add[N << ]; struct edge{
int nxt, to;
}e[N << ]; int read() {
int X = , p = ; char c = getchar();
for(; c > '' || c < ''; c = getchar()) if(c == '-') p = -;
for(; c >= '' && c <= ''; c = getchar()) X = X * + c - '';
return X * p;
} void added(int u, int v) {
e[++tot].to = v;
e[tot].nxt = head[u];
head[u] = tot;
} void dfs1(int u) {
size[u] = ;
for(int i = head[u]; i; i = e[i].nxt) {
int nt = e[i].to;
if(nt == f[u]) continue;
dep[nt] = dep[u] + ;
f[nt] = u;
dfs1(nt);
size[u] += size[nt];
if(size[nt] > size[son[u]]) son[u] = nt;
}
} void dfs2(int u) {
id[u] = ++cnt;
A[cnt] = a[u];
if(!son[u]) return;
top[son[u]] = top[u];
dfs2(son[u]);
for(int i = head[u]; i; i = e[i].nxt) {
int nt = e[i].to;
if(nt == f[u] || nt == son[u]) continue;
top[nt] = nt;
dfs2(nt);
}
} void pushdown(int nd) {
if(add[nd]) {
sum2[lson] += ( * sum[lson] * add[nd] % mod + (Ri[lson] - Li[lson] + ) * add[nd] % mod * add[nd] % mod) % mod;
sum2[lson] %= mod;
sum2[rson] += ( * sum[rson] * add[nd] % mod + (Ri[rson] - Li[rson] + ) * add[nd] % mod * add[nd] % mod) % mod;
sum2[rson] %= mod; sum[lson] += (Ri[lson] - Li[lson] + ) * add[nd] % mod;
sum[lson] %= mod;
sum[rson] += (Ri[rson] - Li[rson] + ) * add[nd] % mod;
sum[rson] %= mod; add[lson] += add[nd];
add[rson] += add[nd];
add[lson] %= mod;
add[rson] %= mod;
add[nd] = ;
}
} void update(int nd) {
sum[nd] = (sum[lson] + sum[rson]) % mod;
sum2[nd] = (sum2[lson] + sum2[rson]) % mod;
} void build(int l, int r, int nd) {
if(l == r) {
Li[nd] = l;
Ri[nd] = r;
sum[nd] = A[l];
sum2[nd] = A[l] * A[l] % mod;
return;
}
Li[nd] = l;
Ri[nd] = r;
int mid =(l + r) >> ;
build(l, mid, lson);
build(mid + , r, rson);
update(nd);
} void modify(int L, int R, int d, int l, int r, int nd) {
if(L <= l && r <= R) {
sum2[nd] += ( * sum[nd] * d % mod + (r - l + ) * d % mod * d % mod) % mod;
sum2[nd] %= mod;
sum[nd] += 1LL * d * (r - l + ) % mod;
sum[nd] %= mod;
add[nd] += d;
add[nd] %= mod;
return;
}
pushdown(nd);
int mid = (l + r) >> ;
if(mid >= L) modify(L, R, d, l, mid, lson);
if(mid < R) modify(L, R, d, mid + , r, rson);
update(nd);
} ll query(int L, int R, int l, int r, int nd) {
if(L <= l && r <= R) return sum[nd];
int mid = (l + r) >> ;
ll re = ;
pushdown(nd);
if(mid >= L) re = (re + query(L, R, l, mid, lson)) % mod;
if(mid < R) re = (re + query(L, R, mid + , r, rson)) % mod;
return re;
} ll query2(int L, int R, int l, int r, int nd) {
if(L <= l && r <= R) return sum2[nd];
int mid = (l + r) >> ;
ll re = ;
pushdown(nd);
if(mid >= L) re = (re + query2(L, R, l, mid, lson)) % mod;
if(mid < R) re = (re + query2(L, R, mid + , r, rson)) % mod;
return re;
} ll query_po(int x, int y) {
ll re = , tmp, sum = ;
for(; top[x] != top[y];) {
if(dep[top[x]] < dep[top[y]]) swap(x, y);
tmp = query(id[top[x]], id[x], , n, );
sum = (sum + tmp % mod) % mod;
tmp = query2(id[top[x]], id[x], , n, );
re = (re - tmp) % mod;
x = f[top[x]];
}
if(dep[x] < dep[y]) swap(x, y);
tmp = query(id[y], id[x], , n, );
sum = (sum + tmp) % mod;
tmp = query2(id[y], id[x], , n, );
re = (re - tmp) % mod;
re = (re + sum * sum % mod) % mod;
re = (re % mod + mod) % mod;
re = re * inv % mod;
return re;
} void modify_po(int x, int y, int d) {
for(; top[x] != top[y];) {
if(dep[top[x]] < dep[top[y]]) swap(x, y);
modify(id[top[x]], id[x], d, , n, );
x = f[top[x]];
}
if(dep[x] < dep[y]) swap(x, y);
modify(id[y], id[x], d, , n, );
} ll fc(ll ta, ll b) {
ll re = ;
for(; b; b >>= , ta = (ta + ta) % mod) if( b & ) re = (re + ta) % mod;
return re;
} ll fpow(ll ta, ll b) {
ll re = ;
for(; b; b >>= , ta = fc(ta, ta)) if(b & ) re = fc(re, ta);
return re;
} int main()
{
n = rd; m = rd;
for(int i = ; i <= n; ++i) a[i] = rd;
for(int i = ; i < n; ++i) {
int u = rd, v = rd;
added(u, v); added(v, u);
}
dfs1();
top[] = ;
dfs2();
build(, n, );
inv = fpow(, mod - );
for(int i = ; i <= m; ++i) {
int k = rd;
if(k == ) {
int u = rd, val = rd;
modify(id[u], id[u] + size[u] - , val, , n, );
}
if(k == ) {
int u = rd, v = rd, val = rd;
modify_po(u, v, val);
}
if(k == ) {
int u = rd, v = rd;
printf("%lld\n", query_po(u, v));
}
}
}
04-26 12:37