题解

加法没写取模然后gg了QwQ,de了半天

思想还是比较自然的,线段树合并的维护方法我是真的很少写,然后没想到

很显然,我们有个很愉快的想法是,对于每个节点枚举它所有的叶子节点,对于一个叶子节点的值为v,然后查询另一棵树小于v的概率和×该节点的p + 大于v的概率和 × 该节点的(1 - p),作为这个v新的概率

我们用线段树合并优化这个操作,我们对于两个树的左右儿子计算四个值

分别是

对于第一棵树的左区间,计算第二棵树的右区间的影响,是第二棵树右区间的概率和×(1 - p)

对于第一棵树的右区间,计算第二棵树的左区间的影响,是第二棵树左区间的概率和×p

对于第二棵树的左区间,计算第一棵树的右区间的影响,是第一棵树右区间的概率和×(1 - p)

对于第二棵树的右区间,计算第一棵树的左区间的影响,是第一棵树左区间的概率和×p

然后当这个节点只有一棵树有值的时候,我们再把这个影响下放下去

最后把第一个点的线段树建出来就好

代码

#include <iostream>
#include <algorithm>
#include <cstdio>
#include <cstring>
#include <vector>
#include <set>
#include <cmath>
#include <bitset>
#include <queue>
#define enter putchar('\n')
#define space putchar(' ')
//#define ivorysi
#define pb push_back
#define mo 974711
#define pii pair<int,int>
#define mp make_pair
#define fi first
#define se second
#define MAXN 300005
#define eps 1e-12
#define lc(u) tr[u].lc
#define rc(u) tr[u].rc
using namespace std;
typedef long long int64;
typedef long double db;
template<class T>
void read(T &res) {
res = 0;char c = getchar();T f = 1;
while(c < '0' || c > '9') {
if(c == '-') f = -1;
c = getchar();
}
while(c >= '0' && c <= '9') {
res = res * 10 - '0' + c;
c = getchar();
}
res = res * f;
}
template<class T>
void out(T x) {
if(x < 0) {x = -x;putchar('-');}
if(x >= 10) out(x / 10);
putchar('0' + x % 10);
}
const int MOD = 998244353;
struct tr_node {
int lc,rc;
int sum,lazy;
}tr[MAXN * 40];
struct node {
int to,next;
}E[MAXN * 2];
int Ncnt,rt[MAXN];
int N,head[MAXN],sumE,p[MAXN],w[MAXN],tot,ans;
bool son[MAXN];
int mul(int a,int b) {
return 1LL * a * b % MOD;
}
int inc(int a,int b) {
return a + b >= MOD ? a + b - MOD : a + b;
}
int fpow(int x,int c) {
int res = 1,t = x;
while(c) {
if(c & 1) res = mul(res,t);
t = mul(t,t);
c >>= 1;
}
return res;
}
void add(int u,int v) {
E[++sumE].to = v;
E[sumE].next = head[u];
head[u] = sumE;
} void addlazy(int u,int v) {
if(!u) return;
tr[u].sum = mul(tr[u].sum,v);
tr[u].lazy = mul(tr[u].lazy,v);
}
void pushdown(int u) {
if(tr[u].lazy != 1) {
addlazy(lc(u),tr[u].lazy);
addlazy(rc(u),tr[u].lazy);
tr[u].lazy = 1;
}
}
void update(int u) {
tr[u].sum = inc(tr[lc(u)].sum,tr[rc(u)].sum);
}
void build(int &u,int L,int R,int pos) {
u = ++Ncnt;
tr[u].lazy = 1;
if(L == R) {tr[u].sum = 1;return;}
int mid = (L + R) >> 1;
if(pos <= mid) build(tr[u].lc,L,mid,pos);
else build(tr[u].rc,mid + 1,R,pos);
update(u);
}
int Merge(int Lrt,int Rrt,int p,int m1,int m2) {
if(!Rrt) {
addlazy(Lrt,m1);
return Lrt;
}
if(!Lrt) {
addlazy(Rrt,m2);
return Rrt;
}
pushdown(Lrt);pushdown(Rrt);
int l1 = mul(tr[rc(Rrt)].sum,MOD + 1 - p);
int r1 = mul(tr[lc(Rrt)].sum,p);
int l2 = mul(tr[rc(Lrt)].sum,MOD + 1 - p);
int r2 = mul(tr[lc(Lrt)].sum,p);
tr[Lrt].lc = Merge(tr[Lrt].lc,tr[Rrt].lc,p,inc(m1,l1),inc(m2,l2));
tr[Lrt].rc = Merge(tr[Lrt].rc,tr[Rrt].rc,p,inc(m1,r1),inc(m2,r2));
update(Lrt);
return Lrt;
}
void Calc(int u,int L,int R) {
if(L == R) {
ans = inc(ans,mul(mul(L,w[L]),mul(tr[u].sum,tr[u].sum)));
return;
}
int mid = (L + R) >> 1;
pushdown(u);
Calc(tr[u].lc,L,mid);
Calc(tr[u].rc,mid + 1,R);
}
void dfs(int u,int fa) {
if(!son[u]) {
int pos = lower_bound(w + 1,w + tot + 1,p[u]) - w;
build(rt[u],1,tot,pos);
return;
}
int s[2] = {0,0},t = 0;
for(int i = head[u] ; i ; i = E[i].next) {
int v = E[i].to;
if(v != fa) {
dfs(v,u);
s[t++] = v;
}
}
if(t == 1) rt[u] = rt[s[0]];
else {
rt[u] = Merge(rt[s[0]],rt[s[1]],p[u],0,0);
}
}
void Solve() {
read(N);
int f;
for(int i = 1 ; i <= N ; ++i) {
read(f);
if(f == 0) continue;
add(i,f);add(f,i);
son[f] = 1;
}
int t = fpow(10000,MOD - 2);
for(int i = 1 ; i <= N ; ++i) {
read(p[i]);
if(son[i]) p[i] = mul(p[i],t);
else w[++tot] = p[i];
}
sort(w + 1,w + tot + 1);
dfs(1,0);
Calc(rt[1],1,tot);
out(ans);enter;
}
int main() {
#ifdef ivorysi
freopen("f1.in","r",stdin);
#endif
Solve();
}
05-11 20:35