线段树解题模版

线段树

概念:线段树就是一棵二叉树,每个节点代表一个区间,主要用于解决区间类问题。每个节点的属性根据需要可以去自定义,比如节点的属性可以是区间和、区间最大/最小值。。

线段树解题模版_第1张图片

一、线段树节点定义

每个node有区间的左右端点,以及左右孩子

public class SegmentTreeNode{
  	int start,end,val;	//val根据需要定义,比如我定义为区间最大值就是max
  	SegmentTreeNode left,right;
  	public SegmentTreeNode(int start, int end, int val){
      	this.start = start;
      	this.end = end;
      	this.val = val;
      	this.left = this.right = null;
    }
}

二、线段树的构建、修改、查询

1、构建

自上而下,分治法,递归调用。

对于区间[m1,m2],mid = (m1+m2)/2,其左儿子区间是[m1,mid],右儿子区间是(mid+1,m2)

//线段树的构建,以求区间最大值为例,返回根节点
public SegmentTreeNode build(int start, int end, int[] A) {
    if (start > end) {
        return null;
    }
    if (start == end) {
        return new SegmentTreeNode(start, end, A[start]);
    }
    //先new根区间,根区间最大值暂时为A[start],不能为其他乱七八糟的值比如-1这种...
    SegmentTreeNode root = new SegmentTreeNode(start, end, A[start]);
    if (start != end) {
        int mid = (start + end) / 2;
        root.left = build(start, mid, A);
        root.right = build(mid + 1, end, A);
    }
    //改root的val
    if (root.left != null && root.left.max > root.max) {
        root.max = root.left.max;
    }
    if (root.right != null && root.right.max > root.max) {
        root.max = root.right.max;
    }
    return root;
}

2、修改

递归调用,一路向下找到最小区间,触底反弹的时候才去修改node。比如数组[6,3,5,1,9],我要修改1位置上的3,那就是一路向下先找到3,然后返回途中修改 ,时间复杂度logN

线段树解题模版_第2张图片

//线段树的修改,
public void modify(SegmentTreeNode root, int index, int value) {
    if (root.start == index && root.end == index) {
        root.max = value;
        return;
    }
    int mid = (root.start + root.end) / 2;
    //看index在左区间还是右区间
    if (root.start <= index && index <= mid) {
        modify(root.left, index, value);
    }
    if (mid < index && index <= root.end) {
        modify(root.right, index, value);
    }
    //最后改下根
    root.max = Math.max(root.left.max, root.right.max);
}

3、查询

比如上面例子,找[0,3],先找[0,2],再找[3,3]。找[0,2],直接返回,找[3-3],就需要走到底

//线段树的查询
public int query(SegmentTreeNode root, int start, int end) {
    if (start == root.start && end == root.end) {
        return root.max;
    }
    int mid = (root.start + root.end) / 2;
    int left_max = Integer.MIN_VALUE;
    int right_max = Integer.MIN_VALUE;

    //求左边最大值
    //如果给定查询范围起点在左子树
    if (start <= mid) {
        //但是终点在右子树(横跨左子树和右子树),那么左边最大值就在start到mid之间查询
        if (mid < end) {
            left_max = query(root.left, start, mid);
        } else {    //如果只在左子树
            left_max = query(root.left, start, end);
        }
    }
    //求右边最大值
    if (mid < end) {
        //横跨左右子树的情况,起点为mid+1
        if (start <= mid) {
            right_max = query(root.right, mid + 1, end);
        } else {    //如果只在右子树
            right_max = query(root.right, start, end);
        }
    }
    return Math.max(left_max, right_max);
}

三、线段树性质

对于区间[m1,m2],mid = (m1+m2)/2,其左儿子区间是[m1,mid],右儿子区间是(mid+1,m2)

线段树解题模版_第3张图片

四、题目练习

1、LintCode 206. Interval Sum

区间求和

给定一个整数数组(下标由 0 到 n-1,其中 n 表示数组的规模),以及一个查询列表。每一个查询列表有两个整数 [start, end] 。 对于每个查询,计算出数组中从下标 start 到 end 之间的数的总和

输入: 数组 :[1,2,7,8,5], 查询:[(0,4),(1,2),(2,4)]
输出: [23,9,20]

思路

  • 暴力,枚举O(nm),n为数组长度,m为查询次数
  • 线段树/树状数组,O(mlogn)
  • 前缀和数组O(n+m),这个题没有涉及到修改,可以用

