概念
树状数组,又称二叉索引树,是一种代码简单、应用广泛、unbelievable的数据结构
它大概长这样:
实现
观察上图,显然:
[begin{aligned}
c_1 &= a_1
\
c_2 &= a_1+a_2
\
c_3 &= a_3
\
c_4 &= a_1+a_2+a_3+a_4
\
c_5 &= a_5
\
c_6 &= a_5+a_6
\
c_7 &= a_7
\
c_8 &= a_1+a_2+a_3+a_4+a_5+a_6+a_7+a_8
end{aligned}
]
不难发现(其中 (k) 为 (i) 的二进制中从最低位到高位连续零的长度):
[c_i=sum_{j=i-2^k+1}^{i}a_i
]
(其实 (c_i) 就是一个特殊的前缀和数组
我们有了 (c_i) 的通项公式后,还会发现一个问题:如何求出每个 (i) 所对应的 (k)
Answer:找 (i) 最低位的 (1) 即可
我们引入一个求某个数最低位的函数: lowbit(x)
inline int lowbit(int x) {
return x & (~x + 1);
}
(原理可自行感性推导
单点修改
由 (c_i) 的通项公式,我们可以推出每个包含 (a_i) 的 (c_j) 为 (c_{i+2^k},c_{(i+2^k) +2^k},...)
其中 (2^k) 为 lowbit(i)
写成代码,就是这样:
inline void update(int x, int k) {
for (; x <= n; x += lowbit(x))
c[x] += k;
}
区间查询
由 (c_i) 的通项公式,我们同样可以推出 (a_i) 的前缀和为 (c_i + c_{i-2^{k1}},c_{(i-2^{k1})-2_{k2}},...)
写成代码,就是这样:
inline int query(int x) {
int res = 0;
for (; x; x -= lowbit(x))
res += c[x];
return res;
}
这样区间 ([l,r]) 的和就是 query(r)-query(l-1)
P3374 【模板】树状数组 1
#include <bits/stdc++.h>
using namespace std;
const int N = 5e5 + 7;
int a[N], c[N];
int n, m;
inline int lowbit(int x) {
return x & (~x + 1);
}
inline void update(int x, int k) {
for (; x <= n; x += lowbit(x))
c[x] += k;
}
inline int query(int x) {
int res = 0;
for (; x; x -= lowbit(x))
res += c[x];
return res;
}
signed main() {
scanf("%d%d", &n, &m);
for (int i = 1; i <= n; ++i) {
scanf("%d", a + i);
update(i, a[i]);
}
for (int opt, x, k, l, r; m; --m) {
scanf("%d", &opt);
if (opt == 1) {
scanf("%d%d", &x, &k);
update(x, k);
}
else {
scanf("%d%d", &l, &r);
printf("%dn", query(r) - query(l - 1));
}
}
return 0;
}
扩展:O(n) 建树
由于每一个节点的值是所有与自己直接相连的儿子的值的,因此可以倒着考虑贡献,即每次确定完儿子的值后,用自己的值更新自己的直接父亲
inline void init() {
for (int i = 1; i <= n; ++i) {
c[i] += a[i];
if(i + lowbit(i) <= n)
c[i + lowbit(i)] += c[i];
}
}
扩展:差分树状数组
前置芝士:差分
差分树状数组,即将普通树状数组中的前缀和数组改为差分数组
区间修改
由差分区间修改的性质,我们容易写出代码
inline void update(int l,int r,int val) {
for(;l<=n;l+=lowbit(l))
c[l]+=val;
for(++r;r<=n;r+=lowbit(r))
c[r]-=val;
}
单点查询
由差分单点查询的性质,我们容易写出代码
inline int query(int pos) {
int res = 0;
for (; pos; pos -= lowbit(pos))
res += c[pos];
return res;
}
P3368 【模板】树状数组 2
#include <bits/stdc++.h>
using namespace std;
const int N = 1e7 + 7;
int c[N];
int n, m;
inline int lowbit(int x) {
return x & (~x + 1);
}
inline void update(int pos, int val) {
for (; pos <= n; pos += lowbit(pos))
c[pos] += val;
}
inline int query(int pos) {
int res = 0;
for (; pos; pos -= lowbit(pos))
res += c[pos];
return res;
}
signed main() {
scanf("%d%d", &n, &m);
for (int i = 1, val; i <= n; ++i) {
scanf("%d", &val);
update(i, val);
update(i + 1, -val);
}
for (int opt, pos, val, l, r; m; --m) {
scanf("%d", &opt);
if (opt == 1) {
scanf("%d%d%d", &l, &r, &val);
update(l, val);
update(r + 1, -val);
}
else {
scanf("%d", &pos);
printf("%dn", query(pos));
}
}
return 0;
}
应用
P4868 Preprefix sum
Solution:求:
[sum_{i=1}^k S_i
]
Description:化简:
[begin{aligned}
&sum_{i=1}^k S_i
\
=&sum_{i=1}^k (k-i+1) times a_i
\
=&sum_{i=1}^k (k+1) times a_i - sum_{i=1}^k i times a_i
\
=&(k+1) times sum_{i=1}^k a_i - sum_{i=1}^k i times a_i
end{aligned}
]
对于 $sum_{i=1}^k a_i $ 与 $ sum_{i=1}^k i times a_i$ ,我们可以用树状数组维护
Code:
#include <bits/stdc++.h>
typedef long long ll;
using namespace std;
const int N = 1e5 + 7;
ll a1[N], a2[N], c1[N], c2[N];
int n, m;
inline int lowbit(int x) {
return x & (~x + 1);
}
inline void update1(int x, ll y) {
for (; x <= n; x += lowbit(x))
c1[x] += y;
}
inline void update2(int x, ll y) {
for (; x <= n; x += lowbit(x))
c2[x] += y;
}
inline ll query1(int x) {
ll res = 0;
for (; x; x -= lowbit(x))
res += c1[x];
return res;
}
inline ll query2(int x) {
ll res = 0;
for (; x; x -= lowbit(x))
res += c2[x];
return res;
}
signed main() {
ios::sync_with_stdio(0),
cin.tie(0), cout.tie(0);
cin >> n >> m;
for (int i = 1; i <= n; ++i) {
cin >> a1[i];
update1(i, a1[i]);
a2[i] = i * a1[i];
update2(i, a2[i]);
}
for (string str; m; --m) {
cin >> str;
if (str == "Query") {
int x;
cin >> x;
cout << (x + 1)*query1(x) - query2(x) << 'n';
}
else {
ll x, y;
cin >> x >> y;
update1(x, y - a1[x]), a1[x] = y;
update2(x, x * y - a2[x]), a2[x] = x * y;
}
}
return 0;
}
P1908 逆序对
我们将原数列离散化,然后将每个数一次插入树状数组中,每次插入时记录一下前面有多少个比它大的数即可
#include <bits/stdc++.h>
typedef long long ll;
using namespace std;
const int N = 5e5 + 7;
int a[N], b[N], c[N];
ll ans;
int n, m;
inline int lowbit(int x) {
return x & (~x + 1);
}
inline void update(int x, int k) {
for (; x <= n; x += lowbit(x))
c[x] += k;
}
inline int query(int x) {
int res = 0;
for (; x; x -= lowbit(x))
res += c[x];
return res;
}
signed main() {
scanf("%d", &n);
for (int i = 1; i <= n; ++i) {
scanf("%d", a + i);
b[i] = a[i];
}
sort(b + 1, b + 1 + n);
int cnt = unique(b + 1, b + 1 + n) - b - 1;
for (int i = 1; i <= n; ++i)
a[i] = lower_bound(b + 1, b + 1 + cnt, a[i]) - b; // 离散化
for (int i = 1; i <= n ; ++i) {
update(a[i], 1);
ans += i - query(a[i]);
}
printf("%lld", ans);
return 0;
}