【数据结构与算法】线段树篇二

之前我已经使用线段树完成了简单的题目,现在我要学习下线段树的高阶用法,支持区间更新与区间查询。

见:http://harryguo.me/2016/01/22/%E7%BA%BF%E6%AE%B5%E6%A0%91%E8%AE%B2%E8%A7%A3%E4%BA%8C/(HarryGuo)讲的通俗易懂!!

懒操作

懒操作是线段树的核心操作,这是线段树之所以是log级算法的原因。我们可以注意到,对于线段树上的一个节点,如果我们更新的区间已经完全覆盖了这个区间,那么我们就没必要把这个节点的子节点都更新了,如果我们询问到这个节点的时候,我们再将堆积在这里的信息往下传。
这样操作的复杂度很明显是 O(logn) 的。结构和单点更新类似,只是多了一个懒操作的标记,随之带来的变化是增加了 pushDown()函数,update()和query()函数里相应位置 调用pushDown()。

使用 hihocoder #1078 : 线段树的区间修改 练手:

#include <cstdio>
#define MAXN 100005

int data[MAXN];
struct Node
{
    int val;
    int lazy;
}tree[MAXN * 4];

void pushDown(int v, int len)
{
    if(tree[v].lazy != -1)                    //判断当前节点是否为懒处理节点
    {    
        tree[v << 1].lazy = tree[v].lazy;
        tree[v << 1 | 1].lazy = tree[v].lazy;
        tree[v << 1].val = (len + 1) / 2 * tree[v].lazy;
        tree[v << 1 | 1].val = len / 2 * tree[v].lazy;
        tree[v].lazy = -1;
    }
}

void pushUp(int v)
{
    tree[v].val = tree[v << 1].val + tree[v << 1 | 1].val;
}

void build(int left, int right, int v)
{
    if(left == right)
    {
        tree[v].val = data[left];
        tree[v].lazy = -1;
    }
    else
    {
        int m = (left + right) >> 1;
        build(left, m, v << 1);
        build(m + 1, right, v << 1 | 1);
        pushUp(v);
        tree[v].lazy = -1;
    }
}

void update(int a, int b, int p, int l, int r, int v)
{
    if(a <= l && r <= b)
    {
        tree[v].val = p * (r - l + 1);
        tree[v].lazy = p;
    }
    else
    {
        int m = (l + r) >> 1;
        pushDown(v, r - l + 1);
        if(m >= a)  update(a, b, p, l, m, v << 1);
        if(m < b)   update(a, b, p, m + 1, r, v << 1 | 1);
        pushUp(v);
    }
}

int query(int a, int b, int left, int right, int v)
{
    if(a <= left && right <= b)
        return tree[v].val;
    else
    {
        int sum = 0;
        int m = (left + right) >> 1;
        pushDown(v, right - left + 1);
        
        if(m >= a)  sum += query(a, b, left, m, v << 1);
        if(m < b)   sum += query(a, b, m + 1, right, v << 1 | 1);
        return sum;
    }
}

int main()
{
    int i, command, left, right, NewP, N, Q;
    scanf("%d", &N);
    if(!N)
        return 0;
    for(i = 1; i <= N; ++i)
        scanf("%d", &data[i]);
    build(1, N, 1);
    scanf("%d", &Q);
    while(Q--)
    {
        scanf("%d", &command);
        if(command)
        {
            scanf("%d %d %d", &left, &right, &NewP);
            update(left, right, NewP, 1, N, 1);
        }
        else
        {
            scanf("%d %d", &left, &right);
            printf("%d\n", query(left, right, 1, N, 1));
        }
    }
    
    return 0;
}


既然又做到了线段树,打算把 leetcode上面的相关题目也做掉加强记忆。


leetcode #303. Range Sum Query

简单的线段树,只需实现 query() 和 build() 函数。

class NumArray {
public:
    int mysum[1000000];
    int N;
    
    void pushUp(int v)
    {
        mysum[v] = mysum[v << 1] + mysum[v << 1 | 1];
    }
    
    void build(vector<int> &nums, int left, int right, int v)
    {
        if(left == right)
            mysum[v] = nums[left];
        else
        {
            int m = (left + right) >> 1;
            build(nums, left, m, v << 1);
            build(nums, m + 1, right, v << 1 | 1);
            pushUp(v);
        }
    }
    
