find the subtree with max sum

Problem:
Given a binary tree, each node has an integer value attached (can be negative), write code to find the subtree with the max sum.

My solution:
Basically this is a recursive problem, for each subtree result, you should remember its maxSum and corresponding root node. So I use a struct SubTreeResult to remember all such info.

Note: The code I pasted here is not tested sufficiently, so don't take it as correct solution unless you verify it yourself.

package alg;

import java.util.ArrayList;

public class MaxSumTree {
    
    class Node {
        int value;
        Node left, right;
        Node(int v, Node l, Node r){
            value = v;
            left = l;
            right = r;
        }
    }
    
    class SubTreeResult {
        Node subTreeRoot, maxSumNode;
        int sum, maxSum;
        SubTreeResult(Node r, Node maxNode, int maxSum, int sum){
            this.subTreeRoot = r;
            this.maxSumNode = maxNode;
            this.sum = sum;
            this.maxSum = maxSum;
        }
    }
    
    public SubTreeResult findMaxSubTree(Node r){
        SubTreeResult leftResult = null, rightResult = null, result = null;    
        if(r.left != null){
            leftResult = findMaxSubTree(r.left);
        }
        if(r.right != null){
            rightResult = findMaxSubTree(r.right);
        }
        
        if(leftResult != null && rightResult != null){
            int sum = r.value + leftResult.sum + rightResult.sum;
            if(leftResult.maxSum >= rightResult.maxSum){                
                if(sum >= leftResult.maxSum){
                    result = new SubTreeResult(r, r, sum, sum); 
                }else {
                    result = new SubTreeResult(r, leftResult.maxSumNode, leftResult.maxSum, sum);
                }
            }else {
                if(sum >= rightResult.maxSum){
                    result = new SubTreeResult(r, r, sum, sum); 
                }else {
                    result = new SubTreeResult(r, rightResult.maxSumNode, rightResult.maxSum, sum);
                }                
            }
        }else if(leftResult != null){
            int sum = r.value + leftResult.sum;
                            
            if(sum >= leftResult.maxSum){
                result = new SubTreeResult(r, r, sum, sum); 
            }else {
                result = new SubTreeResult(r, leftResult.maxSumNode, leftResult.maxSum, sum);
            }            
        }else if(rightResult != null){
            int sum = r.value + rightResult.sum;
            if(sum >= rightResult.maxSum){
                result = new SubTreeResult(r, r, sum, sum); 
            }else {
                result = new SubTreeResult(r, rightResult.maxSumNode, rightResult.maxSum, sum);
            }            
        }else {
            result = new SubTreeResult(r, r, r.value, r.value);
        }
        return result;
    }
    
    public Node buildBinaryTree(ArrayList<Integer> list){
        int size = list.size();
        if(size == 0) return null;
        
        Node [] nodes = new Node[size];
        for(int i = 0; i< size; i++){
            int number = list.get(i);
            nodes[i] = new Node(number, null, null);
            int idx = i;
            if(i%2 == 0){
                idx = i - 1;
            }
            if(idx > 0){
                idx /= 2;
                if(i%2 == 0){
                    nodes[idx].right = nodes[i];
                }else {
                    nodes[idx].left = nodes[i];
                }
            }
        }        
        return nodes[0];        
    }
    
    public void printTree(Node root){
        if(root != null){
            System.out.println(root.value);
            if(root.left != null)
                printTree(root.left);
            if(root.right != null)
                printTree(root.right);
        }
    }
    /**
     * @param args
     */
    public static void main(String[] args) {
        // TODO Auto-generated method stub
        MaxSumTree mst = new MaxSumTree();
        ArrayList<Integer> list = new ArrayList<Integer>();
        list.add(3);
        list.add(-4);
        list.add(2);
        list.add(7);
        list.add(-5);
        list.add(9);
        list.add(-8);
        list.add(-2);
        list.add(5);
        list.add(20);
        list.add(100);
        list.add(-30);
        list.add(-150);
        
        Node root = mst.buildBinaryTree(list);
        mst.printTree(root);
        SubTreeResult str = mst.findMaxSubTree(root);
        System.out.println("max sum subtree root:" + str.maxSumNode.value + ", max sub:" + str.maxSum);
    }

}

你可能感兴趣的:(java,Algorithm,tree)