题目链接:

参考于AgOH大佬的递归版splay做法!!!

#include <bits/stdc++.h>
using namespace std;
#define ll long long
#define re register
const int N=1e6+10;
void read(int &a)
{
    a=0;int d=1;char ch;
    while(ch=getchar(),ch>'9'||ch<'0')
        if(ch=='-')
            d=-1;
    a=ch^48;
    while(ch=getchar(),ch>='0'&&ch<='9')
        a=(a<<3)+(a<<1)+(ch^48);
    a*=d;
}
struct note{int l,r,siz,cnt,val;}spl[N];
int rt,cnt;
void newnode(int &now,int val)
{
    spl[now=++cnt].val=val;
    spl[now].siz=1;
    spl[now].cnt=1;
}
void update(int now){spl[now].siz=spl[spl[now].l].siz+spl[spl[now].r].siz+spl[now].cnt;}
void zig(int &now)
{
    int l=spl[now].l;
    spl[now].l=spl[l].r;
    spl[l].r=now;
    now=l;
    update(spl[now].r),update(now);
}
void zag(int &now)
{
    int r=spl[now].r;
    spl[now].r=spl[r].l;
    spl[r].l=now;
    now=r;
    update(spl[now].l),update(now);
}
void splaying(int x,int &y)
{
    if(x==y) return;
    int &l=spl[y].l,&r=spl[y].r;
    if(x==l) zig(y);
    else if(x==r) zag(y);
    else if(spl[x].val<spl[y].val)
    {
        if(spl[x].val<spl[l].val) splaying(x,spl[l].l),zig(y),zig(y);
        else splaying(x,spl[l].r),zag(l),zig(y);
    }
    else
    {
        if(spl[x].val<spl[r].val) splaying(x,spl[r].l),zig(r),zag(y);
        else splaying(x,spl[r].r),zag(y),zag(y);
    }
}
void delnode(int now)
{
    splaying(now,rt);
    if(spl[now].cnt>1) spl[now].cnt--,spl[now].siz--;
    else if(spl[now].r)
    {
        int p=spl[now].r;
        while(spl[p].l) p=spl[p].l;
        splaying(p,spl[now].r);
        spl[spl[rt].r].l=spl[rt].l;
        rt=spl[rt].r;
        update(rt);
    }
    else rt=spl[now].l;
}
void ins(int &now,int val)
{
    if(!now) newnode(now,val),splaying(now,rt);
    else if(spl[now].val>val) ins(spl[now].l,val);
    else if(spl[now].val<val) ins(spl[now].r,val);
    else spl[now].cnt++,spl[now].siz++,splaying(now,rt);
}
void del(int &now,int val)
{
    if(spl[now].val==val) delnode(now);
    else if(spl[now].val>val) del(spl[now].l,val);
    else del(spl[now].r,val);
}
int getnum(int rk)
{
    int now=rt;
    while(now)
    {
        int lsiz=spl[spl[now].l].siz;
        if(lsiz+1<=rk&&rk<=lsiz+spl[now].cnt)
        {
            splaying(now,rt);
            break;
        }
        else if(lsiz>=rk) now=spl[now].l;
        else
        {
            rk-=lsiz+spl[now].cnt;
            now=spl[now].r;
        }
    }
    return spl[now].val;
}
int getrank(int val)
{
    int now=rt,rk=1;
    while(now)
    {
        if(spl[now].val==val)
        {
            rk+=spl[spl[now].l].siz;
            splaying(now,rt);
            break;
        }
        else if(spl[now].val>val) now=spl[now].l;
        else
        {
            rk+=spl[now].cnt+spl[spl[now].l].siz;
            now=spl[now].r;
        }
    }
    return rk;
}
int main()
{
    int n;read(n);
    for(re int i=1,op,x;i<=n;i++)
    {
        read(op),read(x);
        if(op==1) ins(rt,x);
        else if(op==2) del(rt,x);
        else if(op==3) printf("%d\n",getrank(x));
        else if(op==4) printf("%d\n",getnum(x));
        else if(op==5) printf("%d\n",getnum(getrank(x)-1));
        else printf("%d\n",getnum(getrank(x+1)));
    }
    return 0;
}
01-08 03:27