    int query(int a, int b, int left, int right, int v)
    {
        if(a <= left && right <= b)
            return mysum[v];
        else
        {
            int sum = 0;
            int m = (left + right) >> 1;
            if(m >= a) sum += query(a, b, left, m, v << 1);
            if(m < b) sum += query(a, b, m + 1, right, v << 1 | 1);
            return sum;
        }
    }
    
    NumArray(vector<int> &nums) {
        N = nums.size();
        if(N)
            build(nums, 0, N - 1, 1);
    }

    int sumRange(int i, int j) {
        if(N)
            return query(i, j, 0, N - 1, 1);
        return 0;
    }
};


// Your NumArray object will be instantiated and called as such:
// NumArray numArray(nums);
// numArray.sumRange(0, 1);
// numArray.sumRange(1, 2);

这道题需要注意的是以前我们的数据下标从 1 ~ N,此处为 0 ~ N - 1,还要注意由于我们使用 v << 1 和 v << 1 | 1计算儿子下标,所以 v 的初始值不能从0开始,一定要从 1 开始。

然后是 leetcode #307 Range Sum Query - Mutable

只需要在上一题的基础上增加 update() 函数,完整代码如下:

class NumArray {
public:
    int mysum[1000000];
    int N;
    
    void pushUp(int v)
    {
        mysum[v] = mysum[v << 1] + mysum[v << 1 | 1];
    }
    
    void build(vector<int> &nums, int left, int right, int v)
    {
        if(left == right)
            mysum[v] = nums[left];
        else
        {
            int m = (left + right) >> 1;
            build(nums, left, m, v << 1);
            build(nums, m + 1, right, v << 1 | 1);
            pushUp(v);
        }
    }
    
    int query(int a, int b, int left, int right, int v)
    {
        if(a <= left && right <= b)
            return mysum[v];
        else
        {
            int sum = 0;
            int m = (left + right) >> 1;
            if(m >= a) sum += query(a, b, left, m, v << 1);
            if(m < b) sum += query(a, b, m + 1, right, v << 1 | 1);
            return sum;
        }
    }
    
    
    NumArray(vector<int> &nums) {
        N = nums.size();
        if(N)
            build(nums, 0, N - 1, 1);
    }

    int sumRange(int i, int j) {
        if(N)
            return query(i, j, 0, N - 1, 1);
        return 0;
    }
    
    void myupdate(int a, int b, int left, int right, int v)
    {
        if(left == right)
            mysum[v] = b;
        else
        {
            int m = (left + right) >> 1;
            if(a <= m) 
                myupdate(a, b, left, m, v << 1);
            else
                myupdate(a, b, m + 1, right, v << 1 | 1);
            pushUp(v);
        }
    }
    
    void update(int i, int val) {
        myupdate(i, val, 0, N - 1, 1);
    }
};

由于我们的方法用到了递归,题解函数不能递归,无奈只能另写可以递归的函数以备调用。



还有一个二维的题目:leetcode #304. Range Sum Query 2D - Immutable

这道题本来是打算就是简单的对一维线段树的简单重复使用,但是这样就没有意义了,不(gu)小(yi)心(de)点开该题目的 editorial solution ,瞟了一眼,发现其解法很巧妙地利用了:

sum(y1, x1, y2, x2) = sum(0, 0, y2, x2) - sum(0, 0, y2, x1 - 1) - sum(0, 0, y1 - 1, x2) + sum(0, 0, y1- 1, x1 - 1)

使用该公式完成预处理和查询。代码如下:

class NumMatrix {
public:
    int mysum[1000][1000];
    int rowNum, colNum;
    NumMatrix(vector<vector<int>> &matrix) {
        
        int i, j;
        rowNum = matrix.size();
        
        if(rowNum)
        {   
            colNum = matrix[0].size();               //此处一定注意:如果矩阵为空,这句话会导致程序崩溃
            for(i = 1; i <= rowNum; ++i)
                for(j = 1; j <= colNum; ++j)
                    mysum[i][j] = mysum[i - 1][j] + mysum[i][j - 1] + matrix[i - 1][j - 1] - mysum[i - 1][j - 1];
        }
    }

    int sumRegion(int row1, int col1, int row2, int col2) {
        if(row2 < rowNum && col2 < colNum)
            return mysum[row2 + 1][col2 + 1] - mysum[row2 + 1][col1] - mysum[row1][col2 + 1] + mysum[row1][col1];
        else
            return 0;
    }
};



你可能感兴趣的:(线段树)