写小作业的时候重新复习了一下splay

只支持插入,删除,查k大,查节点数。没有迭代器。

T类型需要重载==和<,要调用拷贝构造函数。

template<class T>
class Splay {
private:
struct node {
T v;
node *ch[2], *fa;
int size;
node(const T &a) : size(1), v(a), ch{nullptr, nullptr}, fa(nullptr) {};
void setc(node *r, int c) {
ch[c] = r;
if (r != nullptr) r->fa = this;
}
int pl() {
if (fa != nullptr) return fa->ch[1] == this;
else return 0;
}
void count() {
size = 1;
if (ch[0] != nullptr) size += ch[0]->size;
if (ch[1] != nullptr) size += ch[1]->size;
}
};
node *root; void release(node *r) {
if (r == nullptr) return;
release(r->ch[0]);
release(r->ch[1]);
delete r;
} public:
Splay() : root(nullptr) {} ~Splay() {
release(root);
} private:
void rotate(node *r) {
node *f = r->fa;
int c = r->pl();
if (f == root) r->fa = nullptr, root = r;
else f->fa->setc(r, f->pl());
f->setc(r->ch[c ^ 1], c);
r->setc(f, c ^ 1);
f->count();
} void splay(node *r, node *tar = nullptr) {
for(; r->fa != tar; rotate(r))
if (r->fa->fa != tar) rotate(r->fa->pl() == r->pl() ? r->fa : r);
r->count();
} void rm(node *); public:
int size(); void remove(const T &); void insert(const T &); T *kth(int);
}; template<class T>
int Splay<T>::size() {
if (root == nullptr) return 0;
else return root->size;
} template<class T>
void Splay<T>::rm(node *r) {
node *f = nullptr;
if (r->ch[0] == nullptr && r->ch[1] == nullptr) {
if (r == root) root = nullptr;
else {
f = r->fa;
r->fa->setc(nullptr, r->pl());
delete r;
}
} else if (r->ch[0] == nullptr || r->ch[1] == nullptr) {
int c = r->ch[0] == nullptr;
node *t = r->ch[c];
while (t->ch[c ^ 1] != nullptr) t = t->ch[c ^ 1];
splay(t, r->fa);
r->fa->setc(nullptr, c ^ 1);
f = r->fa;
delete r;
} else {
node *h = r->ch[0], *t = r->ch[1];
while (h->ch[1] != nullptr)
h = h->ch[1];
while (t->ch[0] != nullptr) t = t->ch[0];
splay(h, r->fa);
splay(t, h);
t->setc(nullptr, 0);
delete r;
f = t;
} while (f != nullptr) {
f->count();
f = f->fa;
}
} template<class T>
void Splay<T>::remove(const T &a) {
node *r = root;
while (r != nullptr) {
if (r->v == a) {rm(r); break;}
if (a < r->v) r = r->ch[0];
else r = r->ch[1];
}
} template<class T>
void Splay<T>::insert(const T &a) {
node *r = root;
node *n = new node(a);
if (root == nullptr) {
root = n;
return;
}
while (r != nullptr) {
if (r->v == a) {delete n; break;}
if (a < r->v) {
if (r->ch[0] == nullptr) {
r->setc(n, 0);
splay(n);
break;
}
else
r = r->ch[0];
} else {
if (r->ch[1] == nullptr) {
r->setc(n, 1);
splay(n);
break;
}
else
r = r->ch[1];
}
}
} template<class T>
T *Splay<T>::kth(int k) {
node *r = root;
while (r != nullptr) {
int l_size = 0;
if (r->ch[0] != nullptr) l_size = r->ch[0]->size;
if (l_size >= k) r = r->ch[0];
else if (l_size + 1 == k) return &r->v;
else {
k -= (l_size + 1);
r = r->ch[1];
}
}
return nullptr;
}
05-11 15:34