主席树笔记
- By BigYellowDog
- 前置芝士:线段树、前缀和、最好还有平衡树。
导入
- 主席树是什么?它可以干嘛?为什么要用它?
- 主席树这个名字跟其功能没有关系。名字来源据说是hjt大牛发明的,根据其网名命名的。
- 主席树真正的名字叫可持久化线段树。
- 假设现在有这么一个问题:现有一个序列,q次询问。每次问区间[l, r]里的最大值。
- 很简单啊,线段树 / st表就好。
- 那改一下,每次询问区间[l, r]里的次大值。
- 很简单啊,线段树就好,只不过多维护一个值而已。
- 那再改一下,每次询问区间[l, r]里的第k大值。
- ... ...
- emmm反正我是写不出来了。
那么主席树就是解决求区间第k大此类问题的。
当然,区间第k大还可以用整体二分。但主席树易理解,码量也不大。所以是一种性价比很高的算法。
原理
- 现在有一个序列3 5 4 1 2,请找出区间[1, 4]的第2大。
ok现在来模拟主席树算法过程。
建出一个空树(树0),如下图。
棕色数字表示节点的值。
- (1~5)这个节点表示1 <= 值 <= 5的个数
- (1~3)这个节点表示1 <= 值 <= 3的个数
... ...注意这个一颗权值线段树。点的左右端点表示权值的边界。
- 插入序列中第一个数3,图就变成了(树1)
- 插入序列中第二个数5,图就变成了(树2)
- 一直插完序列中第4个数1后,图就变成了(树4)
- 到这里打住,我们已经插入了前4个数,那么就可以查询区间[1, 4]的第k大了。
- 首先进入(1~5)节点,发现其左儿子的子树个数2 <= k(2),于是进入(1~3)节点
- 发现(1~3)节点的左儿子子树个数1 < k(2),于是进入(3~3)节点,k更新为2 - 1 = 1
- 走到头了,于是返回节点的左端点3。于是区间[1, 4]的第2小就是3
- 为什么返回左端点?因为这是一颗权值线段树。
- 上述在权值线段树里找第k大的过程就是平衡树找第k大的过程
- 那现在,请找出区间[2, 4]里的第2大值。
- 很容易啊,我们取出刚刚建出的树4和树1。拿树4 减 树1即可得到一个新树。
- 减,即拿每个对应节点相减。
- 在这个新树上找第2大即可。
- 这就是一个前缀和啊。
- 那利用主席树解决区间第k大的方法就出来了。
- 依次插入序列中的数,每插入一个树建立一个新树。
- 若查询区间[l, r],则拿出树r和树(l - 1)相减得到新树,在新树上找第k大。
实现
- 原理就是这样啦,但是如果每插入一个树就建立一颗新树,那岂不是空间爆炸?
- 是的,原理很简单,关键就是实现,这也是发明者的精妙之处。
- 首先可以发现一个性质,每插入一个数,相对于上一棵树来说,只会有从根节点到叶节点的一条链上的点的值发生了变化。其它都是不变的!
- 如图,这是上述序列插入3后的图,当插入4后只有蓝色路径上的节点的值会++
- 那我们每建立一颗新树就不用重新建了,而是值改变的点就重建,没改变的连到上一个树上去即可。那么上述的图可以变成
- 带 ' 的表示是插入4后的,不带 ' 的是插入3后的树。
- 这样建树的空间开销就大大减少了。
代码
- 搬来一道模板题
cin >> n >> m;
for(int i = 1; i <= n; i++)
a[i] = read(), b[++cnt] = a[i];
sort(b + 1, b + 1 + cnt);
cnt = unique(b + 1, b + 1 + cnt) - b - 1;
/**
* 首先输入序列中的每一个数,因为权值很大,我们又要按照权值建树,那么就先离散化下
* a是原序列,b是离散化后的数组,cnt是离散化的不同权值个数
*/
r[0] = build(1, cnt); //建一个空树, r[0]表示第0棵树的根节点
int build(int l, int r) //建树函数,各位都懂
{
int p = ++dex, mid = l + r >> 1;
if(l == r) return p;
t[p].l = build(l, mid);
t[p].r = build(mid + 1, r);
return p;
}
for(int i = 1; i <= n; i++) //这里就是每插入一个数建一个树的过程了
r[i] = upd(r[i - 1], 1, cnt, find(a[i]));
int upd(int las, int l, int r, int val)
{
int p = ++dex, mid = l + r >> 1;
t[p].l = t[las].l, t[p].r = t[las].r; //首先都给它连上上一棵树的节点
t[p].sum = t[las].sum + 1;
if(l == r) return p;
if(val <= mid) t[p].l = upd(t[las].l, l, mid, val); //说明左子树发生了改变
else t[p].r = upd(t[las].r, mid + 1, r, val); //说明右子树发生了改变
return p;
}
for(int i = 1; i <= m; i++)
{
int ll = read(), rr = read(), rank = read();
printf("%d\n", b[ask(r[ll - 1], r[rr], 1, cnt, rank)]);
}
int ask(int u, int v, int l, int r, int rank) //取出了树u和树v
{
if(l == r) return l;
int size = t[t[v].l].sum - t[t[u].l].sum, mid = l + r >> 1; //得到新树的左儿子子树个数
if(rank <= size) return ask(t[u].l, t[v].l, l, mid, rank);
else return ask(t[u].r, t[v].r, mid + 1, r, rank - size);
}
- 最后是完整代码:链接