首先定义Node和树

		//线段树Node
    class SegmentTreeNode{
        int start;
        int end;
        long sum;
        SegmentTreeNode left, right;

        public SegmentTreeNode(int start, int end) {
            this.start = start;
            this.end = end;
            sum = 0;
            left = right = null;
        }
    }

    class  SegmentTree{
        private int size;   //区间
        private SegmentTreeNode root;

        public SegmentTree(int[] A) {
            size = A.length;
            root = buildTree(A,0, size - 1);
        }

        private SegmentTreeNode buildTree(int[] A, int start, int end) {
            SegmentTreeNode node = new SegmentTreeNode(start, end);
            //递归出口,叶子节点
            if (start == end) {
                node.sum = A[start];
                return node;
            }
            //不是出口,递归建立左子树和右子树
            int mid = (start + end) / 2;
            node.left = buildTree(A, start, mid);
            node.right = buildTree(A, mid + 1, end);
            //别忘记维护当前节点的sum
            node.sum = node.left.sum + node.right.sum;
            return node;
        }

        //查询对外界接口
        public long querySum(int start, int end) {
            return querySum(root, start, end);
        }

        //重载方法.在node节点下查询原数组start到end区间内的和
        private long querySum(SegmentTreeNode node, int start, int end) {
            //递归出口
            if (node.start == start && node.end == end) {
                return node.sum;
            }
            int mid = (node.start + node.end) / 2;
            long leftSum = 0, rightSum = 0;
            //左边区间
            if (start <= mid) {
                //如果不是跨区间
                if (end <= mid) {
                    leftSum = querySum(node.left, start, end);
                } else {
                    leftSum = querySum(node.left, start, mid);
                }
                //可以合并为一行 !!!!
                // leftSum = querySum(node.left, start, Math.min(mid, end));
            }
            //要考虑右半区间,也就是start-end与右半区间有交集
            if (end >= mid + 1) {
                //如果不是跨区间
                if (start >= mid + 1) {
                    rightSum = querySum(node.right, start, end);
                } else {
                    rightSum = querySum(node.right, mid + 1, end);
                }
                // 可以合并为一句
                // rightSum = querySum(node.right, Math.max(mid + 1, start), end);
            }
            return leftSum + rightSum;
        }
    }

实现方法

		public class Interval {
        int start, end;

        Interval(int start, int end) {
            this.start = start;
            this.end = end;
        }
    }

    public List<Long> intervalSum(int[] A, List<Interval> queries) {
        List<Long> res = new ArrayList<>();
        SegmentTree segmentTree = new SegmentTree(A);
        for (Interval query : queries) {
            long sum = segmentTree.querySum(query.start, query.end);
            res.add(sum);
        }
        return res;
    }

2、LintCode 207. Interval Sum

在类的构造函数中给一个整数数组, 实现两个方法 query(start, end)modify(index, value):

  • 对于 query(start, end), 返回数组中下标 startend
  • 对于 modify(index, value), 修改数组中下标为 index 上的数为 value.

比206多了modify,无法使用前缀和数组,暴力O(nm),线段树/树状数组O(mlogn),n为数组长度,m为操作次数。

线段树类中提供三个方法

  • 构造器传入int[] A
  • querySum(int start, int end)
  • modify(int index, int val)
public class Solution {
        private SegmentTree segmentTree;

        public Solution(int[] A) {
            if (A == null || A.length == 0) {
                return;
            }
            segmentTree = new SegmentTree(A);
        }


        public long query(int start, int end) {
            return segmentTree.querySum(start, end);
        }

        public void modify(int index, int value) {
            segmentTree.modify(index, value);
            return;
        }

    }

    class SegmentTreeNode{
        public int start, end;
        public long sum;
        public SegmentTreeNode left, right;

        public SegmentTreeNode(int start, int end) {
            this.start = start;
            this.end = end;
            sum = 0;
            left = right = null;
        }
    }

    class SegmentTree{
        public SegmentTreeNode root;
        public int size;

        public SegmentTree(int[] A) {
            size = A.length;
            root = buildTree(A, 0, size - 1);
        }

        private SegmentTreeNode buildTree(int[] A, int start, int end) {
            SegmentTreeNode node = new SegmentTreeNode(start, end);
            //递归出口
            if (start == end) {
                node.sum = A[start];
                return node;
            }
            //不是出口则递归建所有子树
            int mid = (start + end) / 2;
            node.left = buildTree(A, start, mid);
            node.right = buildTree(A, mid + 1, end);
            node.sum = node.left.sum + node.right.sum;
            return node;
        }

        //公开接口
        public long querySum(int start, int end) {
            return querySum(root, start, end);
        }

        //公开接口
        public void modify(int index, int val) {
            modify(root, index, val);
        }

        private long querySum(SegmentTreeNode node, int start, int end) {
            if (node.start == start && node.end == end) {
                return node.sum;
            }
            int mid = (node.start + node.end) / 2;  //这边不是start和end 是node的区间
            long leftSum = 0, rightSum = 0;
            if (start <= mid) {
                leftSum = querySum(node.left, start, Math.min(end, mid));
            }
            if (end >= mid + 1) {
                rightSum = querySum(node.right, Math.max(start, mid + 1), end);
            }
            return leftSum + rightSum;
        }

        private void modify(SegmentTreeNode node, int index, int val) {
            //递归出口:到达这个叶子节点,并修改它的值
            if (node.start == node.end && node.end == index) {
                node.sum = val;
                return;
            }
            //递归:分为在左子树和右子树两种情况,不用求mid
            if (node.left.end >= index) {
                modify(node.left, index, val);
            } else {
                modify(node.right, index, val);
            }
            //最后改下根
            node.sum = node.left.sum + node.right.sum;
        }
    }

3、LintCode 248. Count of Smaller Number

统计比给定整数小的数的个数

给定一个整数数组 (下标由 0 到 n-1,其中 n 表示数组的规模,数值范围由 0 到 10000),以及一个查询列表。对于每一个查询,将会给你一个整数,请你返回该数组中小于给定整数的元素的数量。

输入: array =[1,2,7,8,5] queries =[1,8,5]
输出:[0,4,2]

时间复杂度:

  • 暴力求O(nm),n为数组长度,m为查询次数
  • 线段树/树状数组O(mlogk),m为查询次数,k为数组最大值
  • 二分,先排序,nlogn,然后查询比某个数小,只要得到它的位置即可,O(nlogn+mlogn)
  • 前缀和数组,线性,O(k+n+m),本题较好的方式,但是扩展较为困难

线段树思路:

  • 数组内元素范围在0~10000,用数组B[i]代表i这个值出现了多少次,那么查询比x小的元素只要计算B的前缀和B[0]+B[1]+…+B[x-1],那就是查询B数组的某一个区间和,查询时间复杂度为logk(k为数组最大值)。这个题用前缀和也是非常的方便,但是遇到follow up就不行了。
public class Solution {

        public List<Integer> countOfSmallerNumber(int[] A, int[] queries) {
            int[] B = new int[10001];
            for (int i : A) {
                B[i]++;
            }
            //建立线段树,大小为10001
            SegmentTree tree = new SegmentTree(10001);
            for (int i = 0; i < 10001; i++) {
                tree.modify(i, B[i]);   //i位置修改为B[i],表示这个数出现了多少次
            }
            List<Integer> res = new ArrayList<>();
            for (int i : queries) {
                if (i == 0) {
                    res.add(0); //没有数比0小,都是正数
                } else {
                    res.add(tree.querySum(0, i - 1));
                } 
            }
            return res;
        }
    }

    class SegmentTreeNode{
        public int sum;
        public  int start, end;
        public  SegmentTreeNode left, right;

        public SegmentTreeNode(int start, int end) {
            this.start = start;
            this.end = end;
            sum = 0;
            left = right = null;
        }
    }

    class SegmentTree{
        private int size;
        private SegmentTreeNode root;

        public SegmentTree(int size) {
            this.size = size;
            root = buildTree(0, size - 1);
        }

        //初始化得到的是全0的树
        private SegmentTreeNode buildTree(int start, int end) {
            SegmentTreeNode node = new SegmentTreeNode(start, end);
            if (start == end) {
                return node;
            }
            int mid = (start + end) / 2;
            node.left = buildTree(start, mid);
            node.right = buildTree(mid + 1, end);
            return node;
        }

        public int querySum(int start, int end) {
            return querySum(root, start, end);
        }

        //在node节点的子树下,查询[start,end]区间内维护的和
        public int querySum(SegmentTreeNode node, int start, int end) {
            if (node.start == start && node.end == end) {
                return node.sum;
            }
            int leftSum = 0, rightSum = 0;
            int mid = (node.start + node.end) / 2;
            if (start <= mid) {
                leftSum = querySum(node.left, start, Math.min(end, mid));
            }
            if (end >= mid + 1) {
                rightSum = querySum(node.right, Math.max(start, mid + 1), end);
            }
            return leftSum + rightSum;
        }

        public void modify(int index, int val) {
            modify(root, index, val);
        }

        private void modify(SegmentTreeNode node, int index, int val) {
            if (node.start == node.end && node.end == index) {   //可以省略node.end == index
                node.sum = val;
                return;
            }
            if (node.left.end >= index) {
                modify(node.left, index, val);
            } else {
                modify(node.right, index, val);
            }
            //维护当前节点sum
            node.sum = node.left.sum + node.right.sum;
        }
    }

4、LintCode 249. Count of Smaller Number before itself

统计前面比自己小的数的个数

给定一个整数数组(下标由 0 到 n-1, n 表示数组的规模,取值范围由 0 到10000)。对于数组中的每个 ai 元素,请计算 ai 前的数中比它小的元素的数量。

输入:
[1,2,7,8,5]
输出:
[0,1,2,3,2]

时间复杂度:

  • 暴力,O(n2)
  • 树状数组/线段树,O(nlogk),n为数组长度,k为数组最大值

思路:

  • 数组内范围为0~10000,假设数组B,B[i]表示数组A当前元素之前有多少个i(或者说B[i]表示A中有多少个i,只不过它是实时变化的)。查询比x小的数的个数相当于求B的x-1前缀和,B[0]+B[1]+…+B[x-1]

A=[1,2,7,8,5]

B=[0,0,0,0,0,0,0,0,0] 初始,这里B得开9,因为0~8一共9位

B=[0,1,0,0,0,0,0,0,0] B[1]++,统计比A中第二个元素(2)小的个数,B[0]+B[1]

B=[0,1,1,0,0,0,0,0,0] B[2]++,A中第三个元素为7,计算B[0]+…+B[6]

B=[0,1,1,0,0,0,0,1,1] B[7]++,A中第四个元素为8,计算B[0]+…+B[7]

B=[0,1,1,0,0,1,0,1,1] B[8]++,A中第五个元素为5,计算B[0]+…+B[4]

public class Solution {

    public List<Integer> countOfSmallerNumberII(int[] A) {
        List<Integer> res = new ArrayList<>();
        SegmentTree tree = new SegmentTree(10001);
        int[] B = new int[10001];
        for (int i : A) {
            if (i == 0) {
                res.add(0);
            } else {
                res.add(tree.querySum(0, i - 1));
            }
            //更新B
            B[i]++;
            tree.modify(i, B[i]);
        }
        return res;
    }
}

class SegmentTreeNode{
    public int sum;
    public  int start, end;
    public SegmentTreeNode left, right;

    public SegmentTreeNode(int start, int end) {
        this.start = start;
        this.end = end;
        sum = 0;
        left = right = null;
    }
}

class SegmentTree{
    private int size;
    private SegmentTreeNode root;

    public SegmentTree(int size) {
        this.size = size;
        root = buildTree(0, size - 1);
    }

    //初始化得到的是全0的树
    private SegmentTreeNode buildTree(int start, int end) {
        SegmentTreeNode node = new SegmentTreeNode(start, end);
        if (start == end) {
            return node;
        }
        int mid = (start + end) / 2;
        node.left = buildTree(start, mid);
        node.right = buildTree(mid + 1, end);
        return node;
    }

    public int querySum(int start, int end) {
        return querySum(root, start, end);
    }

    //在node节点的子树下,查询[start,end]区间内维护的和
    public int querySum(SegmentTreeNode node, int start, int end) {
        if (node.start == start && node.end == end) {
            return node.sum;
        }
        int leftSum = 0, rightSum = 0;
        int mid = (node.start + node.end) / 2;
        if (start <= mid) {
            leftSum = querySum(node.left, start, Math.min(end, mid));
        }
        if (end >= mid + 1) {
            rightSum = querySum(node.right, Math.max(start, mid + 1), end);
        }
        return leftSum + rightSum;
    }

    public void modify(int index, int val) {
        modify(root, index, val);
    }

    private void modify(SegmentTreeNode node, int index, int val) {
        if (node.start == node.end && node.end == index) {   //可以省略node.end == index
            node.sum = val;
            return;
        }
        if (node.left.end >= index) {
            modify(node.left, index, val);
        } else {
            modify(node.right, index, val);
        }
        //维护当前节点sum
        node.sum = node.left.sum + node.right.sum;
    }
}

你可能感兴趣的:(算法)