最优二叉搜索树算法 java实现

最优二叉搜索树

定义

数据集: S = [ x 1 ,   x 2 ,   …   ,   x n ] S=[x_1, ~x_2,~\dots~, ~x_n] S=[x1, x2,  , xn]

最优二叉搜索树算法 java实现_第1张图片

截图来自:北大公开课 算法设计与分析 最优二叉搜索树算法 ,下同

存取概率分布: P = [ a 0 ,   b 1 ,   a 1 ,   b 2 ,   …   ,   a i ,   b i + 1 ,   …   ,   b n ,   a n ] P=[a_0, ~b_1, ~a_1, ~b_2, ~\dots~, ~a_i, ~b_{i+1}, ~\dots~, ~b_n, ~a_n] P=[a0, b1, a1, b2,  , ai, bi+1,  , bn, an]

其中 a i a_i ai是结点匹配到的概率, b i b_i bi是结点不被匹配到的概率

结点 x i x_i xi在T中的深度是 d ( x i ) ,   i = 1 , 2 , … , n d(x_i), ~i=1,2,\dots,n d(xi), i=1,2,,n

空隙 L i L_i Li的深度为 d ( L j ) ,   j = 0 , 1 , … , n d(L_j), ~j=0,1,\dots,n d(Lj), j=0,1,,n

树的深度从0开始计算,根结点的深度为0

对于结点 x i x_i xi,其比较次数为深度+1(即 d ( x i ) + 1 d(x_i)+1 d(xi)+1),而空隙 L i L_i Li的比较次数恰好为深度的值 d ( L j ) d(L_j) d(Lj)

平均比较次数 t t t
t = ∑ i = 1 n b i ( d ( x i ) + 1 ) + ∑ j = 0 n a j d ( L j ) t=\sum^n_{i=1}b_i(d(x_i)+1)+\sum^n_{j=0}a_jd(L_j) t=i=1nbi(d(xi)+1)+j=0najd(Lj)


构建动态规划

步骤
  1. 子问题边界限定
  2. 如何将该问题归结为更小的子问题
  3. 优化函数的递推方程及初始值
  4. 计算顺序
  5. 是否需要标记函数
  6. 时间复杂度

子问题划分

子问题边界: ( i , j ) (i,j) (i,j)

数据集: S [ i , j ] = [ x i ,   x i + 1 ,   …   ,   x j ] S[i,j]=[x_i,~x_{i+1},~\dots~,~x_j] S[i,j]=[xi, xi+1,  , xj]

存储概率分布: P [ i , j ] = [ a i − 1 ,   b i ,   a i ,   b i + 1 ,   … ,   b j ,   a j ] P[i,j]=[a_{i-1},\ b_i,\ a_i,\ b_{i+1},\ \dots,\ b_j,\ a_j] P[i,j]=[ai1, bi, ai, bi+1, , bj, aj]

对一棵树进行划分,假设以 x i x_i xi作为根结点,则其左右子树是两个子问题,子问题的划分分别为:

  1. 左子树: S [ i ,   k − 1 ] ,   P [ i ,   k − 1 ] S[i,\ k-1],\ P[i,\ k-1] S[i, k1], P[i, k1]
  2. 右子树: S [ k + 1 ,   j ] ,   P [ k + 1 ,   j ] S[k+1,\ j],\ P[k+1,\ j] S[k+1, j], P[k+1, j]

举例:

最优二叉搜索树算法 java实现_第2张图片


子问题的概率之和
子问题界定为 S [ i , j ] S[i,j] S[i,j] P [ i , j ] P[i,j] P[i,j],令 w [ i   j ] = ∑ p = i − 1 j a p + ∑ q = l j b q w[i\ j]=\displaystyle\sum^j_{p=i-1}a_p + \sum^j_{q=l} b_q w[i j]=p=i1jap+q=ljbq P [ i , j ] P[i,j] P[i,j]中所有概率(数据和空隙)之和


优化函数的递推方程

