数据结构--线段树

  线段树用于处理区间数据的更新与查询问题,不考虑往区间中增加与删除数据的,主要用于统计数据方面的需求,在更新与查询的时间复杂度都为logn级别。线段树不属于完全二叉树,但属于平衡二叉树。
数据结构--线段树_第1张图片
线段树事例
  • 数组为存储的实现代码如下:

/**
 * 
 * 功能描述:线段树
 * 
 * @version 2.0.0
 * @author zhiminchen
 */
public class SegmentTree {

    // 用于存储线段数的数据
    private E[]       data;
    // 用于存储原始数据
    private E[]       tree;
    // 用于抽象线段树的统计操作
    private Merger merger;

    /**
     * 
     * 功能描述:
     * 
     * @version 2.0.0
     * @author zhiminchen
     */
    static interface Merger {
        public E merge(E left,
                       E right);
    }

    /**
     * 线段树的构造方法,传入的是一个数组与merger对象
     * @param arr
     * @param merger
     */
    public SegmentTree(E[] arr, Merger merger) {
        this.merger = merger;
        data = (E[]) new Object[arr.length];
        // 将arr传入的数据复制到data数组中
        for (int i = 0; i < arr.length; i++) {
            data[i] = arr[i];
        }
        // 这里需要4倍的空间(tree是一个满二叉树的空间,需要考虑只有10个节点的情况)
        tree = (E[]) new Object[arr.length * 4];
        buildSegmentTree(0, 0, data.length - 1);
    }

    /**
     * 
     * 功能描述: 在treeIndex的位置创建表示区间【left, right】的线段树
     * 
     * @param treeIndex
     * @param left
     * @param right void
     * @version 2.0.0
     * @author zhiminchen
     */
    private void buildSegmentTree(int treeIndex,
                                  int left,
                                  int right) {
        // 递归终止条件
        if (left == right) {
            tree[treeIndex] = data[left];
            return;
        }
        // treeIndex的左孩子节点的值
        int leftIndex = getLeft(treeIndex);
        // treeIndex的右孩子节点的值
        int rightIndex = getRight(treeIndex);
        // 找到线段树中间的区间,递归构造区间树的左右孩子节点
        int mid = left + (right - left) / 2;
        //
        buildSegmentTree(leftIndex, left, mid);
        buildSegmentTree(rightIndex, mid + 1, right);
        // 这里贸给具体的业务逻辑处理
        tree[treeIndex] = merger.merge(tree[leftIndex], tree[rightIndex]);
    }

    /**
     * 
     * 功能描述:对线段树中的值进行更新
     * 
     * @param index
     * @param e void
     * @version 2.0.0
     * @author zhiminchen
     */
    public void set(int index,
                    E e) {
        if (index < 0 || index > data.length - 1) {
            throw new IllegalArgumentException("index is illegal");
        }
        data[index] = e;
        set(0, 0, data.length - 1, index, e);
    }

    /**
     * 
     * 功能描述: 在以treeIndex为根的线段树中 更新index的值为e
     * 
     * @param treeIndex
     * @param left
     * @param right
     * @param index
     * @param E void
     * @version 2.0.0
     * @author zhiminchen
     */
    private void set(int treeIndex,
                     int left,
                     int right,
                     int index,
                     E e) {
        // 递归退出条件
        if (left == right) {
            tree[treeIndex] = e;
            return;
        }
        int mid = left + (right - left) / 2;
        // treeIndex的左孩子节点的值
        int leftIndex = getLeft(treeIndex);
        // treeIndex的右孩子节点的值
        int rightIndex = getRight(treeIndex);

        if (index >= mid + 1) { // 更新index在右子树的情况
            set(rightIndex, mid + 1, right, index, e);
        } else {
            set(leftIndex, left, mid, index, e);
        }
        // 更新线段树的值
        tree[treeIndex] = merger.merge(tree[leftIndex], tree[rightIndex]);
    }

    /**
     * 
     * 功能描述: 查询某个区间的值
     * 
     * @param queryLeft
     * @param queryRight
     * @return E
     * @version 2.0.0
     * @author zhiminchen
     */
    public E query(int queryLeft,
                   int queryRight) {
        return query(0, 0, data.length - 1, queryLeft, queryRight);
    }

