
线段树(Segment Tree)又叫区间树(Interval Tree),它实际上是一颗二叉树,树种的每一个节点表示一个区间[a, b],左儿子的区间是[a, (a+b)/2],右儿子的区间是[(a+b)/2+1, b]







201. Segment Tree Build


 * Definition of SegmentTreeNode:
 * public class SegmentTreeNode {
 *     public int start, end;
 *     public SegmentTreeNode left, right;
 *     public SegmentTreeNode(int start, int end) {
 *         this.start = start, this.end = end;
 *         this.left = this.right = null;
 *     }
 * }
public class Solution {
     *@param start, end: Denote an segment / interval
     *@return: The root of Segment Tree
    public SegmentTreeNode build(int start, int end) {
        if (start <= end) {
            SegmentTreeNode node = new SegmentTreeNode(start, end);
            if (start == end) {
                return node;
            node.left = build(start, (start + end) / 2);
            node.right = build( (start + end) / 2 + 1, end);
            return node;
        return null;

439. Segment Tree Build II


 * Definition of SegmentTreeNode:
 * public class SegmentTreeNode {
 *     public int start, end, max;
 *     public SegmentTreeNode left, right;
 *     public SegmentTreeNode(int start, int end, int max) {
 *         this.start = start;
 *         this.end = end;
 *         this.max = max
 *         this.left = this.right = null;
 *     }
 * }
public class Solution {
     *@param A: a list of integer
     *@return: The root of Segment Tree
    public SegmentTreeNode buildHelper(int[] A, int start, int end) {
        if (start <= end) {
            SegmentTreeNode node = new SegmentTreeNode(start, end, A[start]);
            if (start == end) {
                return node;
            node.left = buildHelper(A, start, (start + end) / 2);
            node.right = buildHelper(A, (start + end) / 2 + 1, end);
            node.max = Math.max(node.left.max, node.right.max);
            return node;
        return null;
    public SegmentTreeNode build(int[] A) {
        return buildHelper(A, 0, A.length - 1);

202. Segment Tree Query

对线段树进行查询,要求找到给定区间(start, end)内的最大值

 * Definition of SegmentTreeNode:
 * public class SegmentTreeNode {
 *     public int start, end, max;
 *     public SegmentTreeNode left, right;
 *     public SegmentTreeNode(int start, int end, int max) {
 *         this.start = start;
 *         this.end = end;
 *         this.max = max
 *         this.left = this.right = null;
 *     }
 * }
public class Solution {
     *@param root, start, end: The root of segment tree and 
     *                         an segment / interval
     *@return: The maximum number in the interval [start, end]
    public int query(SegmentTreeNode root, int start, int end) {
        // 查询区间在当前节点的范围之内
        if (start <= root.start && root.end <= end) {
            return root.max;
        int mid = (root.start + root.end) / 2;
        int ans = Integer.MIN_VALUE;
        // 查询区间和左子树有交集
        if (start <= mid) {
            ans = Math.max(ans, query(root.left, start, end));
        // 查询区间和右子树有交集
        if (end >= mid + 1) {
            ans = Math.max(ans, query(root.right, start, end));
        return ans;

247. Segment Tree Query II


 * Definition of SegmentTreeNode:
 * public class SegmentTreeNode {
 *     public int start, end, count;
 *     public SegmentTreeNode left, right;
 *     public SegmentTreeNode(int start, int end, int count) {
 *         this.start = start;
 *         this.end = end;
 *         this.count = count;
 *         this.left = this.right = null;
 *     }
 * }
public class Solution {
     *@param root, start, end: The root of segment tree and 
     *                         an segment / interval
     *@return: The count number in the interval [start, end]
    public int helper(SegmentTreeNode root, int start, int end) {
        // 查询区间在当前节点的范围之内
        if (start <= root.start && root.end <= end) {
            return root.count;
        int mid = (root.start + root.end) / 2;
        int leftSum = 0, rightSum = 0;
        // 查询区间和左子树有交集
        if (start <= mid) {
            if (end <= mid) { // 包含
                leftSum = query(root.left, start, end);
            } else { // 分裂
                leftSum = query(root.left, start, mid);
        // 查询区间和右子树有交集
        if (end >= mid + 1) {
            if (start >= mid + 1) { // 包含
                rightSum = query(root.right, start, end);
            } else { // 分裂
                rightSum = query(root.right, mid + 1, end);
        return leftSum + rightSum;
    public int query(SegmentTreeNode root, int start, int end) {
        if (root == null || start > root.end || end < root.start) {
            return 0;
        return helper(root, start, end);

203. Segment Tree Modify


 * Definition of SegmentTreeNode:
 * public class SegmentTreeNode {
 *     public int start, end, max;
 *     public SegmentTreeNode left, right;
 *     public SegmentTreeNode(int start, int end, int max) {
 *         this.start = start;
 *         this.end = end;
 *         this.max = max
 *         this.left = this.right = null;
 *     }
 * }
public class Solution {
     *@param root, index, value: The root of segment tree and 
     *@ change the node's value with [index, index] to the new given value
     *@return: void
    public void modify(SegmentTreeNode root, int index, int value) {
        // Find the leaf node that needs modifying
        if (root.start == index && root.end == index) {
            root.max = value;
        int mid = (root.start + root.end) / 2;
        if (index <= mid) { // target leaf node is in the left
            modify(root.left, index, value);
            root.max = Math.max(root.left.max, root.right.max);
        } else { // target leaf node is in the right
            modify(root.right, index, value);
            root.max = Math.max(root.left.max, root.right.max);

205. Interval Minimum Number


 * Definition of Interval:
 * public classs Interval {
 *     int start, end;
 *     Interval(int start, int end) {
 *         this.start = start;
 *         this.end = end;
 *     }
class SegmentTree {
    public int start, end, min;
    public SegmentTree left, right;
    public SegmentTree(int start, int end, int min) {
        this.start = start;
        this.end = end;
        this.min = min;
        this.left = this.right = null;
public class Solution {
     *@param A, queries: Given an integer array and an query list
     *@return: The result list
    public SegmentTree build(int[] A, int start, int end) {
        if (start <= end) {
            SegmentTree node = new SegmentTree(start, end, A[start]);
            if (start == end) {
                return node;
            node.left = build(A, start, (start + end) / 2);
            node.right = build(A, (start + end) / 2 + 1, end);
            node.min = Math.min(node.left.min, node.right.min);
            return node;
        return null;
    public int query(SegmentTree root, int start, int end) {
        if (start <= root.start && root.end <= end) {
            return root.min;
        int mid = (root.start + root.end) / 2;
        int ans = Integer.MAX_VALUE;
        if (start <= mid) {
            ans = Math.min(ans, query(root.left, start, end));
        if (end > mid) {
            ans = Math.min(ans, query(root.right, start, end));
        return ans;
    public ArrayList intervalMinNumber(int[] A, 
                                                ArrayList queries) {
        ArrayList res = new ArrayList();
        SegmentTree root = build(A, 0, A.length - 1);
        for (Interval in: queries) {
            res.add(query(root, in.start, in.end));
        return res;

206. Interval Sum


 * Definition of Interval:
 * public classs Interval {
 *     int start, end;
 *     Interval(int start, int end) {
 *         this.start = start;
 *         this.end = end;
 *     }
class SegmentTree {
    public int start, end;
    public long sum;
    public SegmentTree left, right;
    public SegmentTree(int start, int end, long sum) {
        this.start = start;
        this.end = end;
        this.sum = sum;
        this.left = this.right = null;
public class Solution {
     *@param A, queries: Given an integer array and an query list
     *@return: The result list
    public SegmentTree build(int[] A, int start, int end) {
        if (start <= end) {
            SegmentTree node = new SegmentTree(start, end, Long.valueOf(A[start]));
            if (start == end) {
                return node;
            node.left = build(A, start, (start + end) / 2);
            node.right = build(A, (start + end) / 2 + 1, end);
            node.sum = node.left.sum + node.right.sum;
            return node;
        return null;
    public long query(SegmentTree root, int start, int end) {
        if (start <= root.start && root.end <= end) {
            return root.sum;
        int mid = (root.start + root.end) / 2;
        long leftSum = 0L, rightSum = 0L;
        if (start <= mid) {
            if (end <= mid) {
                leftSum = query(root.left, start, end);
            } else {
                leftSum = query(root.left, start, mid);
        if (end > mid) {
            if (start <= mid) {
                rightSum = query(root.right, mid + 1, end);
            } else {
                rightSum = query(root.right, start, end);
        return leftSum + rightSum;
    public ArrayList intervalSum(int[] A, 
                                       ArrayList queries) {
        ArrayList res = new ArrayList();
        SegmentTree root = build(A, 0, A.length - 1);
        for (Interval in: queries) {
            res.add(query(root, in.start, in.end));
        return res;

207. Interval Sum II


class SegmentTree {
    public int start, end;
    public long sum;
    public SegmentTree left, right;
    public SegmentTree(int start, int end, long sum) {
        this.start = start;
        this.end = end;
        this.sum = sum;
        this.left = this.right = null;
public class Solution {
    /* you may need to use some attributes here */
    private SegmentTree root;
    public SegmentTree build(int[] A, int start, int end) {
        if (start <= end) {
            SegmentTree node = new SegmentTree(start, end, Long.valueOf(A[start]));
            if (start == end) {
                return node;
            node.left = build(A, start, (start + end) / 2);
            node.right = build(A, (start + end) / 2 + 1, end);
            node.sum = node.left.sum + node.right.sum;
            return node;
        return null;
    public long queryHelper(SegmentTree root, int start, int end) {
        if (start <= root.start && root.end <= end) {
            return root.sum;
        int mid = (root.start + root.end) / 2;
        long leftSum = 0L, rightSum = 0L;
        if (start <= mid) {
            if (end <= mid) {
                leftSum = queryHelper(root.left, start, end);
            } else {
                leftSum = queryHelper(root.left, start, mid);
        if (end > mid) {
            if (start <= mid) {
                rightSum = queryHelper(root.right, mid + 1, end);
            } else {
                rightSum = queryHelper(root.right, start, end);
        return leftSum + rightSum;

     * @param A: An integer array
    public Solution(int[] A) {
        root = build(A, 0, A.length - 1);
     * @param start, end: Indices
     * @return: The sum from start to end
    public long query(int start, int end) {
        return queryHelper(root, start, end);
    public void modifyHelper(SegmentTree root, int index, int value) {
        if (root.start == index && root.end == index) {
            root.sum = Long.valueOf(value);
        int mid = (root.start + root.end) / 2;
        if (index <= mid) {
            modifyHelper(root.left, index, value);
        } else {
            modifyHelper(root.right, index, value);
        root.sum = root.left.sum + root.right.sum;
     * @param index, value: modify A[index] to value.
    public void modify(int index, int value) {
        modifyHelper(root, index, value);

248. Count of Smaller Number




但是我还有更快的方法来做这道题,由于每个数字的值在0~10000之间,那我完全可以用基数排序来统计。这样先用一轮loop (OlogN) 统计字数用于更新统计数组,然后一次性把所有的累加和求出来存到数组,用于查询。这样总共的时间复杂度就是O(n)。是线性的:

    public ArrayList countOfSmallerNumber(int[] A, int[] queries) {
        int[] countArr = new int[10001];
        for (int i = 0; i < A.length; i++) {
        int[] res = new int[10001];
        int sum = 0;
        for (int i = 1; i < countArr.length; i++) {
            res[i] = sum;
            sum += countArr[i];
        ArrayList list = new ArrayList();
        for (int i = 0; i < queries.length; i++) {
        return list;

249. Count of Smaller Number before itself


要查询当前数字左边共有多少个数比当前数字小,其实也就是统计下标从0到k-1的书里面有几个是小于A[k]的,相当于对前k-1个数字变相求和。那么可以先按照区间[0, 10000]建立一棵线段树。每个叶子节点的count都先记为1。那么查询的时候,就可以边更新边查询了。从左往右扫描数组,遇到当前的数字就把对应到的线段树更新自增,然后查询。查询的时候就是进行统计,找到线段树中所有小于当前数字的区间的count sum。由于每次只会把当前访问过的数字更新,这样就保证了查询是正确的。

举个栗子,比如数组是[1,2,9,8,5]。那如果对8进行查询,看左边有多少个数字小于8的话,其实就统计区间[0, 7]的count sum。而且要保证8右边的数字不出现在区间内。怎么样做到在query查询的时候,8右边的数字不出现在区间统计里面呢?我们可以通过从左到右扫描数组的方式进行。扫描数组的过程中,只有访问到了该数字,才把对应的线段树更新。这样一来,8左边比它大的数字并不会出现在线段树查询的区间中,故自然不会影响到结果。二来,8右边的比它小的数字由于还没被访问到,所以也不会影响查询结果。这一来二去就保证了用线段树的查询是正确的:

class SegmentTree {  
    public int start, end;  
    public int count;  
    public SegmentTree left, right;  
    public SegmentTree(int start, int end, int count) {  
        this.start = start;  
        this.end = end;  
        this.count = count;  
        this.left = this.right = null;  
public class Solution {
     * @param A: An integer array
     * @return: Count the number of element before this element 'ai' is 
     *          smaller than it and return count number array
    public SegmentTree build(int[] A, int start, int end) {  
        if (start <= end) {  
            SegmentTree node = new SegmentTree(start, end, 0);  
            if (start == end) {  
                return node;  
            node.left = build(A, start, (start + end) / 2);  
            node.right = build(A, (start + end) / 2 + 1, end);  
            node.count = 0;
            return node;  
        return null;  
    public int query(SegmentTree root, int start, int end) {  
        if (start <= root.start && root.end <= end) {  
            return root.count;  
        int mid = (root.start + root.end) / 2;  
        int leftSum = 0, rightSum = 0;  
        if (start <= mid) {  
            if (end <= mid) {  
                leftSum = query(root.left, start, end);  
            } else {  
                leftSum = query(root.left, start, mid);  
        if (end > mid) {  
            if (start <= mid) {  
                rightSum = query(root.right, mid + 1, end);  
            } else {  
                rightSum = query(root.right, start, end);  
        return leftSum + rightSum;  
    public void modify(SegmentTree root, int index, int value) {  
        if (root.start == index && root.end == index) {  
            root.count += value;
        int mid = (root.start + root.end) / 2;
        if (index <= mid) {
            modify(root.left, index, value);
        } else {
            modify(root.right, index, value);  
        root.count = root.left.count + root.right.count;  
    public ArrayList countOfSmallerNumberII(int[] A) {
        ArrayList res = new ArrayList();
        SegmentTree root = build(A, 0, 10000);
        for (int i = 0; i < A.length; i++) {
            if (A[i] > 0) {
                res.add(query(root, 0, A[i] - 1));
            } else {
            modify(root, A[i], 1);
        return res;

扫描线则是另外一种类型的区间求解的方法。一般遇到如下的几种类型的题目就适用扫描线Sweep Line来解题。扫描线就是模拟用一根线去扫描区间,然后获取每个时间点的信息来解题。

a. 告诉你一些飞机的起飞和降落时间,问你最多能有几架飞机在同时在天上飞。

b. 给你一些会议的开始和结束时间,问你至少需要多少个会议室才能满足需求。

c. 告诉你火车的起始和到达时间,问你需要多少铁轨才能容纳所有的火车。

391. Number of Airplanes in the Sky

告诉你飞机的起飞和降落时间,问你最多有多少个飞机同时在天上飞。典型的扫描线问题。本来给定的是类似于[2, 4], [1, 10]这样的区间代表起飞和结束时间,但是仅仅用这样的数据结构是无法解题的。得把区间拆分成点。而我们其实只需要记录起点和终点就行了。


 * Definition of Interval:
 * public classs Interval {
 *     int start, end;
 *     Interval(int start, int end) {
 *         this.start = start;
 *         this.end = end;
 *     }
class Point {
    public int time;
    public int flag;
    public Point(int time, int flag) {
        this.time = time;
        this.flag = flag;
class PointComparator implements Comparator {
    public int compare(Point a, Point b) {
        if (a.time == b.time) {
            return a.flag - b.flag;
        return a.time - b.time;
class Solution {
     * @param intervals: An interval array
     * @return: Count of airplanes are in the sky.
    public int countOfAirplanes(List airplanes) { 
        List list = new ArrayList();
        for (Interval in: airplanes) {
            list.add(new Point(in.start, 1));
            list.add(new Point(in.end, 0));
        Collections.sort(list, new PointComparator());
        int count = 0, res = 0;
        for (Point p: list) {
            if (p.flag == 1) {
            } else {
            res = Math.max(res, count);
        return res;
