数据集: S = [ x 1 , x 2 , … , x n ] S=[x_1, ~x_2,~\dots~, ~x_n] S=[x1, x2, … , xn]
截图来自:北大公开课 算法设计与分析 最优二叉搜索树算法 ,下同
存取概率分布: P = [ a 0 , b 1 , a 1 , b 2 , … , a i , b i + 1 , … , b n , a n ] P=[a_0, ~b_1, ~a_1, ~b_2, ~\dots~, ~a_i, ~b_{i+1}, ~\dots~, ~b_n, ~a_n] P=[a0, b1, a1, b2, … , ai, bi+1, … , bn, an]
其中 a i a_i ai是结点匹配到的概率, b i b_i bi是结点不被匹配到的概率
结点 x i x_i xi在T中的深度是 d ( x i ) , i = 1 , 2 , … , n d(x_i), ~i=1,2,\dots,n d(xi), i=1,2,…,n
空隙 L i L_i Li的深度为 d ( L j ) , j = 0 , 1 , … , n d(L_j), ~j=0,1,\dots,n d(Lj), j=0,1,…,n
树的深度从0开始计算,根结点的深度为0
对于结点 x i x_i xi,其比较次数为深度+1(即 d ( x i ) + 1 d(x_i)+1 d(xi)+1),而空隙 L i L_i Li的比较次数恰好为深度的值 d ( L j ) d(L_j) d(Lj)
平均比较次数 t t t:
t = ∑ i = 1 n b i ( d ( x i ) + 1 ) + ∑ j = 0 n a j d ( L j ) t=\sum^n_{i=1}b_i(d(x_i)+1)+\sum^n_{j=0}a_jd(L_j) t=i=1∑nbi(d(xi)+1)+j=0∑najd(Lj)
子问题边界: ( i , j ) (i,j) (i,j)
数据集: S [ i , j ] = [ x i , x i + 1 , … , x j ] S[i,j]=[x_i,~x_{i+1},~\dots~,~x_j] S[i,j]=[xi, xi+1, … , xj]
存储概率分布: P [ i , j ] = [ a i − 1 , b i , a i , b i + 1 , … , b j , a j ] P[i,j]=[a_{i-1},\ b_i,\ a_i,\ b_{i+1},\ \dots,\ b_j,\ a_j] P[i,j]=[ai−1, bi, ai, bi+1, …, bj, aj]
对一棵树进行划分,假设以 x i x_i xi作为根结点,则其左右子树是两个子问题,子问题的划分分别为:
举例:
子问题的概率之和:
子问题界定为 S [ i , j ] S[i,j] S[i,j]和 P [ i , j ] P[i,j] P[i,j],令 w [ i j ] = ∑ p = i − 1 j a p + ∑ q = l j b q w[i\ j]=\displaystyle\sum^j_{p=i-1}a_p + \sum^j_{q=l} b_q w[i j]=p=i−1∑jap+q=l∑jbq为 P [ i , j ] P[i,j] P[i,j]中所有概率(数据和空隙)之和
设 m [ i , j ] m[i,j] m[i,j]是相对于输入 S [ i , j ] S[i,j] S[i,j]和 P [ i , j ] P[i,j] P[i,j]的最优二叉搜索树的最少平均比较次数
递推方程:
m [ i , j ] = min a ≤ k ≤ j { m [ i , k − 1 ] + m [ k + 1 , j ] + w [ i , j ] } m[i,\ j] = \min \limits_{a\le k \le j}\{ m[i,\ k-1] + m[k+1,\ j] + w[i,j] \} m[i, j]=a≤k≤jmin{m[i, k−1]+m[k+1, j]+w[i,j]}
即:左边子问题的最优比较次数,加上右边子问题的最优比较次数,再加上 i i i到 j j j的所有概率的总和 w [ i , j ] w[i,j] w[i,j]
为什么要加上 w [ i , j ] w[i, j] w[i,j]?因为在对左右子问题进行求解后,并上新的根结点 x k x_k xk时,左右子树的结点的深度都会增加1(如右图)。而此时,对子树的每个结点的影响,就是每个子树的比较次数都会增加1,那么就相当于概率多乘了一次,所以后面都加上了对应的 w w w。而每个结点的概率都多加了 w w w,再加上新增的根节点的概率 b k b_k bk,那么,合起来就相当于加了一个 w [ i , j ] w[i, j] w[i,j]
具体原因和公式推导可见下方
方程分解:
设 m [ i , j ] k m[i,j]_k m[i,j]k是以第 k k k个结点为根的相对于输入 S [ i , j ] S[i,j] S[i,j]和 P [ i , j ] P[i,j] P[i,j]的平均比较次数
m [ i , j ] k = m [ i , k − 1 ] + m [ k + 1 , j ] + w [ i , j ] m[i,\ j]_k = m[i,\ k-1] + m[k+1,\ j] + w[i,j] m[i, j]k=m[i, k−1]+m[k+1, j]+w[i,j]
总的最优平均比较次数为:
m [ i , j ] = min { m [ i , j ] k ∣ i ≤ k ≤ j } m[i,\ j]=\min\{ m[i,\ j]_k \ | \ i \le k \le j\} m[i, j]=min{m[i, j]k ∣ i≤k≤j}
也就是说,要求总的最优平均比较次数,需要将这一层的每一个结点都拿去当根结点试一试
递推方程公式证明:
假设令第 k k k个结点为根结点,其平均比较次数为:
m [ i , j ] k = ( m [ i , k − 1 ] + w [ i , k − 1 ] ) + ( m [ k + 1 , j ] + w [ k + 1 , j ] ) + 1 × b k = ( m [ i , k − 1 ] + m [ k + 1 , j ] ) + ( w [ i , k − 1 ] + b k + w [ k + 1 , j ] ) = m [ i , k − 1 ] + m [ k + 1 , j ] + w [ i , j ] \begin{aligned} m[i,\ j]_k &= (m[i,\ k-1] + w[i,\ k-1]) + (m[k+1,\ j] + w[k+1,\ j]) + 1\times b_k \\ &= (m[i,\ k-1] + m[k+1,\ j]) + (w[i,\ k-1]+b_k+w[k+1,\ j]) \\ &= m[i,\ k-1]+m[k+1,\ j] + w[i,\ j] \end{aligned} m[i, j]k=(m[i, k−1]+w[i, k−1])+(m[k+1, j]+w[k+1, j])+1×bk=(m[i, k−1]+m[k+1, j])+(w[i, k−1]+bk+w[k+1, j])=m[i, k−1]+m[k+1, j]+w[i, j]
注意第一行中, ( m [ i , k − 1 ] + w [ i , k − 1 ] ) + ( m [ k + 1 , j ] + w [ k + 1 , j ] ) (m[i, k-1] + w[i, k-1]) + (m[k+1, j] + w[k+1, j]) (m[i,k−1]+w[i,k−1])+(m[k+1,j]+w[k+1,j])都加了 w w w,这是由于左右子树的深度都增加了1带来的影响。而最后的 1 × b k 1\times b_k 1×bk就代表新的根结点,乘1是比较次数为1
初始值: m [ i , i − 1 ] = 0 m[i,\ i-1]=0 m[i, i−1]=0对应子问题为空的情况。
如左子树为空时,对应于 S [ 1 , 0 ] S[1,0] S[1,0], m [ 1 , 0 ] = 0 m[1,0]=0 m[1,0]=0的情况
m [ i , j ] m[i,\ j] m[i, j]中, i , j i,j i,j代表一个结点的区间,而在递推公式中, m [ i , j ] k = m [ i , k − 1 ] + m [ k + 1 , j ] + w [ i , j ] m[i,\ j]_k = m[i,\ k-1] + m[k+1,\ j] + w[i,j] m[i, j]k=m[i, k−1]+m[k+1, j]+w[i,j],
公式计算实例:北大公开课 算法设计与分析 最优二叉搜索树算法 19:00左右
w [ i ] [ j ] = ∑ p = i j a [ p ] + ∑ q = i j b [ q ] w[i][j] = \sum^j_{p=i}a[p] + \sum^j_{q=i}b[q] w[i][j]=p=i∑ja[p]+q=i∑jb[q]
用于编程时,可采用:
w [ i ] [ j ] = { a [ j ] , i = j + 1 w [ i ] [ j − 1 ] + b [ j ] + a [ j ] , i ≤ j w[i][j] = \begin{cases} a[j], & i=j+1 \\ w[i][j-1] + b[j] + a[j], & i \le j \end{cases} w[i][j]={a[j],w[i][j−1]+b[j]+a[j],i=j+1i≤j
其中,当 i = j i=j i=j时, w [ i ] [ j − 1 ] = 0 w[i][j-1]=0 w[i][j−1]=0,因此有 w [ i ] [ j ] = a [ j ] + b [ j ] w[i][j] = a[j]+b[j] w[i][j]=a[j]+b[j]
i , j i,j i,j的所有组合有 O ( n 2 ) O(n^2) O(n2)种,每种要对不同的 k k k进行计算, k = O ( n ) k=O(n) k=O(n),每次计算为常数时间
时间复杂度: T ( n ) = O ( n 3 ) T(n)=O(n^3) T(n)=O(n3)
空间复杂度: S ( n ) = O ( n 2 ) S(n)=O(n^2) S(n)=O(n2),为 m m m数组的大小
/**
* @author 寒洲
* @description 参考视频:https://www.bilibili.com/video/BV1Ls411W7PB?p=46
* @date 2021-05-20
*/
public class OptimalBinarySearchTree {
/**
* @param a 结点不被匹配的概率,下标为 [0, n),n为数据量(结点数)
* @param b 结点被匹配的概率,下标为 [1, n]
* @param m i到j的最优平均检索次数,也就是平均最优搜索值
* @param s 保存i到j的最优根节点,就是确定i到j把哪一个结点作为根结点最优
* @param w 区间[i, j]之间所有结点和空隙的命中概率。公式为:w[i][j] = a[i-1]+b[i]+...+b[j]+a[j]
*/
public static void optimalBinarySearchTree(float[] a, float[] b, float[][] m, int[][] s, float[][] w) {
// 数据个数,即结点个数
int n = a.length - 1;
// 赋初始值
for (int i = 0; i <= n; i++) {
/*
这个初始值之后会被用到。
这里只有a[i],是因为大多数情况下a都会比b多一个元素,
比如w[i][j] = a[i-1] + b[i] + a[i] + ... + b[j] + a[j],
可以看到除了第一个a[i-1],其余都可以在一个for循环中相加得到
此外,我们也可以发现w[i][j]对应a[i-1],所以下方,下标也是错开的
*/
w[i + 1][i] = a[i];
m[i + 1][i] = 0;
}
// r表示[i,j]区间的长度
for (int r = 0; r < n; r++) {
/*
下方的for循环中,r表示[i,j]区间的长度,i,j关系对应如下:
[1,1] [2,2] ... [n,n]
[1,2] [2,3] ... [n-1,n]
[1,3] [2,4] ... [n-2,n]
...
*/
for (int i = 1; i <= n - r; i++) {
int j = i + r;
/*
w的递推公式为:w[i][j] = a[i-1]+b[i]+...+b[j]+a[j] = w[i][j-1] + b[j] + a[j]
当r=0时,
j=i,则w[i][i-1]的值为0,其初始值已经在上方给出
所以当r=0时,w[i][j]=w[i][i]=a[i-1]+b[i]+a[i]
当r=1时,
w[i][j] = w[i][i+1] = w[i][i] + a[i+1] + b[i+1]
= a[i-1] + b[i] + a[i] + b[i+1] + a[i+1]
也就是在w[i][i]的基础上加上新增的a和b,
这样子就把[i,i+1]这个区间的概率都加进来了
*/
w[i][j] = w[i][j - 1] + a[j] + b[j];
// 这里先计算出一个值来,后续可以用于找出最小值。不然的话,数组元素默认值为0,不好比较
m[i][j] = m[i][i - 1] + m[i + 1][j];
s[i][j] = i;
// 上面的操作是针对[i][j]的,下面的操作都是发送在[i, j]区间内的
// k表示令第k个结点作为根节点。
// 在for循环中,要求以k为根节点的最优比较次数
for (int k = i + 1; k <= j; k++) {
/*
m[i][j] = min{ m[i][k-1] + m[k+1][j] + w[i][j] } (i <= k <= j)
每次都要+w[i][j]是因为:
因为在对左右子问题进行求解后,并上新的根结点x_k时,
左右子树的结点的深度都会增加1。而此时,对子树的每个结点的影响,
就是每个子树的比较次数都会增加1,那么就相当于概率加了一次。
而每个结点的概率都加1,再加上新增的根节点的概率,合起来就相当于加了一个w[i][j]
上式又可推导出
m[i][j] = w[i][j] + min(m[i][k-1] + m[k+1][j]) (1 <= k <= j)
*/
// 下方先求出min(m[i][k-1] + m[k+1][j]),得到一个min值
float min = m[i][k - 1] + m[k + 1][j];
if (min < m[i][j]) {
m[i][j] = min;
// 保存最优的根节点方案
s[i][j] = k;
}
// min值再加上w,得到m[i][j]
m[i][j] += w[i][j];
}
}
}
}
public static void optimalBinarySearchTree(float[] a, float[] b) {
if (a.length == b.length + 1) {
float[] b2 = new float[b.length + 1];
System.arraycopy(b, 0, b2, 1, b.length);
b = b2;
}
int n = b.length;
float[][] m = new float[n + 1][n + 1];
int[][] s = new int[n + 1][n + 1];
float[][] w = new float[n + 1][n + 1];
optimalBinarySearchTree(a, b, m, s, w);
display(m);
display(s);
display(w);
}
public static void main(String[] args) {
float[] b = {0.1F, 0.3F, 0.1F, 0.2F, 0.1F};
float[] a = {0.04F, 0.02F, 0.02F, 0.05F, 0.06F, 0.01F};
optimalBinarySearchTree(a, b);
}
public static <T> void display(float[][] arr) {
for (float[] ts : arr) {
for (float t : ts) {
System.out.print(t + "\t");
}
System.out.println();
}
System.out.println();
}
public static <T> void display(int[][] arr) {
for (int[] ts : arr) {
for (int t : ts) {
System.out.print(t + " ");
}
System.out.println();
}
System.out.println();
}
}