[区间问题] 树状数组与线段树归纳

一、树状数组

1. 定义

用一个数组C[maxn]模拟树状结构,来表示原数组A[maxn]的区间信息。各操作的复杂度为O(lgn)。

2. 解析

  • lowbit(x) 指一个数x的二进制中,最靠右的1的权值大小。如lowbit(6)=2^1, lowbit(7)=2^0,lowbit(8)=8。
  • 树状数组就是用一个C[maxn]数组来存放原数组A[maxn]的去区间信息。具体而言,C[x]存放了从A[(x-lowbit(x))...x]区间中各元素之和。 A[0]为0,不存放有效值。
  • [单点更新 区间查询] 求A[1...x]的区间和就需要求出它所对应的那几个C[j]区间之和,这几个区间通过lowbit向下枚举即可得出;而更改一个值A[i]时就需要更改所有包含了A[i]的C[j]区间,同样可以通过lowit线上枚举来获得。

3. 变式

  • [区间更新 单点查询]
    树状数组中存放 差分值 D[i] = A[i] - A[i - 1],这样区间更新只用更新区间两端,单点查值就是getSum。

  • [区间更新 区间查询]

    1. 维护两个数组
      • sum1[maxn]存放差值 D[1]、D[2]、D[3]... 的树状数组;
      • sum2[maxn]存放 0 * D[1]、1 * D[2]、2 * D[3]... 的树状数组
    2. 更新操作要同时更新 sum1[maxn]、sum2[maxn]
    3. 查询区间的和为:n * ∑D[i] - ∑[(i - 1) * D[i]] = n * sum1[i] - sum2[i]
      • 证明如下:

        A[1] + A[2] + ... + A[n]
        = (D[1]) + (D[1] + D[2]) + ... + (D[1] + D[2] + ... + D[n])
        = n * D[1] + (n - 1) * D[2] + ... + D[n]
        = n * (D[1] + D[2] + ... + D[n]) - (0 * D[1] + 1 * D[2] + ... + (n - 1) * D[n])

4. 模板

  • [单点更新 区间查询]
int A[100010], C[100010]; // 原数组 和 树状数组 A[1...n]有效 

int lowbit(int x){
    return x & -x;
} 

void add(int k, int v){ // 当 A[k] += v时 
    for(int i = k; i <= n; i += lowbit(i)) C[i] += v; 
}

int getSum(int k){ // 求 A[1]+...+A[k]时 
    int res = 0;
    for(int i = k; i; i -= lowbit(i)) res += C[i];
    return res;
}
  • [区间更新 单点查询]
int A[100010], C[100010];

int lowbit(int x){
    return x & -x;
}

void add(int k, int v){  // 不变
    for(int i = k; i <= n; i += lowbit(i)) C[i] += v;
}

int getSum(int k){  // 不变
    int res = 0;
    for(int i = k; i; i -= lowbit(i)) res += C[i];
    return res;
}

int main(){
    ......
    for(int i = 1; i <= n; i ++){  //从1开始放 
        cin >> A[i];
        add(i, A[i] - A[i - 1]);   // 树状数组放差值 
    }
    ......
    add(x, k);  // 在[x, y]区间加上k, 即让这个区间向上突出 
    add(y + 1, -k);
    ......
    int ans = getSum(q);  // 查询q位置的值
    ......
}
  • [区间更新 区间查询]
int A[100010];
int sum1[100010]; // 存放了差值D[1]、D[2]、D[3]...的树状数组 
int sum2[100010]; // 存放了0*D[1]、1*D[2]、2*D[3]...的树状数组 

int lowbit(int x){
    return x & -x;
}

void add(int k, int v){   // 加差值和sum2的同时还要加sum2
    for(int i = k; i <= n; i += lowbit(i)) sum1[i] += v, sum2[i] += (k - 1) * v;
}

int getSum(int k){  //  区间和为 n*∑D[i] - ∑[(i-1)*D[i]] = n*sum1[i] - sum2[i]
    int res = 0;
    for(int i = k; i; i -= lowbit(i)) res += k * sum1[i] - sum2[i];
    return res;
}

int main(){
    for(int i = 1; i <= n; i ++){  //从1开始放 
        cin >> A[i];
        add(i, A[i] - A[i - 1]);   // 树状数组放差值 
    }
    ......
    add(x, k);  //在[x, y]区间加上k 
    add(y + 1, -k);
    ......
    int ans = getSum(q) - getSum(p - 1);  //查询[p, q]位置的值
}

二、线段树

1. 定义

一种强大的能解决区间的修改、查询问题的方法。

2. 解析

