树状数组Binary Indexed Trees详解与Java实现

WiKi

树状数组是由Peter Fenwick在1994年提出的,所以又称为Fenwick Tree。数组的区间求和的复杂度是O(n),树状数组可以将数组区间求和的复杂度降低到O(lg n)。这对于长数组的高频率区间求和的应用场景来讲,可以提高效率。

参考

树状数组(Binary Indexed Trees)
搞懂树状数组

详解

这里从上面的参考中总结我的思路。
树状数组通过树形结构对原始数组进行预处理,树的每个节点存储了原始数组的某些区间和,这样,原始数组的区间和就可以通过树形结构将算法复杂度降低到O(lg n)。
例如,下图a数组是原始数组,而c数组就是对a数组进行预处理后的树行结构。
树状数组Binary Indexed Trees详解与Java实现_第1张图片

从a数组如何获得c数组呢?这里要借助lowBit(int K)函数。lowBit(int K)函数保留k的二进制最低位1的值。例如,1110保留最低位1即0010。

private static int lowBit(int k){
        return k&-k;
    }

从a数组和c数组满足下面的累加公式。即,c[i]等于从a数组从i开始向左累加lowBit(i)个值。
树状数组Binary Indexed Trees详解与Java实现_第2张图片
例如:
c[1]等于a数组从1开始向左累加lowBit(1)=1个,即c[1]=a[1];
c[2]等于a数组从2开始向左累加lowBit(2)=2个,即c[2]=a[2]+a[1];
……
c[7]等于a数组从2开始向左累加lowBit(7)=1个,即c[2]=a[7];
c[8]等于a数组从2开始向左累加lowBit(8)=8个,即c[2]=a[8]+a[7]+…a[1];
那么,构造好了c数组,求和就变得简单了。
这里写图片描述
直到求和下标为1即累加到了a[1]。
树状数组Binary Indexed Trees详解与Java实现_第3张图片
代码实现:

/**
     * 计算1~index范围内和
     * index一直减去lowBit(index),直到index为0
     * */
    public int sum(int index){
        if (index<1&&index>length) {
            throw new IllegalArgumentException("Out of Range!");
        }
        int sum=0;
        while (index>0) {
            sum+=tree[index];
            index-=lowBit(index);
        }
        return sum;
    }

到这里,求和的方式已经介绍完毕。
两点注意:
1、实际上并没有一个原始数组a一直存在,否则就多耗费了一倍的存储空间。
2、数组一般从1开始算有效位,0位置为无效位,这样不容易混淆。

Java实现

这里实现了BinaryIndexedTree的类。

/**
 * 树状数组的Java版本,created by 曹艳丰  2016.07.09
 * 原理参考:http://www.hawstein.com/posts/binary-indexed-trees.html
 * 或:https://www.topcoder.com/community/data-science/data-science-tutorials/binary-indexed-trees/
 * */
public class BinaryIndexedTree {
    public int length;
    private int[] tree;
    /**
     * 为了统一下标,所以tree[0]不被使用,数组有效范围1~length。
     * */
    public BinaryIndexedTree(int length){
        this.length=length;
        tree=new int[length+1];
    }
    /**
     * 计算1~index范围内和
     * index一直减去lowBit(index),直到index为0
     * */
    public int sum(int index){
        if (index<1&&index>length) {
            throw new IllegalArgumentException("Out of Range!");
        }
        int sum=0;
        while (index>0) {
            sum+=tree[index];
            index-=lowBit(index);
        }
        return sum;
    }
    /**
     * 计算start~end范围内和
     * */
    public int  sum(int start,int end) {
        return sum(end)-sum(start-1);
    }
    /**
     * index一直加上lowBit(index),直到index为length。这些位置的值都加上value
     * */
    public void put(int index,int value){
        if (index<1&&index>length) {
            throw new IllegalArgumentException("Out of Range!");
        }
        while (index<=length) {
            tree[index]+=value;
            index+=lowBit(index);
        }
    }
    /**
     * index一直减去lowBit(index),直到index为length。这些位置的值都加上value
     * */
    public int get(int index){
        if (index<1&&index>length) {
            throw new IllegalArgumentException("Out of Range!");
        }
        int sum=tree[index];
        int z=index-lowBit(index);
        index--;
        while (index!=z) {
            sum-=tree[index];
            index-=lowBit(index);
        }
        return sum;
    }
    /**
     * 保留k的二进制最低位1的值。例如,1110保留最低位1即0010.
     * */
    private static int lowBit(int k){
        return k&-k;
    }
}

并进行了测试

import java.util.Random;

public class Main {

    public static void main(String[] args) {
        // TODO Auto-generated method stub
        int length=15;
        BinaryIndexedTree bTree=new BinaryIndexedTree(length);
        Random random=new Random();
        //随机放满数据
        for (int i = 1; i <= bTree.length; i++) {
            bTree.put(i, random.nextInt(100));
        }
        //取出每一位
        for (int i = 1; i <= bTree.length; i++) {
            int value=bTree.get(i);
            System.out.printf("%3d",value);
            System.out.print("  ");
        }
        System.out.println();
        //计算0~i的和
        for (int i = 1; i <= bTree.length; i++) {
            int sum=bTree.sum(i);

            System.out.printf("%3d",sum);
            System.out.print("  ");
        }
        System.out.println();
        //计算start~end的和
        System.out.printf("%3d",bTree.sum(2,4));
    }

}

你可能感兴趣的:(《算法导论》学习)