树状数组(BIT)—— 一篇就够了
前言、内容梗概
本文旨在讲解:
- 树状数组的原理(起源,原理,模板代码与需要注意的一些知识点)
- 树状数组的优势,缺点,与比较(eg:线段树)
- 树状数组的经典例题及其技巧(普通离散化,二分查找离散化)
Fenwick便开始思考树状数组的经典例题及其技巧
模板题:单点修改,区间查询
思路:
非常简单,只需要套模板即可。
代码:
// 上述模板部分省略
using ll = long long;
const int maxn = 1e6+50;
ll f[maxn];
int main(){
ios::sync_with_stdio(0);
cin.tie(0);
int n; cin >> n;
int q; cin >> q;
for (int i = 1; i <= n; ++ i) cin >> f[i];
BIT<ll> bit(f, n);
for (int i = 0; i < q; ++ i){
int type; cin >> type;
if (type == 1){
int i, x;
cin >> i >> x;
bit.add(i, (ll) x);
}else {
int l, r;
cin >> l >> r;
cout << bit.sum(l, r) << '\n';
}
}
return 0;
}
模板题:区间修改,区间查询
思路:
该模板题则难上许多,需要对问题分析建模。
我们需要考虑如何建模表示 \(tree\) 数组。
首先,设更新操作为:在 \([l, r]\) 上增加 \(x\)。我们考虑如何建模维护新的区间前缀和 \(c^{\prime}[i]\)。
下面分情况讨论:
- \(i < l\)
这种情况下,不需要任何处理, \(c^{\prime}[i] = c[i]\)
- \(l <= i <= r\)
这种情况下,\(c^{\prime}[i] = c[i] + (i - l + 1) \cdot x\)
- \(i > r\)
这种情况下,\(c^{\prime}[i]=c[i] + (r-l+1)\cdot x\)
因此如下图所示,我们可以设两个 BIT,那么\(c^{\prime}[i] = \mathrm{sum(bit_1,i)+sum(bit_2,i) \cdot i}\),对于区间修改等价于:
- 在 \(bit_1\) 的 \(l\) 位置加上 \(-x(l-1)\),在 \(bit_1\) 的 \(r\) 位置加上 \(rx\)。
- 在 \(bit_2\) 的 \(l\) 位置加上 \(x\) 的 \(r\) 位置加上 \(-x\)。
代码
#include <bits/stdc++.h>
using namespace std;
// 模板代码省略
// 这里做的是单点查询,但是实现的为区间查询
using ll = long long;
ll get_sum(BIT<ll> &a, BIT<ll> &b, int l, int r){
auto sum1 = a.sum(r) * r + b.sum(r);
auto sum2 = a.sum(l - 1) * (l - 1) + b.sum(l - 1);
return sum1 - sum2;
}
int n, q;
const int maxn = 1e6 + 50;
ll f[maxn];
int main(){
// ios::sync_with_stdio(0);
// cin.tie(0);
cin >> n >> q;
BIT<ll> bit1, bit2;
for (int i = 1; i <= n; ++ i) cin >> f[i];
bit1.init(n), bit2.init(f, n);
for (int i = 0; i < q; ++ i){
int type; cin >> type;
if (type == 1){
int l, r, x;
cin >> l >> r >> x;
bit2.add(l, (ll) -1 * (l - 1) * x), bit2.add(r + 1, (ll) r * x);
bit1.add(l, (ll) x), bit1.add(r + 1, (ll) -1 * x);
}else {
int i; cin >> i;
cout << get_sum(bit1, bit2, i, i) << '\n';
}
}
return 0;
}
逆序对 简单版
思路
BIT 求解逆序对是非常方便的,在初学时我没有想到过 BIT 还能用于求解逆序对。在这里我借逆序对来引出一个小技巧:离散化
BIT 求逆序对的方法非常简单,逆序对指:i < j and a[i] > a[j]
,统计逆序对实际上就是统计在该元素 a[i]
之前有多少元素大于他。
我们可以初始化一个大小为 \(maxn\) 的空 BIT(全为0)。随后:
- 我们顺序访问数组中的每个元素
a[i]
,计算区间[1, a[i]]
的和,更新答案ans = i - sum([1, a[i]])
- 然后,我们更新 BIT 中坐标
a[i]
的值,tree[a[i]] <- tree[a[i]] + 1
举个例子:
eg: [2,1,3,4]
BIT: 0, 0, 0, 0
>2, sum(2) = 0, ans += 0 - sum(2) -> ans = 0
BIT: 0, 1, 0, 0
>1, sum(1) = 0, ans += 1 - sum(1) -> ans = 1
BIT: 1, 1, 0, 0
>3, sum(3) = 2, ans += 2 - sum(3) -> ans = 1
BIT: 1, 1, 1, 0
>4, sum(4) = 3, ans += 3 - sum(4) -> ans = 1
实际上,便是借助 BIT 高效计算前缀和的性质实现了快速打标记,先统计在我之前有多少个标记(这些都是合法对),再将自己所在位置的标记加 \(1\)。
因此,很容易写出这段代码:
代码一
// 仅保留核心代码
int reversePairs(vector<int>& nums) {
int n = nums.size();
if (n == 0) return 0;
int mx = *max_element(nums.begin(), nums.end());
BIT<int> bit(mx); // 因为最大只到最大值的位置
int ans(0);
for (int i = 0; i < n; ++ i){
ans += (i - bit.sum(nums[i]));
bit.add(nums[i], 1);
}
return ans;
}
但是这个代码有非常严重的问题,首先假如 mx = 1e9
就会出现段错误;或者假如 nums[i] < 0
则会出现访问越界的问题,但是实际上题目中说明了:数组最多只有 50000个元素,也就是我们需要想办法将坐标离散化,保留其大小顺序即可。
代码二
#define lb lower_bound
#define all(x) x.begin(), x.end()
const int maxn = 5e4 + 50;
struct node{
int v, id;
}f[maxn]; // 离散化结构体
int arr[maxn];
bool cmp(const node&a, const node &b){
return a.v < b.v;
}
class Solution {
public:
int reversePairs(vector<int>& nums) {
int n = nums.size();
if (n == 0) return 0;
BIT<int> bit(n);
for (int i = 1; i <= n; ++ i){
f[i].v = nums[i - 1], f[i].id = i; // 赋值用于排序
}
sort(f + 1, f + 1 + n, cmp);
int cnt = 1, i = 1;
while (i <= n){
/* 用于去重,当有相同元素时其对应的 cnt 应该相同 */
if (f[i].v == f[i - 1].v || i == 1) arr[f[i].id] = cnt;
else arr[f[i].id] = ++cnt;
++ i;
}
int ans = 0;
for (int i = 0; i < n; ++ i){
int pos = arr[i + 1];
ans += i - bit.sum(pos);
bit.add(pos, 1);
}
return ans;
}
};
上面的方法是离散化操作的一种方式,有一点复杂,需要注意的细节比较多。
实际上,该方法便是通过保留每个元素的所在位置,并将其排序,排序后自己在第 \(i\) 个则将其值 arr[id] = i
离散化为 \(i\) 。这样既可以避免负数,过大的数造成的访问或者内存错误,也充分的保留了各元素之间的大小关系。
离散化的复杂度为 \(\mathcal{O(\log n)}\) ,实际上也就是排序的复杂度。
可以发现,结构体方法对于空间要求较大,且在去重方面需要下功夫,稍后我们会讲解另一种离散化方法,你也可以试试用后文的离散化方法再次解决这题。
逆序对加强版: 翻转对
思路
可以看到这题与逆序对的区别在于,翻转对的定义是:i < j
且 a[i] > 2*a[j]
。其大小关系发生了变化,不再是原来单纯的大小关系,而存在值的变化。
我们可以思考下能否用结构体进行离散化,简单思考后发现:假如第 i
个元素离散化之后的编号为 id1
,则我们无法确定编号为 2 * id1
所对应元素的 val
值之间的关系。可能出现如下情况:
id1 = 1, val = 2
2 * id1 = 2, val' = 3
所以,我们需要思考一个新的方法来进行离散化。需要注意的是,我们的关键点在于:如何快速的询问一个元素在一个数组中是第几大的元素。比如,在数组中快速询问某个值的两倍是第几大的。
实际上,稍微有基础的话答案便非常清晰:二分查找,我们可以首先将数组进行排序,利用 \(lower_{bound}\) 快速找到第一个大于等于该元素所对应的位置,用代码来说的话:pos = lower_bound(nums.begin(), nums.end(), x) - nums.begin() + 1
。
eg: nums = [3, 2, 4, 7]
farr = sort(nums) -> farr = [2, 3, 4, 7]
pos(4) = lower_bound(..., 4) - farr.begin() + 1 = 3
便可以快速找到 4 的编号为 3 (1-index)
但是,有一个问题需要注意:
eg: nums = [3, 2, 5, 7]
farr = sort(nums) -> farr = [2, 3, 5, 7]
pos(4) = lower_bound(...,4) - farr.beign() + 1 = 3
但实际上,5 > 4,这次询问错误了!!!
为什么会出现询问错误的情况呢?(因此我们需要找到的是最后一个小于等于元素 x
的对应位置,而二分查找是大于等于 x
的第一个元素,当原数组中不存在 x
时,便会出现询问出错的情况。)
有多种方法可以解决这个问题,但是最为方便的还是直接将需要查询的元素全部加进去,也就是 2 * x
全部添加到数组中,从而保证一定存在该元素,又因为 lower_bound
的性质,我们无需去重。
代码
using vi = vector<int>;
using vl = vector<ll>;
#define complete_unique(x) (x.erase(unique(x.begin(), x.end()), x.end()))
#define lb lower_bound
class Solution {
public:
int reversePairs(vector<int>& nums) {
vl tarr;
for (auto &e: nums){
tarr.push_back(e);
tarr.push_back(2ll * e); // 直接把需要离散化的对应元素加入
}
sort(tarr.begin(), tarr.end());
int n = nums.size();
BIT<int> bit(2 * n); // 注意,因为加入了两倍的元素,所以对应也要开大一点
int res = 0;
for (int i = 0; i < n; ++ i){
res += i - bit.sum(lb(tarr.begin(), tarr.end(), 2ll * nums[i]) - tarr.begin() + 1);
bit.add(lb(tarr.begin(), tarr.end(), nums[i]) - tarr.begin() + 1, 1);
}
return res;
}
};
二维BIT:区间查询,单点修改
思路
二维 BIT 实际上就是套娃,一层层套即可。
其复杂度为 \(\mathcal{O(\log n \times \log m)}\) ,\(n,m\)分别为每个维度 BIT 的个数,这里不再赘述。
代码
#include <bits/stdc++.h>
using namespace std;
// 模板代码省略
using ll = long long;
int n, m, q;
const int maxn = 5e3 + 50;
BIT<ll> f[maxn]; // 二维BIT
void add(int i, int j, ll x){
while (i <= n){
f[i].add(j, x);
i += lowbit(i);
}
}
ll sum(int i, int j){
ll res(0);
while (i > 0){
res += f[i].sum(j);
i -= lowbit(i);
}
return res;
}
signed main(){
ios::sync_with_stdio(0);
cin.tie(0);
cin >> n >> m;
for (int i = 1; i <= n; ++ i) f[i] = BIT<ll>(m);
int type;
while (cin >> type){
if (type == 1){
int x, y, k; cin >> x >> y >> k;
add(x, y, (ll) k);
}else {
int a, b, c, d; cin >> a >> b >> c >> d;
cout << sum(c, d) - sum(c, b - 1) - sum(a - 1, d) + sum(a - 1, b - 1) << '\n';
}
}
return 0;
}
后记
这是我耗时最长的一篇博客,也是我花费心血最多的一次,也希望自己能好好掌握 BIT
附上参考链接: