最优二叉搜索树 - Java实现(递归,带备忘录递归,动态规划)

对于一个给定的概率集合,我们希望构造一棵期望搜索代价最小的二叉搜索树,我们称之为最优二叉搜索树

最近复习了一下算法导论的动态规划,发现最优二叉搜索树的实现方法与矩阵链连乘问题极其相似。
根据动态规划bottom -> top的思想,我们应该从length = 1的情况逐渐求解到length = n的情况,对于每一个i = length,我们从起始位置 j = 1开始,加上length - 1作为当前序列的end index,遍历 j = 1 -> n - i + 1(因为每一次长度增长,对应的可遍历的序列数量就会 - 1,打个比方,如果二叉搜索树序列为<1,4,5,6>,第一次遍历时会分别求出cost[1,1],cost[2,2],cost[3,3]和cost[4,4],第二次遍历仍然从index = 1开始,但是end index = 1 + 2 - 1 = 1,也就是第二次遍历会求出cost[1,2],cost[2,3]和cost[3,4]…按照这种遍历方式,最终求出cost[1,n],n是二叉搜索树序列长度,同时使用一个n * n数组root记录从i -> j的节点的父节点)

public class OptimalBinarySearchTree {
    private static final double[] intlNodes = {0, 0.15, 0.1, 0.05, 0.1, 0.2};
    private static final double[] leafNodes = {0.05, 0.1, 0.05, 0.05, 0.05, 0.1};
    private static final int length = intlNodes.length;

    /**
     * This is a brutal solution
     * Top -> Bottom
     * @param i 关键字搜索起点
     * @param j 关键字搜索终点
     * @param realNodes 关键字数组
     * @param fakeNodes 伪关键字数组
     * @return  最小搜索期望
     */
    private static double solution(int i, int j, double[] realNodes, double[] fakeNodes){
        //平凡case,子树不再包含关键字,而只有一个伪关键字,返回当前伪关键字
        if(j == i - 1){
            return fakeNodes[j];
        }
        //否则计算从i -> j的关键字的最小搜索期望
        double minExpect = Double.MAX_VALUE;
        double weight = 0;
        //根据w[i,j] = 包含了i -> j的关键字节点和伪关键字节点的概率和
        weight += fakeNodes[i-1];
        for (int sw = i; sw <= j; sw++) {
            weight += realNodes[sw];
            weight += fakeNodes[sw];
        }
        for (int start = i; start <= j; start++) {
            minExpect = Math.min(minExpect, solution(i, start - 1, realNodes, fakeNodes)
             + solution(start + 1, j, realNodes, fakeNodes) + weight);
        }
        return minExpect;
    }

    /**
     * 带备忘自顶向下
     * @param i 关键字起始位置
     * @param j 关键字结束位置
     * @param realNodes 关键字数组
     * @param fakeNodes 伪关键字数组
     * @param memo  备忘
     * @return  最小搜索代价
     */
    private static double solutionWithMemo(int i, int j, double[] realNodes, double[] fakeNodes, double[][] memo){
        if (memo[i][j] > 0) {
            return memo[i][j];
        }
        if (j == i - 1) {
            return fakeNodes[j];
        }
        double weight = 0;
        weight += fakeNodes[i-1];
        for (int sw = i; sw <= j; sw++) {
            weight += realNodes[sw];
            weight += fakeNodes[sw];
        }

        double cost = Double.MAX_VALUE;
        for (int start = i; start <= j; start++) {
            cost = Math.min(cost, solutionWithMemo(i, start - 1, realNodes, fakeNodes, memo)
             + solutionWithMemo(start + 1, j, realNodes, fakeNodes, memo) + weight);
        }
        memo[i][j] = cost;
        return cost;
    }

    /**
     * dp solution
     * bottom -> top
     * @param realNodes 实际的节点数组
     * @param fakeNodes 伪节点数组
     * @return  最小期望搜索代价
     */
    private static double dpSolution(double[] realNodes, double[] fakeNodes, int length){
        //记录i -> j节点的最小搜索期望,由于包含了伪关键字dn,所以存在costE[length + 1][length]的情况
        double[][] costE = new double[length+2][length+1];
        //记录i -> j的搜索概率,包含了关键字i -> j的概率和以及关键字对应的伪关键字 (i - 1) -> j的概率和
        double[][] weight = new double[length+2][length+1];
        for (int i = 1; i <= length + 1; i++){
            costE[i][i-1] = fakeNodes[i-1];
            weight[i][i-1] = fakeNodes[i-1];
        }

        for (int i = 1; i <= length; i++) {
            //根据计算规模,确定每一计算规模下的计算次数,确定计算的节点起始和终止位置
            //根据已经求得的子问题计算当前最小期望
            int loops = length - i + 1;
            for (int j = 1; j <= loops; j++) {
                int end = j + i - 1;
                double cost = Double.MAX_VALUE;
                weight[j][end] = weight[j][end-1] + realNodes[end] + fakeNodes[end];
                for (int k = j; k <= end; k++){
                    cost = Math.min(cost, costE[j][k-1] + costE[k+1][end] + weight[j][end]);
                }
                costE[j][end] = cost;
            }
        }

        return costE[1][length];
    }
}

做一组数据测试:

    public static void main(String[] args) {
        System.out.println(solution(1, intlNodes.length - 1, intlNodes, leafNodes));
        System.out.println(dpSolution(intlNodes, leafNodes, intlNodes.length - 1));
        double[][] memo = new double[length+1][length+1];
        for (int i = 0; i <= length; i++) {
            Arrays.fill(memo[i], -1);
        }
        System.out.println(solutionWithMemo(1, length - 1, intlNodes, leafNodes, memo));
    }

你可能感兴趣的:(算法)