对于给定区间, 支持更新和查询操作 :
如下图所示数组A, 以求和为例, 根节点A[0-7]存放的就是A[0-3]节点和A[4-7]节点之和, 下面的每个节点存放的值都是该节点对应左右孩子节点的和, 这样就用数组构建出了一个线段树,
如下图所示, 线段树查找步骤如下:
线段树更新的方法也很简单, 更新对应位置的值之后, 包含该位置的区间的值也都要进行更新
线段树完整代码实现如下 :
package tree.segment;
/**
* 使用数组实现线段树
* @author 七夜雪
*
* @param
*/
public class SegmentTree<E> {
private Merger<E> merger;
private E[] tree;
private E[] data;
@SuppressWarnings("unchecked")
public SegmentTree (E[] arr, Merger<E> merger){
this.merger = merger;
// java中无法直接使用new E[arr.length];这种方式创建泛型数组
data = (E[])new Object[arr.length];
for (int i = 0; i < arr.length; i++) {
data[i] = arr[i];
}
// 对于有n个元素的区间, 使用数组实现线段树的话, 需要4n的空间来存储
tree = (E[])new Object[arr.length * 4];
buildSegmentTree(0, 0, data.length - 1);
}
/**
* 在treeIndex的位置, 创建表示区间[l, r]的线段树
* 递归算法
* @param treeIndex
* @param l
* @param r
*/
private void buildSegmentTree(int treeIndex, int l, int r){
// 递归到底的情况
if (l == r) {
tree[treeIndex] = data[l];
return;
}
int leftTreeIndex = leftChild(treeIndex);
int rightTreeIndex = rightChild(treeIndex);
int mid = l + (r - l) / 2;
buildSegmentTree(leftTreeIndex, l, mid);
buildSegmentTree(rightTreeIndex, mid + 1, r);
// 根据具体场景自定义merge方法
tree[treeIndex] = merger.merge(tree[leftTreeIndex], tree[rightTreeIndex]);
}
/**
* 计算index节点左孩子的位置
* @param index
* @return
*/
private int leftChild(int index){
return 2 * index + 1;
}
/**
* 计算index节点左孩子的位置
* @param index
* @return
*/
private int rightChild(int index){
return 2 * index + 2;
}
/**
* 查询QueryL~QueryR之间的区间
* @param queryL
* @param queryR
* @return
*/
public E query(int queryL, int queryR){
if (queryL < 0 || queryL >=data.hashCode() ||
queryR < 0 || queryR >= data.length ||
queryL > queryR) {
throw new IllegalArgumentException("无效的区间[" + queryL + ", " + queryR + "]");
}
return query(0, 0, data.length - 1 , queryL, queryR);
}
/**
* 从treeIndex节点开始, 在l~r的范围内查找QueryL~QueryR之间的区间
* @param treeIndex
* @param queryL
* @param queryR
* @return
*/
private E query(int treeIndex, int l, int r, int queryL, int queryR){
// 递归终结条件, 左右边界相同时, 表示找到了对应的区间
if (l == queryL && r == queryR) {
return tree[treeIndex];
}
int mid = l + (r - l) / 2;
int leftTreeIndex = leftChild(treeIndex);
int rightTreeIndex = rightChild(treeIndex);
// 要查找的区间右边界小于mid时, 说明只需要到左子树进行查找即可
if (queryR <= mid) {
return query(leftTreeIndex, l, mid, queryL, queryR);
// 要查找的区间左边界大于mid时, 说明只需要到右子树进行查找即可
} else if (queryL > mid){
return query(rightTreeIndex, mid + 1, r, queryL, queryR);
// queryL <=mid < queryR这种情况需要对左右子树分别进行查找
} else { // queryL <=mid < queryR
return merger.merge(query(leftTreeIndex, l, mid, queryL, mid), query(rightTreeIndex, mid + 1, r, mid + 1, queryR));
}
}
/**
* 更新位置index的值
* @param index
* @param value
*/
public void set(int index, E value){
if(index < 0 || index >= data.length)
throw new IllegalArgumentException("下标越界");
data[index] = value;
set(0, 0, data.length - 1, index, value);
}
/**
* 在以treeIndex为根的线段树中更新index的值为e
* 递归算法
* @param treeIndex
* @param l
* @param r
* @param index
*/
private void set(int treeIndex, int l, int r, int index, E value){
// 递归终止条件
if (l == r) {
tree[treeIndex] = value;
return;
}
int leftTreeIndex = leftChild(treeIndex);
int rightTreeIndex = rightChild(treeIndex);
int mid = l + (r - l) / 2;
if (index <= mid) {
set(leftTreeIndex, l, mid, index, value);
} else { // index > mid
set(rightTreeIndex, mid + 1, r, index, value);
}
// 因为所有包含index区间的值都要更新, 所以需要对treeIndex节点进行一次merge操作
tree[treeIndex] = merger.merge(tree[leftTreeIndex], tree[rightTreeIndex]);
}
// size
public int getSize(){
return data.length;
}
// get
public E get(int index){
if (index < 0 || index >=data.length) {
throw new IllegalArgumentException("无效的位置 : " + index);
}
return data[index];
}
@Override
public String toString() {
StringBuilder res = new StringBuilder();
res.append("SegmentTree [");
for (int i = 0; i < tree.length; i++) {
if (tree[i] != null) {
res.append(tree[i]);
} else {
res.append("null");
}
if (i != tree.length -1) {
res.append(", ");
}
}
res.append("]");
return res.toString();
}
}
使用的merger融合器代码如下 :
package tree.segment;
/**
* 融合器
* 用于将两个元素融合成一个元素
* 配合线段树的合并操作使用
* @FunctionalInterface这个注解是jdk8中函数式接口声明, 加不加不影响
* @author 七夜雪
*
*/
@FunctionalInterface
public interface Merger<E> {
E merge(E a, E b);
}
使用Junit进行简单测试的代码如下 :
package tree.segment;
import org.junit.Test ;
public class SegmentTreeTest {
@Test
public void testBuild(){
Integer[] nums = {2, 3, 4, -1 , -2, 3};
// jdk8的lambda表达式写法
SegmentTree<Integer> segment = new SegmentTree<>(nums, (a, b) -> a + b);
System.out.println(segment) ;
System.out.println(segment.query(1, 3)) ;
}
@Test
public void testBuildSet(){
Integer[] nums = {2, 3, 4, -1 , -2, 3};
// jdk8的lambda表达式写法
SegmentTree<Integer> segment = new SegmentTree<>(nums, (a, b) -> a + b);
System.out.println(segment) ;
segment.set(3, 1);
segment.set(4, 2);
System.out.println(segment) ;
}
}