对于一个给定的概率集合,我们希望构造一棵期望搜索代价最小的二叉搜索树,我们称之为最优二叉搜索树
最近复习了一下算法导论的动态规划,发现最优二叉搜索树的实现方法与矩阵链连乘问题极其相似。
根据动态规划bottom -> top的思想,我们应该从length = 1的情况逐渐求解到length = n的情况,对于每一个i = length,我们从起始位置 j = 1开始,加上length - 1作为当前序列的end index,遍历 j = 1 -> n - i + 1(因为每一次长度增长,对应的可遍历的序列数量就会 - 1,打个比方,如果二叉搜索树序列为<1,4,5,6>,第一次遍历时会分别求出cost[1,1],cost[2,2],cost[3,3]和cost[4,4],第二次遍历仍然从index = 1开始,但是end index = 1 + 2 - 1 = 1,也就是第二次遍历会求出cost[1,2],cost[2,3]和cost[3,4]…按照这种遍历方式,最终求出cost[1,n],n是二叉搜索树序列长度,同时使用一个n * n数组root记录从i -> j的节点的父节点)
public class OptimalBinarySearchTree {
private static final double[] intlNodes = {0, 0.15, 0.1, 0.05, 0.1, 0.2};
private static final double[] leafNodes = {0.05, 0.1, 0.05, 0.05, 0.05, 0.1};
private static final int length = intlNodes.length;
/**
* This is a brutal solution
* Top -> Bottom
* @param i 关键字搜索起点
* @param j 关键字搜索终点
* @param realNodes 关键字数组
* @param fakeNodes 伪关键字数组
* @return 最小搜索期望
*/
private static double solution(int i, int j, double[] realNodes, double[] fakeNodes){
//平凡case,子树不再包含关键字,而只有一个伪关键字,返回当前伪关键字
if(j == i - 1){
return fakeNodes[j];
}
//否则计算从i -> j的关键字的最小搜索期望
double minExpect = Double.MAX_VALUE;
double weight = 0;
//根据w[i,j] = 包含了i -> j的关键字节点和伪关键字节点的概率和
weight += fakeNodes[i-1];
for (int sw = i; sw <= j; sw++) {
weight += realNodes[sw];
weight += fakeNodes[sw];
}
for (int start = i; start <= j; start++) {
minExpect = Math.min(minExpect, solution(i, start - 1, realNodes, fakeNodes)
+ solution(start + 1, j, realNodes, fakeNodes) + weight);
}
return minExpect;
}
/**
* 带备忘自顶向下
* @param i 关键字起始位置
* @param j 关键字结束位置
* @param realNodes 关键字数组
* @param fakeNodes 伪关键字数组
* @param memo 备忘
* @return 最小搜索代价
*/
private static double solutionWithMemo(int i, int j, double[] realNodes, double[] fakeNodes, double[][] memo){
if (memo[i][j] > 0) {
return memo[i][j];
}
if (j == i - 1) {
return fakeNodes[j];
}
double weight = 0;
weight += fakeNodes[i-1];
for (int sw = i; sw <= j; sw++) {
weight += realNodes[sw];
weight += fakeNodes[sw];
}
double cost = Double.MAX_VALUE;
for (int start = i; start <= j; start++) {
cost = Math.min(cost, solutionWithMemo(i, start - 1, realNodes, fakeNodes, memo)
+ solutionWithMemo(start + 1, j, realNodes, fakeNodes, memo) + weight);
}
memo[i][j] = cost;
return cost;
}
/**
* dp solution
* bottom -> top
* @param realNodes 实际的节点数组
* @param fakeNodes 伪节点数组
* @return 最小期望搜索代价
*/
private static double dpSolution(double[] realNodes, double[] fakeNodes, int length){
//记录i -> j节点的最小搜索期望,由于包含了伪关键字dn,所以存在costE[length + 1][length]的情况
double[][] costE = new double[length+2][length+1];
//记录i -> j的搜索概率,包含了关键字i -> j的概率和以及关键字对应的伪关键字 (i - 1) -> j的概率和
double[][] weight = new double[length+2][length+1];
for (int i = 1; i <= length + 1; i++){
costE[i][i-1] = fakeNodes[i-1];
weight[i][i-1] = fakeNodes[i-1];
}
for (int i = 1; i <= length; i++) {
//根据计算规模,确定每一计算规模下的计算次数,确定计算的节点起始和终止位置
//根据已经求得的子问题计算当前最小期望
int loops = length - i + 1;
for (int j = 1; j <= loops; j++) {
int end = j + i - 1;
double cost = Double.MAX_VALUE;
weight[j][end] = weight[j][end-1] + realNodes[end] + fakeNodes[end];
for (int k = j; k <= end; k++){
cost = Math.min(cost, costE[j][k-1] + costE[k+1][end] + weight[j][end]);
}
costE[j][end] = cost;
}
}
return costE[1][length];
}
}
做一组数据测试:
public static void main(String[] args) {
System.out.println(solution(1, intlNodes.length - 1, intlNodes, leafNodes));
System.out.println(dpSolution(intlNodes, leafNodes, intlNodes.length - 1));
double[][] memo = new double[length+1][length+1];
for (int i = 0; i <= length; i++) {
Arrays.fill(memo[i], -1);
}
System.out.println(solutionWithMemo(1, length - 1, intlNodes, leafNodes, memo));
}