与树状数组相比,线段树代码量更大,但是适用范围更广。

  • k为结点编号 根节点为1, 左孩子为 k2, 右孩子为 k2+1
    每个函数都包含参数k

  • 除建树外共4种情形:单点查询 单点更新 区间查询 区间更新(要用到懒标记f)
    一个题目一般只会用到两种

  • 线段树一般保持 左闭右闭 区间,这样代码会好写一点

  • 注意:

    • 无论是更新还是查询,都要down懒标记
    • 一些线段树的变形主要需要特别注意 状态归并 和 懒标记的更新

3. 模板

#include <iostream>
using namespace std;
using ll = long long;
const int maxn = 5e5 + 5;
struct node {
    int l, r, f = 0;
    ll w;
    node() {}
    node(int l, int r, ll w) : l(l), r(r), w(w) {}
} Tree[4 * maxn];
int n, m;
ll a[maxn];

void down(int k) {
    int & v = Tree[k].f;
    Tree[k * 2].w += v * (Tree[k * 2].r - Tree[k * 2].l + 1);
    Tree[k * 2].f += v;
    Tree[k * 2 + 1].w += v * (Tree[k * 2 + 1].r - Tree[k * 2 + 1].l + 1);
    Tree[k * 2 + 1].f += v;
    v = 0;
}

void build(int l, int r, int k) {  // 建树
    if(l == r) {
        Tree[k] = node(l, r, a[l]);
        return;
    }
    int mid = (l + r) / 2;
    build(l, mid, k * 2);
    build(mid + 1, r, k * 2 + 1);
    Tree[k] = node(l, r, Tree[k * 2].w + Tree[k * 2 + 1].w);
}

void update(int x, int v, int k) {  // 单点更新
    if(Tree[k].l == Tree[k].r) {
        Tree[k].w += v;
        return;
    }
    int mid = (Tree[k].l + Tree[k].r) / 2;
    if(x <= mid) update(x, v, k * 2);
    else update(x, v, k * 2 + 1);
    Tree[k].w = Tree[k * 2].w + Tree[k * 2 + 1].w;
}

ll query(int l, int r, int k) {  // 区间查值
    if(l <= Tree[k].l && r >= Tree[k].r) return Tree[k].w;
    if(Tree[k].f) down(k); // 查询时也要记得更新懒标记下移!!!
    int mid = (Tree[k].l + Tree[k].r) / 2;
    ll res = 0;
    if(l <= mid) res += query(l, r, k * 2);
    if(r > mid) res += query(l, r, k * 2 + 1);
    return res;
}

void update(int l, int r, int v, int k) {  // 区间更新
    if(l <= Tree[k].l && r >= Tree[k].r) {
        Tree[k].w += v * (Tree[k].r - Tree[k].l + 1);  // 注意是多个值更新
        Tree[k].f += v;  // 懒标记: 预记儿子们的变化
        return;
    }
    if(Tree[k].f) down(k);  // 懒标记下传
    int mid = (Tree[k].l + Tree[k].r) / 2;
    if(l <= mid) update(l, r, v, k * 2);
    if(r > mid) update(l, r, v, k * 2 + 1);
    Tree[k].w = Tree[k * 2].w + Tree[k * 2 + 1].w;
}

ll query(int x, int k) {  // 单点查值
    if(Tree[k].l == Tree[k].r) return Tree[k].w;
    if(Tree[k].f) down(k); // 查询时也要记得更新懒标记下移!!!
    int mid = (Tree[k].l + Tree[k].r) / 2;
    if(x <= mid) return query(x, k * 2);
    else return query(x, k * 2 + 1);
}

int main() {
    ios::sync_with_stdio(0), cin.tie(0);
    cin >> n >> m;
    for(int i = 1; i <= n; i ++) cin >> a[i];
    build(1, n, 1);
    while(m --) {
        int f, x, y, k;
        // 单点更新 + 区间查值
        // cin >> f;
        // if(f == 1) cin >> x >> y, update(x, y, 1);
        // else cin >> x >> y, cout << query(x, y, 1) << endl;

        // 区间更新 + 单点查值
        // cin >> f;
        // if(f == 1) cin >> x >> y >> k, update(x, y, k, 1);
        // else cin >> x, cout << query(x, 1) << endl;

        // 区间更新 + 区间查值
        cin >> f;
        if(f == 1) cin >> x >> y >> k, update(x, y, k, 1);
        else cin >> x >> y, cout << query(x, y, 1) << endl;
    }
}
Re:从零开始的归纳时间 文章被收录于专栏

在刷题的同时,常常会发现一些具有相似性和模板题,但是经常会混淆或是记不起来 于是乎就想着归纳归纳吧~ 这样脑袋好受一些~ 若有参考了其他大神的博客,文末都将予以备注

全部评论

相关推荐

05-22 12:44
已编辑
门头沟学院 golang
点赞 评论 收藏
分享
评论
点赞
收藏
分享

创作者周榜

更多
牛客网
牛客网在线编程
牛客网题解
牛客企业服务