    /**
     * 
     * 功能描述: 在以treeIndex为根的线段树中区间为【left, right】, 搜索【queryLeft, queryRight】区间的值
     * 
     * @param treeIndex
     * @param left
     * @param right
     * @param queryLeft
     * @param queryRight
     * @return E
     * @version 2.0.0
     * @author zhiminchen
     */
    private E query(int treeIndex,
                    int left,
                    int right,
                    int queryLeft,
                    int queryRight) {
        // 如果搜索的区间正好是treeIndex的根节点,则直接返回
        if (left == queryLeft && right == queryRight) {
            return tree[treeIndex];
        }

        int mid = left + (right - left) / 2;
        int leftTreeIndex = getLeft(treeIndex);
        int rightTreeIndex = getRight(treeIndex);
        // 区间的最小值比mid还大, 则在右子树进行查找。
        if (queryLeft >= mid + 1) {
            return query(rightTreeIndex, mid + 1, right, queryLeft, queryRight);
        } else if (queryRight <= mid) { // 在线段树的左子树查找
            return query(leftTreeIndex, left, mid, queryLeft, queryRight);
        } else {
            // 处理在左右子树都需要查找的情况
            E leftResult = query(leftTreeIndex, left, mid, queryLeft, mid);
            E rightResult = query(rightTreeIndex, mid + 1, right, mid + 1, queryRight);
            return merger.merge(leftResult, rightResult);
        }
    }

    public E get(int index) {
        if (index < 0 || index >= data.length) {
            throw new IllegalArgumentException("index is illegal");
        }
        return data[index];
    }

    /**
     * 
     * 功能描述: (完全二叉树的数据表示)得到左孩子节点的索引
     * 
     * @param index
     * @return int
     * @version 2.0.0
     * @author zhiminchen
     */
    public int getLeft(int index) {
        return 2 * index + 1;
    }

    /**
     * 
     * 功能描述:(完全二叉树的数据表示) 得到右孩子节点
     * 
     * @param index
     * @return int
     * @version 2.0.0
     * @author zhiminchen
     */
    public int getRight(int index) {
        return 2 * index + 2;
    }

    @Override
    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append("[");
        for (int i = 0; i < tree.length; i++) {
            if (tree[i] != null) {
                sb.append(tree[i]);
            } else {
                sb.append("null");
            }

            if (i != tree.length - 1) {
                sb.append(", ");
            }
        }
        sb.append("]");
        return sb.toString();
    }

    public static void main(String[] args) {
        Integer[] data = new Integer[] {
                                         1,
                                         2,
                                         3,
                                         5,
                                         -8,
                                         9,
                                         11,
                                         2,
                                         33,
                                         45,
                                         12
        };

        SegmentTree segTree = new SegmentTree(data, new Merger() {

            public Integer merge(Integer left,
                                 Integer right) {
                return left + right;
            }

        });
        System.out.println(segTree.query(1, 8));
        segTree.set(5, 15);
        System.out.println(segTree.query(1, 8));

    }

}

  • 以树结构存储的实现如下:

/**
 * 
 * 功能描述:采用树的方式实现线段树
 * 
 * @version 2.0.0
 * @author zhiminchen
 */
public class SegmentTree2 {

    // 用于存储原始数据
    private E[]       data;

    // 用于抽象线段树的统计操作
    private Merger merger;

    // 树的根结点
    private Node   root;

    /**
     * 
     * 功能描述: node结点,用于存储数据
     * 
     * @version 2.0.0
     * @author zhiminchen
     */
    private static class Node {
        E       e;
        int     left;     // 左区间起始位置
        int     right;    // 右区间终止位置
        Node leftNode;
        Node rightNode;

        Node(E e, int left, int right) {
            this(e, left, right, null, null);
        }

        Node(E e, int left, int right, Node leftNode, Node rightNode) {
            this.e = e;
            this.left = left;
            this.right = right;
            this.leftNode = leftNode;
            this.rightNode = rightNode;
        }
    }

    public SegmentTree2(E[] arr, Merger merger) {
        this.merger = merger;
        data = (E[]) new Object[arr.length];
        // 将arr传入的数据复制到data数组中
        for (int i = 0; i < arr.length; i++) {
            data[i] = arr[i];
        }
        // 调用构造树的方法
        root = buildSegmentTree(0, data.length - 1);
    }

