平衡树SPLAY

平衡树这个东西,非常有用,而且锻炼码力,是个非常好的模板题!

那么SPLAY这个东西 十分的毒瘤 我只调了一上午就调出来了!(我真棒)

首先 我们要知道平衡树是一棵二叉查找树 他可以处理:加点,删点,前驱,后继,num的排名,某排名的num,合并两棵平衡树,分离两棵平衡树

关于其原理 网上有很多 而且比较简单

对于splay来说 原理并不是最难的 码代码才是最难的!

那么就提一提细节吧:

当这一棵树为空的时候,\(tot\)不一定=0,而\(root=0\)

当这一棵树为空的时候,如果新加入一个节点,那么没有必要对于这个节点进行splay操作,(反正我的代码会死循环QAQ)

在rotate完了时候 记得update呀

对于该有返回值的函数 一定要返回值!(可怜孩子调了半天)

还有就是在:\(insert(build),find,dele,oprank\)函数里面要\(splay\)

注意:在我的函数\(rank\)中 我只能找到在平衡树中有的\(val\)的排名 而不存在的\(val\)这个函数是查不到的

注释掉的代码也可以查询rank 但是 好丑啊!!

#include <cstdio>
#include <iostream>
#include <algorithm>
#include <cstring>
using namespace std;
const int maxn=150005;
const int INF=2147480000;
int n,op,num;
inline int read()
{
    int x=0,f=1;
    char c=getchar();
    while(c>57||c<48)
    {
        if(c=='-') f=0;
         c=getchar();
    }
    while(c>47&&c<58)
    x=(x<<3)+(x<<1)+(c&15),c=getchar();
    return f?x:~x+1;
}
struct wow{
    int chi[2],size,v,fa,times;
}node[maxn];
int tot,tot_size;
#define root node[0].chi[1]
int identify(int x){
    return (x==node[node[x].fa].chi[0])?0:1;
}
void connect(int x,int father,int son){
    node[x].fa=father;
    node[father].chi[son]=x;
}
void update(int x){
    node[x].size=node[node[x].chi[0]].size+node[node[x].chi[1]].size+node[x].times;
}
void rotate(int x){
    int y=node[x].fa;
    int mroot=node[y].fa;
    int mrootson=identify(y);
    int yson=identify(x);
    int b=node[x].chi[yson^1];
    connect(b,y,yson);connect(y,x,yson^1);connect(x,mroot,mrootson);
    update(y);update(x);
}
void splay(int at,int to){
    to=node[to].fa;
    while(node[at].fa!=to){
        int up=node[at].fa;
        if(node[up].fa==to) rotate(at);
        else if(identify(at)==identify(up)){
            rotate(up);rotate(at);
         }
        else{
            rotate(at);rotate(at);
        }
    }
}

int create_point(int v,int father){
    tot++;
    node[tot].fa=father;node[tot].v=v;
    node[tot].size=1;node[tot].times=1;
    return tot;
}
void insert(int v){
    tot_size++;
    if(root==0){
        root=1;create_point(v,0);
    }
    else{
        int now=root;
        while(1){
            node[now].size++;
            if(v==node[now].v){
                node[now].times++;
                splay(now,root);
                return ;
            }
            int nxt=(v<node[now].v)?0:1;
            if(!node[now].chi[nxt]){
                int p=create_point(v,now);node[now].chi[nxt]=p;
                splay(p,root);
                return ;
            }
            now=node[now].chi[nxt];
         }
    }
}
int find(int v){
    int now=root;
    while(1){
        if(node[now].v==v){
            splay(now,root);return now;
        }
        int nxt=(v<node[now].v)?0:1;
        if(!node[now].chi[nxt]) return 0;
        now=node[now].chi[nxt];
    }
}
void dele(int v){
    int pos=find(v);
    if(!pos) return ;
    tot_size--;
    if(node[pos].times>1){
        node[pos].times--;node[pos].size--;
    }
    else{
        if(!node[pos].chi[0] && !node[pos].chi[1])  root=0;
        else if(!node[pos].chi[0]){
            root=node[pos].chi[1];
            node[root].fa=0;
        }
        else{
            int left=node[pos].chi[0];
            while(node[left].chi[1]) left=node[left].chi[1];
            splay(left,node[pos].chi[0]);
            int right=node[pos].chi[1];
            connect(right,left,1);connect(left,0,1);
            update(left);
        }
    }
    return ;
}
int rank(int v){
    /*
    int ans=0,now=root;
    while(1){
        if(node[now].v==v){
            splay(now,root);
            return ans+node[node[now].chi[0]].size+1;
        }
        if(!now)    return 0;
        if(v<node[now].v) now=node[now].chi[0];
        else{
            ans+=node[node[now].chi[0]].size+node[now].times;
            now=node[now].chi[1];
        }
    }*/
    int pos=find(v);
    return node[node[pos].chi[0]].size+1;
}
int oprank(int x){
    int sum=0,now=root;
    if(x>tot_size)  return 0;
    while(1){
        int mleftsum=node[now].size-node[node[now].chi[1]].size;
        if(x>node[node[now].chi[0]].size && x<=mleftsum) break;
        if(x<mleftsum)  now=node[now].chi[0];
        else{
            x-=mleftsum;
            now=node[now].chi[1];
        }
    }
    splay(now,root);
    return node[now].v;
}
int upper(int v){
    int now=root,ans=INF;
    while(now){
        if(node[now].v>v && node[now].v<ans) ans=node[now].v;
        if(v<node[now].v) now=node[now].chi[0];
        else now=node[now].chi[1];
    }
    return ans;
}
int lower(int v){
    int now=root,ans=-INF;
    while(now){
        if(node[now].v<v && node[now].v>ans) ans=node[now].v;
        if(v>node[now].v) now=node[now].chi[1];
        else now=node[now].chi[0];
    }
    return ans;
}
int main(){
    n=read();
    for(int i=1;i<=n;i++){
         op=read();num=read();
         if(op==1)
            insert(num);
         else if(op==2)
            dele(num);
         else if(op==3)
            printf("%d\n",rank(num));
         else if(op==4)
            printf("%d\n",oprank(num));
         else if(op==5)
            printf("%d\n",lower(num));
         else
            printf("%d\n",upper(num));
    }
}
12-28 15:59