• 欢迎光临~

# 树状数组学习笔记

## 实现

[begin{aligned} c_1 &= a_1 \ c_2 &= a_1+a_2 \ c_3 &= a_3 \ c_4 &= a_1+a_2+a_3+a_4 \ c_5 &= a_5 \ c_6 &= a_5+a_6 \ c_7 &= a_7 \ c_8 &= a_1+a_2+a_3+a_4+a_5+a_6+a_7+a_8 end{aligned} ]

[c_i=sum_{j=i-2^k+1}^{i}a_i ]

（其实 (c_i) 就是一个特殊的前缀和数组

``````inline int lowbit(int x) {
return x & (~x + 1);
}
``````

（原理可自行感性推导

#### 单点修改

(c_i) 的通项公式，我们可以推出每个包含 (a_i)(c_j)(c_{i+2^k},c_{(i+2^k) +2^k},...)

``````inline void update(int x, int k) {
for (; x <= n; x += lowbit(x))
c[x] += k;
}
``````

#### 区间查询

(c_i) 的通项公式，我们同样可以推出 (a_i) 的前缀和为 (c_i + c_{i-2^{k1}},c_{(i-2^{k1})-2_{k2}},...)

``````inline int query(int x) {
int res = 0;

for (; x; x -= lowbit(x))
res += c[x];

return res;
}
``````

#### P3374 【模板】树状数组 1

``````#include <bits/stdc++.h>
using namespace std;
const int N = 5e5 + 7;

int a[N], c[N];

int n, m;

inline int lowbit(int x) {
return x & (~x + 1);
}

inline void update(int x, int k) {
for (; x <= n; x += lowbit(x))
c[x] += k;
}

inline int query(int x) {
int res = 0;

for (; x; x -= lowbit(x))
res += c[x];

return res;
}

signed main() {
scanf("%d%d", &n, &m);

for (int i = 1; i <= n; ++i) {
scanf("%d", a + i);
update(i, a[i]);
}

for (int opt, x, k, l, r; m; --m) {
scanf("%d", &opt);

if (opt == 1) {
scanf("%d%d", &x, &k);
update(x, k);
}
else {
scanf("%d%d", &l, &r);
printf("%dn", query(r) - query(l - 1));
}
}

return 0;
}
``````

## 扩展：O(n) 建树

``````inline void init() {
for (int i = 1; i <= n; ++i) {
c[i] += a[i];
if(i + lowbit(i) <= n)
c[i + lowbit(i)] += c[i];
}
}
``````

## 扩展：差分树状数组

#### 区间修改

``````inline void update(int l,int r,int val) {
for(;l<=n;l+=lowbit(l))
c[l]+=val;
for(++r;r<=n;r+=lowbit(r))
c[r]-=val;
}
``````

#### 单点查询

``````inline int query(int pos) {
int res = 0;

for (; pos; pos -= lowbit(pos))
res += c[pos];

return res;
}
``````

#### P3368 【模板】树状数组 2

``````#include <bits/stdc++.h>
using namespace std;
const int N = 1e7 + 7;

int c[N];

int n, m;

inline int lowbit(int x) {
return x & (~x + 1);
}

inline void update(int pos, int val) {
for (; pos <= n; pos += lowbit(pos))
c[pos] += val;
}

inline int query(int pos) {
int res = 0;

for (; pos; pos -= lowbit(pos))
res += c[pos];

return res;
}

signed main() {
scanf("%d%d", &n, &m);

for (int i = 1, val; i <= n; ++i) {
scanf("%d", &val);
update(i, val);
update(i + 1, -val);
}

for (int opt, pos, val, l, r; m; --m) {
scanf("%d", &opt);

if (opt == 1) {
scanf("%d%d%d", &l, &r, &val);
update(l, val);
update(r + 1, -val);
}
else {
scanf("%d", &pos);
printf("%dn", query(pos));
}
}

return 0;
}
``````

## 应用

#### P4868 Preprefix sum

Solution：求：

[sum_{i=1}^k S_i ]

Description：化简：

[begin{aligned} &sum_{i=1}^k S_i \ =&sum_{i=1}^k (k-i+1) times a_i \ =&sum_{i=1}^k (k+1) times a_i - sum_{i=1}^k i times a_i \ =&(k+1) times sum_{i=1}^k a_i - sum_{i=1}^k i times a_i end{aligned} ]

Code：

``````#include <bits/stdc++.h>
typedef long long ll;
using namespace std;
const int N = 1e5 + 7;

ll a1[N], a2[N], c1[N], c2[N];

int n, m;

inline int lowbit(int x) {
return x & (~x + 1);
}

inline void update1(int x, ll y) {
for (; x <= n; x += lowbit(x))
c1[x] += y;
}

inline void update2(int x, ll y) {
for (; x <= n; x += lowbit(x))
c2[x] += y;
}

inline ll query1(int x) {
ll res = 0;

for (; x; x -= lowbit(x))
res += c1[x];

return res;
}

inline ll query2(int x) {
ll res = 0;

for (; x; x -= lowbit(x))
res += c2[x];

return res;
}

signed main() {
ios::sync_with_stdio(0),
cin.tie(0), cout.tie(0);
cin >> n >> m;

for (int i = 1; i <= n; ++i) {
cin >> a1[i];
update1(i, a1[i]);
a2[i] = i * a1[i];
update2(i, a2[i]);
}

for (string str; m; --m) {
cin >> str;

if (str == "Query") {
int x;
cin >> x;
cout << (x + 1)*query1(x) - query2(x) << 'n';
}
else {
ll x, y;
cin >> x >> y;
update1(x, y - a1[x]), a1[x] = y;
update2(x, x * y - a2[x]), a2[x] = x * y;
}
}

return 0;
}
``````

#### P1908 逆序对

``````#include <bits/stdc++.h>
typedef long long ll;
using namespace std;
const int N = 5e5 + 7;

int a[N], b[N], c[N];

ll ans;
int n, m;

inline int lowbit(int x) {
return x & (~x + 1);
}

inline void update(int x, int k) {
for (; x <= n; x += lowbit(x))
c[x] += k;
}

inline int query(int x) {
int res = 0;

for (; x; x -= lowbit(x))
res += c[x];

return res;
}

signed main() {
scanf("%d", &n);

for (int i = 1; i <= n; ++i) {
scanf("%d", a + i);
b[i] = a[i];
}

sort(b + 1, b + 1 + n);
int cnt = unique(b + 1, b + 1 + n) - b - 1;

for (int i = 1; i <= n; ++i)
a[i] = lower_bound(b + 1, b + 1 + cnt, a[i]) - b; // 离散化

for (int i = 1; i <= n ; ++i) {
update(a[i], 1);
ans += i - query(a[i]);
}

printf("%lld", ans);

return 0;
}
``````