    /**
     * 
     * 功能描述: 构区间[left,right]的线段树
     * 
     * @param left
     * @param right
     * @return Node
     * @version 2.0.0
     * @author zhiminchen
     */
    private Node buildSegmentTree(int left,
                                     int right) {
        // 递归终止条件
        if (left == right) {
            return new Node(data[left], left, right);
        }
        // 找到线段树中间的区间,递归构造区间树的左右孩子节点
        int mid = left + (right - left) / 2;
        // 构建树的左结点
        Node leftNode = buildSegmentTree(left, mid);
        // 构建树的右结点
        Node rightNode = buildSegmentTree(mid + 1, right);
        // 求结点的值
        E e = merger.merge(leftNode.e, rightNode.e);
        // 返回结点
        Node node = new Node(e, left, right, leftNode, rightNode);
        return node;
    }

    /**
     * 功能描述:对线段树中的值进行更新
     * 
     * @param index
     * @param e void
     * @version 2.0.0
     * @author zhiminchen
     */
    public void set(int index,
                    E e) {
        if (index < 0 || index > data.length - 1) {
            throw new IllegalArgumentException("index is illegal");
        }
        data[index] = e;
        set(root, 0, data.length - 1, index, e);
    }

    /**
     * 
     * 功能描述: 在以treeIndex为根的线段树中 更新index的值为e
     * 
     * @param left
     * @param right
     * @param index
     * @param E void
     * @version 2.0.0
     * @author zhiminchen
     */
    private void set(Node node,
                     int left,
                     int right,
                     int index,
                     E e) {
        // 递归退出条件
        if (left == right && node.left == left && node.right == right) {
            node.e = e;
            return;
        }

        int mid = left + (right - left) / 2;

        if (index >= mid + 1) { // 更新index在右子树的情况
            set(node.rightNode, mid + 1, right, index, e);
        } else { //更新左子树的情况
            set(node.leftNode, left, mid, index, e);
        }
        // 更新线段树的值
        node.e = merger.merge(node.leftNode.e, node.rightNode.e);
    }

    /**
     * 
     * 功能描述:查询【queryLeft, queryRight】区间的值 
     * 
     * @param queryLeft
     * @param queryRight
     * @return E
     * @version 2.0.0
     * @author zhiminchen
     */
    public E query(int queryLeft,
                   int queryRight) {
        return query(root, queryLeft, queryRight);
    }

    /**
     * 
     * 功能描述: 在以treeIndex为根的线段树中区间为【left, right】, 搜索【queryLeft, queryRight】区间的值
     * 
     * @param treeIndex
     * @param left
     * @param right
     * @param queryLeft
     * @param queryRight
     * @return E
     * @version 2.0.0
     * @author zhiminchen
     */
    private E query(Node node,
                    int queryLeft,
                    int queryRight) {
        // 如果搜索的区间正好是treeIndex的根节点,则直接返回
        if (node.left == queryLeft && node.right == queryRight) {
            return node.e;
        }
        int mid = node.left + (node.right - node.left) / 2;
        // 区间的最小值比mid还大, 则在右子树进行查找。
        if (queryLeft >= mid + 1) {
            return query(node.rightNode, queryLeft, queryRight);
        } else if (queryRight <= mid) { // 在线段树的左子树查找
            return query(node.leftNode, queryLeft, queryRight);
        } else {
            // 处理在左右子树都需要查找的情况
            E leftResult = query(node.leftNode, queryLeft, mid);
            E rightResult = query(node.rightNode, mid + 1, queryRight);
            return merger.merge(leftResult, rightResult);
        }
    }

    /**
     * 
     * 功能描述:
     * 
     * @version 2.0.0
     * @author zhiminchen
     */
    static interface Merger {
        public E merge(E left,
                       E right);
    }

    public static void main(String[] args) {
        Integer[] data = new Integer[] {
                                         1,
                                         2,
                                         3,
                                         5,
                                         -8,
                                         9,
                                         11,
                                         2,
                                         33,
                                         45,
                                         12
        };

        SegmentTree2 segTree = new SegmentTree2(data, new Merger() {
            // 用于求合统计
            public Integer merge(Integer left,
                                 Integer right) {
                return left + right;
            }
        });

        System.out.println(segTree.query(1, 8));
        segTree.set(5, 15);
        System.out.println(segTree.query(1, 8));
    }

}

你可能感兴趣的:(数据结构--线段树)