找到一份比较好的板子,链接https://blog.csdn.net/crazy_ac/article/details/8034190
#include<cstdio>
#include<cstdlib>
const int inf = ~0u>>;
#define L ch[x][0]
#define R ch[x][1]
#define KT (ch[ ch[rt][1] ][0])
const int maxn = ;
int lim;
struct SplayTree {
int sz[maxn];
int ch[maxn][];
int pre[maxn];
int rt,top;
inline void up(int x){
sz[x] = cnt[x] + sz[ L ] + sz[ R ];
}
inline void Rotate(int x,int f){
int y=pre[x];
ch[y][!f] = ch[x][f];
pre[ ch[x][f] ] = y;
pre[x] = pre[y];
if(pre[x]) ch[ pre[y] ][ ch[pre[y]][] == y ] =x;
ch[x][f] = y;
pre[y] = x;
up(y);
}
inline void Splay(int x,int goal){//将x旋转到goal的下面
while(pre[x] != goal){
if(pre[pre[x]] == goal) Rotate(x , ch[pre[x]][] == x);
else {
int y=pre[x],z=pre[y];
int f = (ch[z][]==y);
if(ch[y][f] == x) Rotate(x,!f),Rotate(x,f);
else Rotate(y,f),Rotate(x,f);
}
}
up(x);
if(goal==) rt=x;
}
inline void RTO(int k,int goal){//将第k位数旋转到goal的下面
int x=rt;
while(sz[ L ] != k-) {
if(k < sz[ L ]+) x=L;
else {
k-=(sz[ L ]+);
x = R;
}
}
Splay(x,goal);
}
inline void vist(int x){
if(x){
printf("结点%2d : 左儿子 %2d 右儿子 %2d val:%2d sz=%d cnt:%d\n",x,L,R,val[x],sz[x],cnt[x]);
vist(L);
vist(R);
}
}
void debug() {
puts("");
vist(rt);
puts("");
}
inline void Newnode(int &x,int c,int f){
x=++top;
L = R = ;
pre[x] = f;
sz[x]=; cnt[x]=;
val[x] = c;
}
inline void init(){
ch[][]=ch[][]=pre[]=sz[]=;
rt=top=; cnt[]=;
}
inline void Insert(int &x,int key,int f){
if(!x) {
Newnode(x,key,f);
Splay(x,);//注意插入完成后splay
return ;
}
if(key==val[x]){
cnt[x]++;
sz[x]++;
Splay(x,);//注意插入完成后splay
return ;
}else if(key<val[x]) {
Insert(L,key,x);
} else {
Insert(R,key,x);
}
up(x);
}
void Del_root(){//删除根节点
int t=rt;
if(ch[rt][]) {
rt=ch[rt][];
RTO(,);
ch[rt][]=ch[t][];
if(ch[rt][]) pre[ch[rt][]]=rt;
}
else rt=ch[rt][];
pre[rt]=;
up(rt);
}
void findpre(int x,int key,int &ans){//找前驱节点
if(!x) return ;
if(val[x] <= key){
ans=x;
findpre(R,key,ans);
} else
findpre(L,key,ans);
}
void findsucc(int x,int key,int &ans){//找后继节点
if(!x) return ;
if(val[x]>=key) {
ans=x;
findsucc(L,key,ans);
} else
findsucc(R,key,ans);
}
inline int find_kth(int x,int k){ //第k小的数
if(k<sz[L]+) {
return find_kth(L,k);
}else if(k > sz[ L ] + cnt[x] )
return find_kth(R,k-sz[L]-cnt[x]);
else{
Splay(x,);
return val[x];
}
}
int find(int x,int key){
if(!x) return ;
else if(key < val[x]) return find(L,key);
else if(key > val[x]) return find(R,key);
else return x;
}
int getmin(int x){
while(L) x=L; return val[x];
}
int getmax(int x){
while(R) x=R; return val[x];
}
//确定key的排名
int getrank(int x,int key,int cur){//cur:当前已知比要求元素(key)小的数的个数
if(key == val[x])
return sz[L] + cur + ;
else if(key < val[x])
getrank(L,key,cur);
else
getrank(R,key,cur+sz[L]+cnt[rt]);
}
int get_lt(int x,int key){//小于key的数的个数 lt:less than
if(!x) return ;
if(val[x]>=key) return get_lt(L,key);
return cnt[x]+sz[L]+get_lt(R,key);
}
int get_mt(int x,int key){//大于key的数的个数 mt:more than
if(!x) return ;
if(val[x]<=key) return get_mt(R,key) ;
return cnt[x]+sz[R]+get_mt(L,key);
}
void del(int &x,int f){//删除小于lim的所有的数所在的节点
if(!x) return ;
if(val[x]>=lim){
del(L,x);
} else {
x=R;
pre[x]=f;
if(f==) rt=x;
del(x,f);
}
if(x) up(x);
}
inline void update(){
del(rt,);
}
int get_mt(int key) {
return get_mt(rt,key);
}
int get_lt(int key) {
return get_lt(rt,key);
}
void insert(int key) {
Insert(rt,key,);
}
void Delete(int key) {
int node=find(rt,key);
Splay(node,);
cnt[rt]--;
if(!cnt[rt])Del_root();
}
int kth(int k) {
return find_kth(rt,k);
}
int cnt[maxn];
int val[maxn];
int lim;
}spt;