CF573E (平衡树)

扫码查看

CF573E

题意概要

给出一个长度为\(n\)的数列,从中选出一个子序列\(b[1...m]\)可以为空

使得\[ \sum_{i=1}^m{b_i*i}\]最大,输出这个最大值。

其中\(n\le10^5\)

题解

\(dp_{i,j}\)表示前\(i\)个数选择\(j\)个数的最大值

那么,转移方程则为:
\[dp_{i,j}=max(dp_{i-1,j},dp_{i-1,j-1}+j*a_i)\]
于是我们就得到了一个\(n^2\)的做法

我们考虑优化这个式子。

\(dalao\)证明,发现总有存在一个分界线,这之前的取前者,这之后的取后者

大佬的证明在这里

我们二分分界线,然后用平衡树维护就好了

顺带一提,我今日方知\(splay\)没事多$ splay $几下还会变快

代码

#include<bits/stdc++.h>
#include<windows.h>
#define lch c[x][0]
#define rch c[x][1]
using namespace std;
typedef long long ll;
const int sz=1e5+7;
int n;
ll v,ans;
int rt,cnt;
int f[sz];
int c[sz][2];
int siz[sz];
ll val[sz],tag1[sz],tag2[sz];
inline int newnode(ll v){
    int x=++cnt;
    val[x]=v;
    siz[x]=1;
    return x;
}
inline void pushup(int x){
    siz[x]=siz[lch]+siz[rch]+1;
}
inline void add(int x,ll tg1,ll tg2){
    val[x]+=siz[lch]*tg1+tg2;
    tag1[x]+=tg1;
    tag2[x]+=tg2;
}
inline void pd(int x){
    if(tag1[x]==0&&tag2[x]==0) return;
    if(lch) add(lch,tag1[x],tag2[x]);
    if(rch) add(rch,tag1[x],tag2[x]+(siz[lch]+1)*tag1[x]);
    tag1[x]=0;
    tag2[x]=0;
}
inline void pushdn(int x){
    if(f[x]) pushdn(f[x]);
    pd(x);
}
inline int get(int x){
    return c[f[x]][1]==x;
}
inline void dfs(int x){
    ans=max(ans,val[x]);
    pd(x);
    if(lch) dfs(lch);
    if(rch) dfs(rch);
}
inline void rotate(int x){
    int y=f[x],z=f[y],k=get(x),w=c[x][!k];
    if(z) c[z][get(y)]=x;c[x][!k]=y;c[y][k]=w;
    if(w) f[w]=y;if(y) f[y]=x;f[x]=z;
    pushup(y);
}
inline void splay(int x,int t){
    pushdn(x);
    while(f[x]!=t){
        int y=f[x];
        if(f[y]!=t) rotate(get(x)^get(y)?x:y);
        rotate(x);
    }
    pushup(x);
    if(t==0) rt=x;
}
inline int find(int k){
    int x=rt;
    while(1){
        pd(x);
        if(siz[lch]+1==k) return x;
        if(k<=siz[lch]) x=lch;
        else k-=siz[lch]+1,x=rch;
    }
}
inline void insert(int k,ll v){
    int x=find(k-1);splay(x,0);
    int y=find(k);splay(y,x);
    c[y][0]=newnode(v);
    f[c[y][0]]=y;
    pushup(y);
    pushup(x);
}
inline void modify(int l,int r,ll a,ll b){
    int x=find(l-1);splay(x,0);
    int y=find(r+1);splay(y,x);
    add(c[y][0],a,b);
    pushup(y);
    pushup(x);
}
int main(){
    scanf("%d",&n);
    rt=newnode(0);
    c[rt][0]=newnode(INT_MIN);
    c[rt][1]=newnode(INT_MIN);
    f[c[rt][0]]=f[c[rt][1]]=rt;
    pushup(rt);
    for(int i=1;i<=n;i++){
        scanf("%lld",&v);
        int l=2,r=i+1;
        while(l<r){
            int mid=(l+r)>>1;
            if(val[find(mid)]+(mid-1)*v>=val[find(mid+1)]) r=mid;
            else l=mid+1;
            splay(find(mid),0);
        }
        int x=find(l);
        ll y=val[x];
        modify(l,i+1,v,v*(l-1));
        insert(l,y);
    }
    dfs(rt);
    printf("%lld\n",ans);
}
01-25 21:05
查看更多