问题本质是把\(a_i\)作为\(i\)的父亲,然后如果有环就不合法,否则每次要取数,要满足取之前他的父亲都被取过(父亲为0可以直接取),求最大价值
贪心想法显然是要把权值大的尽量放在后面,这等价于把权值小的尽量放在前面.所以如果当前最小的数没有父亲,显然直接取出来最优;如果有父亲,那么这个数应该在它的父亲被取之后马上取出来.这时我们把这两个点合并.之后重复此操作知道所有点被取完,就能得到答案
还有个问题是两个点合并后怎么取权值.两个点合并相当于两个序列合并,序列分别记为\(\{a_1,a_2...a_n\},\{b_1,b_2...b_m\}\),考虑什么时候\(\{a\}\)会放在\(\{b\}\)前面,\(\{a\}\)在前面的答案为\(ans_a+ans_b+n\sum_{j=1}^{m}b_j\),\(\{b\}\)在前面的答案为\(ans_a+ans_b+m\sum_{i=1}^{n}a_i\),\(\{a\}\)在前面当且仅当\(n\sum_{j=1}^{m}b_j\ge m\sum_{i=1}^{n}a_i\),等价于\(\frac{\sum a_i}{n}\le \frac{\sum b_j}{m}\),所以把权值设为里面点点权平均值即可.然后两个点\(a,b\)合并,会产生\(n\sum_{j=1}^{m}b_j\)的贡献,直接往答案里加即可
#include<bits/stdc++.h>
#define LL long long
#define uLL unsigned long long
#define db double
using namespace std;
const int N=5e5+10;
int rd()
{
int x=0,w=1;char ch=0;
while(ch<'0'||ch>'9'){if(ch=='-') w=-1;ch=getchar();}
while(ch>='0'&&ch<='9'){x=(x<<3)+(x<<1)+(ch^48);ch=getchar();}
return x*w;
}
struct node
{
LL w,sz,i;
bool operator < (const node &bb) const {return w*bb.sz!=bb.w*sz?w*bb.sz>bb.w*sz:i>bb.i;}
bool operator == (const node &bb) const {return w==bb.w&&sz==bb.sz&&i==bb.i;}
}a[N];
bool ban[N];
struct HEAP
{
priority_queue<node> q1;
void mntn(){while(!q1.empty()&&(ban[q1.top().i]||!(q1.top()==a[q1.top().i]))) q1.pop();}
void push(node x){q1.push(x);}
void pop(){mntn();q1.pop();}
node top(){mntn();return q1.top();}
}hp;
int n,fa[N],ff[N];
LL ans,sm;
int findf(int x){return ff[x]==x?x:ff[x]=findf(ff[x]);}
int main()
{
n=rd();
for(int i=1;i<=n;++i) ff[i]=i;
for(int i=1;i<=n;++i)
{
fa[i]=rd();
if(fa[i])
{
int x=findf(i),y=findf(fa[i]);
if(x==y){puts("-1");return 0;}
ff[y]=x;
}
}
for(int i=1;i<=n;++i)
{
int w=rd();
hp.push((a[i]=(node){w,1,i}));
sm+=w;
}
ans=sm;
int gg=0;
for(int i=1;i<=n;++i) ff[i]=i;
for(int i=1;i<=n;++i)
{
int x=hp.top().i;
hp.pop();
if(findf(fa[x]))
{
int xx=findf(fa[x]);
ans+=a[xx].sz*a[x].w;
a[xx].w+=a[x].w,a[xx].sz+=a[x].sz;
hp.push(a[xx]);
ff[x]=xx;
}
else ff[x]=0,sm-=a[x].w,ans+=a[x].sz*sm;
ban[x]=1;
}
for(int i=1;i<=n;++i) ff[i]=findf(i);
printf("%lld\n",ans);
return 0;
}