网上已经有了一些AVL的c++实现,我也学习了几篇,比如这篇我参考的博客https://www.cnblogs.com/maybe2030/p/4732377.html,这篇文章关于AVL的思想的描述没什么问题,但是实现是存在一些bug的,主要出现在树高度的维护和旋转部分。我读了这位作者的实现然后自己撸了一份,在验证过程中发现我的代码通过了测试,但是参考的博客的代码没有通过。测试的办法是我随机乱序生成了0-1023共1024个数,逐次插入这些数到AVL树中,每次插入后我都调用isAVL()来判断树是否保证平衡了,插入结束后逐个删除,同样随后调用isAVL()判断删除是否导致树失衡了。这份代码旨在理清算法,销毁树的内存回收函数和查找函数我都没写,欢迎大家测试代码正确性,如果这份代码有bug,留言我会继续调一下(有bug的概率我觉得很低,毕竟这个测试还是很难的)。
//avltree.h
#ifndef AVLTREE_H_INCLUDED
#define AVLTREE_H_INCLUDED
#include <vector>
using std::vector;
using std::max;
template<class T>
class BNode
{
public:
BNode *left,*right;
int height;
T val;
BNode():left(NULL),right(NULL),height(0),val(0) {}
BNode(T _val):left(NULL),right(NULL),height(1),val(_val) {}
};
template<class T>
class AVLTree
{
private:
BNode<T> *root;
void traversal(BNode<T> *rt,vector<T>&vec);
BNode<T>* AVL_insert(BNode<T> *rt,T x);
BNode<T>* AVL_delete(BNode<T> *rt,T x);
int get_height(BNode<T> *rt);
BNode<T>* rotate_LL(BNode<T> *rt);
BNode<T>* rotate_LR(BNode<T> *rt);
BNode<T>* rotate_RL(BNode<T> *rt);
BNode<T>* rotate_RR(BNode<T> *rt);
// just for testing
bool isAVL(BNode<T> *rt);
public:
AVLTree<T>():root(NULL) {}
//此处应该实现销毁内存操作
~AVLTree<T>() {}
vector<T> traversal();
void insert(T x);
void erase(T x);
bool isAVL();
};
template<class T>
bool AVLTree<T>::isAVL(BNode<T> *rt)
{
if(!rt) return true;
int left_height = rt->left?rt->left->height:0;
int right_height = rt->right?rt->right->height:0;
if(abs(left_height - right_height) <= 1 && isAVL(rt->left) && isAVL(rt->right)) return true;
return false;
}
template<class T>
bool AVLTree<T>::isAVL()
{
return isAVL(root);
}
template<class T>
BNode<T>* AVLTree<T>::rotate_LL(BNode<T> *rt)
{
BNode<T>* lson = rt->left;
rt->left = lson->right;
lson->right = rt;
rt->height = max(get_height(rt->left),get_height(rt->right)) + 1;
lson->height = max(get_height(lson->left),get_height(lson->right)) + 1;
return lson;
}
template<class T>
BNode<T>* AVLTree<T>::rotate_RR(BNode<T> *rt)
{
BNode<T>* rson = rt->right;
rt->right = rson->left;
rson->left = rt;
rt->height = max(get_height(rt->left),get_height(rt->right)) + 1;
rson->height = max(get_height(rson->left),get_height(rson->right)) + 1;
return rson;
}
template<class T>
BNode<T>* AVLTree<T>::rotate_LR(BNode<T> *rt)
{
rt->left = rotate_RR(rt->left);
return rotate_LL(rt);
}
template<class T>
BNode<T>* AVLTree<T>::rotate_RL(BNode<T> *rt)
{
rt->right = rotate_LL(rt->right);
return rotate_RR(rt);
}
template<class T>
int AVLTree<T>::get_height(BNode<T> *rt)
{
if(!rt) return 0;
else return rt->height;
}
template<class T>
void AVLTree<T>::insert(T x)
{
root = AVL_insert(root,x);
}
template<class T>
void AVLTree<T>::erase(T x)
{
root = AVL_delete(root,x);
}
template<class T>
BNode<T> *AVLTree<T>::AVL_insert(BNode<T> *rt,T x)
{
if(!rt)
{
BNode<T> *node = new BNode<T>(x);
if(!root) root = node;
return node;
}
if(x == rt->val) return rt;
if(x < rt->val)
{
rt->left = AVL_insert(rt->left,x);
int left_height = get_height(rt->left);
int right_height = get_height(rt->right);
rt->height = max(left_height,right_height) + 1;
if(left_height - right_height == 2)
{
if(x < rt->left->val) rt = rotate_LL(rt);
else rt = rotate_LR(rt);
}
}
else
{
if(rt->right) rt->right = AVL_insert(rt->right,x);
else
{
BNode<T> *node = new BNode<T>(x);
node->height = 1;
rt->right = node;
}
int left_height = get_height(rt->left);
int right_height = get_height(rt->right);
rt->height = max(left_height,right_height) + 1;
if(right_height - left_height == 2)
{
if(x < rt->right->val) rt = rotate_RL(rt);
else rt = rotate_RR(rt);
}
}
return rt;
}
template<class T>
BNode<T>* AVLTree<T>::AVL_delete(BNode<T> *rt,T x)
{
if(!rt) return NULL;
if(x < rt->val)
{
if(rt->left) rt->left = AVL_delete(rt->left,x);
int left_height = get_height(rt->left);
int right_height = get_height(rt->right);
rt->height = max(left_height,right_height) + 1;
if(right_height - left_height == 2)
{
if(get_height(rt->right->right) >= get_height(rt->right->left)) rt = rotate_RR(rt);
else rt = rotate_RL(rt);
}
}
else if(x == rt->val)
{
if(!rt->left && !rt->right)
{
BNode<T> *tmp = rt;
rt = NULL;
delete tmp;
}
else if(!rt->left || !rt->right)
{
if(rt->left)
{
BNode<T> *left = rt->left;
delete rt;
rt = left;
}
else if(rt->right)
{
BNode<T> *right = rt->right;
delete rt;
rt = right;
}
}
else
{
BNode<T> *tmp = rt->right;
while(tmp->left) tmp = tmp->left;
rt->val = tmp->val;
rt->right = AVL_delete(rt->right,tmp->val);
int left_height = rt->left?rt->left->height:0;
int right_height = rt->right?rt->right->height:0;
rt->height = max(left_height,right_height) + 1;
if(left_height - right_height == 2)
{
if(get_height(rt->left->left) >= get_height(rt->left->right)) rt = rotate_LL(rt);
else rt = rotate_LR(rt);
}
}
}
else
{
if(rt->right) rt->right = AVL_delete(rt->right,x);
int left_height = rt->left?rt->left->height:0;
int right_height = rt->right?rt->right->height:0;
rt->height = max(left_height,right_height) + 1;
if(left_height - right_height == 2)
{
if(get_height(rt->left->left) >= get_height(rt->left->right)) rt = rotate_LL(rt);
else rt = rotate_LR(rt);
}
}
return rt;
}
template<class T>
vector<T> AVLTree<T>::traversal()
{
vector<T> vec;
traversal(root,vec);
return vec;
}
template<class T>
void AVLTree<T>::traversal(BNode<T> *rt,vector<T> &vec)
{
if(!rt) return;
traversal(rt->left,vec);
vec.push_back(rt->val);
traversal(rt->right,vec);
}
#endif // AVLTREE_H_INCLUDED
//main.cpp
#include <iostream>
#include <ctime>
#include <vector>
#include <algorithm>
#include "avltree.h"
using namespace std;
int main()
{
//获取乱序的数组a
int len = 102400;
int* a = new int[len];
srand(time(0));
for(int i = 0; i < len; i++) a[i] = i;
for(int i = 0; i < len; i++)
{
int r = rand()%len;
swap(a[r],a[i]);
}
AVLTree<int> avl;
//逐个插入,每插入一次检验一次算法正确性
for(int i = 0; i < len; i++)
{
avl.insert(a[i]);
if(!avl.isAVL())
{
cout << "not AVL tree after insert" << endl;
}
}
//vector<int> vec = avl.traversal();
//for(int i = 0; i < vec.size(); i++) cout << vec[i] << " ";
cout << endl;
//删除删除,每删除一次验证一次算法正确性
for(int i = 0; i < len; i++)
{
avl.erase(a[i]);
if(!avl.isAVL())
{
cout << "not AVL tree after delete" << endl;
}
}
delete[] a;
return 0;
}