m [ i , j ] m[i,j] m[i,j]是相对于输入 S [ i , j ] S[i,j] S[i,j] P [ i , j ] P[i,j] P[i,j]的最优二叉搜索树的最少平均比较次数

递推方程:
m [ i ,   j ] = min ⁡ a ≤ k ≤ j { m [ i ,   k − 1 ] + m [ k + 1 ,   j ] + w [ i , j ] } m[i,\ j] = \min \limits_{a\le k \le j}\{ m[i,\ k-1] + m[k+1,\ j] + w[i,j] \} m[i, j]=akjmin{m[i, k1]+m[k+1, j]+w[i,j]}
即:左边子问题的最优比较次数,加上右边子问题的最优比较次数,再加上 i i i j j j的所有概率的总和 w [ i , j ] w[i,j] w[i,j]

最优二叉搜索树算法 java实现_第3张图片

为什么要加上 w [ i , j ] w[i, j] w[i,j]?因为在对左右子问题进行求解后,并上新的根结点 x k x_k xk时,左右子树的结点的深度都会增加1(如右图)。而此时,对子树的每个结点的影响,就是每个子树的比较次数都会增加1,那么就相当于概率多乘了一次,所以后面都加上了对应的 w w w。而每个结点的概率都多加了 w w w,再加上新增的根节点的概率 b k b_k bk,那么,合起来就相当于加了一个 w [ i , j ] w[i, j] w[i,j]

具体原因和公式推导可见下方


方程分解:

m [ i , j ] k m[i,j]_k m[i,j]k以第 k k k个结点为根的相对于输入 S [ i , j ] S[i,j] S[i,j] P [ i , j ] P[i,j] P[i,j]的平均比较次数
m [ i ,   j ] k = m [ i ,   k − 1 ] + m [ k + 1 ,   j ] + w [ i , j ] m[i,\ j]_k = m[i,\ k-1] + m[k+1,\ j] + w[i,j] m[i, j]k=m[i, k1]+m[k+1, j]+w[i,j]
总的最优平均比较次数为:
m [ i ,   j ] = min ⁡ { m [ i ,   j ] k   ∣   i ≤ k ≤ j } m[i,\ j]=\min\{ m[i,\ j]_k \ | \ i \le k \le j\} m[i, j]=min{m[i, j]k  ikj}

也就是说,要求总的最优平均比较次数,需要将这一层的每一个结点都拿去当根结点试一试


递推方程公式证明:

假设令第 k k k个结点为根结点,其平均比较次数为:

m [ i ,   j ] k = ( m [ i ,   k − 1 ] + w [ i ,   k − 1 ] ) + ( m [ k + 1 ,   j ] + w [ k + 1 ,   j ] ) + 1 × b k = ( m [ i ,   k − 1 ] + m [ k + 1 ,   j ] ) + ( w [ i ,   k − 1 ] + b k + w [ k + 1 ,   j ] ) = m [ i ,   k − 1 ] + m [ k + 1 ,   j ] + w [ i ,   j ] \begin{aligned} m[i,\ j]_k &= (m[i,\ k-1] + w[i,\ k-1]) + (m[k+1,\ j] + w[k+1,\ j]) + 1\times b_k \\ &= (m[i,\ k-1] + m[k+1,\ j]) + (w[i,\ k-1]+b_k+w[k+1,\ j]) \\ &= m[i,\ k-1]+m[k+1,\ j] + w[i,\ j] \end{aligned} m[i, j]k=(m[i, k1]+w[i, k1])+(m[k+1, j]+w[k+1, j])+1×bk=(m[i, k1]+m[k+1, j])+(w[i, k1]+bk+w[k+1, j])=m[i, k1]+m[k+1, j]+w[i, j]

注意第一行中, ( m [ i , k − 1 ] + w [ i , k − 1 ] ) + ( m [ k + 1 , j ] + w [ k + 1 , j ] ) (m[i, k-1] + w[i, k-1]) + (m[k+1, j] + w[k+1, j]) (m[i,k1]+w[i,k1])+(m[k+1,j]+w[k+1,j])都加了 w w w,这是由于左右子树的深度都增加了1带来的影响。而最后的 1 × b k 1\times b_k 1×bk就代表新的根结点,乘1是比较次数为1


初始值: m [ i ,   i − 1 ] = 0 m[i,\ i-1]=0 m[i, i1]=0对应子问题为空的情况。

如左子树为空时,对应于 S [ 1 , 0 ] S[1,0] S[1,0] m [ 1 , 0 ] = 0 m[1,0]=0 m[1,0]=0的情况

m [ i ,   j ] m[i,\ j] m[i, j]中, i , j i,j i,j代表一个结点的区间,而在递推公式中, m [ i ,   j ] k = m [ i ,   k − 1 ] + m [ k + 1 ,   j ] + w [ i , j ] m[i,\ j]_k = m[i,\ k-1] + m[k+1,\ j] + w[i,j] m[i, j]k=m[i, k1]+m[k+1, j]+w[i,j]

公式计算实例:北大公开课 算法设计与分析 最优二叉搜索树算法 19:00左右


w的递推公式

w [ i ] [ j ] = ∑ p = i j a [ p ] + ∑ q = i j b [ q ] w[i][j] = \sum^j_{p=i}a[p] + \sum^j_{q=i}b[q] w[i][j]=p=ija[p]+q=ijb[q]

用于编程时,可采用:
w [ i ] [ j ] = { a [ j ] , i = j + 1 w [ i ] [ j − 1 ] + b [ j ] + a [ j ] , i ≤ j w[i][j] = \begin{cases} a[j], & i=j+1 \\ w[i][j-1] + b[j] + a[j], & i \le j \end{cases} w[i][j]={a[j],w[i][j1]+b[j]+a[j],i=j+1ij
其中,当 i = j i=j i=j时, w [ i ] [ j − 1 ] = 0 w[i][j-1]=0 w[i][j1]=0,因此有 w [ i ] [ j ] = a [ j ] + b [ j ] w[i][j] = a[j]+b[j] w[i][j]=a[j]+b[j]


时间复杂度估计

i , j i,j i,j的所有组合有 O ( n 2 ) O(n^2) O(n2)种,每种要对不同的 k k k进行计算, k = O ( n ) k=O(n) k=O(n),每次计算为常数时间

时间复杂度: T ( n ) = O ( n 3 ) T(n)=O(n^3) T(n)=O(n3)

空间复杂度: S ( n ) = O ( n 2 ) S(n)=O(n^2) S(n)=O(n2),为 m m m数组的大小


代码实现

/**
 * @author 寒洲
 * @description 参考视频:https://www.bilibili.com/video/BV1Ls411W7PB?p=46
 * @date 2021-05-20
 */
public class OptimalBinarySearchTree {

    /**
     * @param a 结点不被匹配的概率,下标为 [0, n),n为数据量(结点数)
     * @param b 结点被匹配的概率,下标为 [1, n]
     * @param m i到j的最优平均检索次数,也就是平均最优搜索值
     * @param s 保存i到j的最优根节点,就是确定i到j把哪一个结点作为根结点最优
     * @param w 区间[i, j]之间所有结点和空隙的命中概率。公式为:w[i][j] = a[i-1]+b[i]+...+b[j]+a[j]
     */
    public static void optimalBinarySearchTree(float[] a, float[] b, float[][] m, int[][] s, float[][] w) {
        // 数据个数,即结点个数
        int n = a.length - 1;
        // 赋初始值
        for (int i = 0; i <= n; i++) {
            /*
            这个初始值之后会被用到。
            这里只有a[i],是因为大多数情况下a都会比b多一个元素,
            比如w[i][j] = a[i-1] + b[i] + a[i] + ... + b[j] + a[j],
            可以看到除了第一个a[i-1],其余都可以在一个for循环中相加得到
            此外,我们也可以发现w[i][j]对应a[i-1],所以下方,下标也是错开的
            */
            w[i + 1][i] = a[i];
            m[i + 1][i] = 0;
        }

        // r表示[i,j]区间的长度
        for (int r = 0; r < n; r++) {
            /*
                下方的for循环中,r表示[i,j]区间的长度,i,j关系对应如下:
                [1,1] [2,2] ... [n,n]
                [1,2] [2,3] ... [n-1,n]
                [1,3] [2,4] ... [n-2,n]
                ...
             */
            for (int i = 1; i <= n - r; i++) {
                int j = i + r;
                /*
                w的递推公式为:w[i][j] = a[i-1]+b[i]+...+b[j]+a[j] = w[i][j-1] + b[j] + a[j]
                当r=0时,
                    j=i,则w[i][i-1]的值为0,其初始值已经在上方给出
                    所以当r=0时,w[i][j]=w[i][i]=a[i-1]+b[i]+a[i]
                当r=1时,
                    w[i][j] = w[i][i+1] = w[i][i] + a[i+1] + b[i+1]
                                        = a[i-1] + b[i] + a[i] +  b[i+1] + a[i+1]
                    也就是在w[i][i]的基础上加上新增的a和b,
                    这样子就把[i,i+1]这个区间的概率都加进来了
                 */
                w[i][j] = w[i][j - 1] + a[j] + b[j];
                // 这里先计算出一个值来,后续可以用于找出最小值。不然的话,数组元素默认值为0,不好比较
                m[i][j] = m[i][i - 1] + m[i + 1][j];
                s[i][j] = i;

                // 上面的操作是针对[i][j]的,下面的操作都是发送在[i, j]区间内的

                // k表示令第k个结点作为根节点。
                // 在for循环中,要求以k为根节点的最优比较次数
                for (int k = i + 1; k <= j; k++) {
                    /*
                    m[i][j] = min{ m[i][k-1] + m[k+1][j] + w[i][j] } (i <= k <= j)
                    每次都要+w[i][j]是因为:
                        因为在对左右子问题进行求解后,并上新的根结点x_k时,
                        左右子树的结点的深度都会增加1。而此时,对子树的每个结点的影响,
                        就是每个子树的比较次数都会增加1,那么就相当于概率加了一次。
                        而每个结点的概率都加1,再加上新增的根节点的概率,合起来就相当于加了一个w[i][j]

                    上式又可推导出
                    m[i][j] = w[i][j] + min(m[i][k-1] + m[k+1][j])  (1 <= k <= j)
                     */
                    // 下方先求出min(m[i][k-1] + m[k+1][j]),得到一个min值
                    float min = m[i][k - 1] + m[k + 1][j];
                    if (min < m[i][j]) {
                        m[i][j] = min;
                        // 保存最优的根节点方案
                        s[i][j] = k;
                    }
                    // min值再加上w,得到m[i][j]
                    m[i][j] += w[i][j];
                }
            }
        }
    }

    public static void optimalBinarySearchTree(float[] a, float[] b) {
        if (a.length == b.length + 1) {
            float[] b2 = new float[b.length + 1];
            System.arraycopy(b, 0, b2, 1, b.length);
            b = b2;
        }
        int n = b.length;
        float[][] m = new float[n + 1][n + 1];
        int[][] s = new int[n + 1][n + 1];
        float[][] w = new float[n + 1][n + 1];
        optimalBinarySearchTree(a, b, m, s, w);
        display(m);
        display(s);
        display(w);
    }


    public static void main(String[] args) {
        float[] b = {0.1F, 0.3F, 0.1F, 0.2F, 0.1F};
        float[] a = {0.04F, 0.02F, 0.02F, 0.05F, 0.06F, 0.01F};
        optimalBinarySearchTree(a, b);
    }

    public static <T> void display(float[][] arr) {
        for (float[] ts : arr) {
            for (float t : ts) {
                System.out.print(t + "\t");
            }
            System.out.println();
        }
        System.out.println();
    }

    public static <T> void display(int[][] arr) {
        for (int[] ts : arr) {
            for (int t : ts) {
                System.out.print(t + " ");
            }
            System.out.println();
        }
        System.out.println();
    }
}


你可能感兴趣的:(算法题解,算法,算法导论,动态规划,二叉树,最优二叉搜索树)