Problem:
Given a binary tree, each node has an integer value attached (can be negative), write code to find the subtree with the max sum.
My solution:
Basically this is a recursive problem, for each subtree result, you should remember its maxSum and corresponding root node. So I use a struct SubTreeResult to remember all such info.
Note: The code I pasted here is not tested sufficiently, so don't take it as correct solution unless you verify it yourself.
package alg;
import java.util.ArrayList;
public class MaxSumTree {
class Node {
int value;
Node left, right;
Node(int v, Node l, Node r){
value = v;
left = l;
right = r;
}
}
class SubTreeResult {
Node subTreeRoot, maxSumNode;
int sum, maxSum;
SubTreeResult(Node r, Node maxNode, int maxSum, int sum){
this.subTreeRoot = r;
this.maxSumNode = maxNode;
this.sum = sum;
this.maxSum = maxSum;
}
}
public SubTreeResult findMaxSubTree(Node r){
SubTreeResult leftResult = null, rightResult = null, result = null;
if(r.left != null){
leftResult = findMaxSubTree(r.left);
}
if(r.right != null){
rightResult = findMaxSubTree(r.right);
}
if(leftResult != null && rightResult != null){
int sum = r.value + leftResult.sum + rightResult.sum;
if(leftResult.maxSum >= rightResult.maxSum){
if(sum >= leftResult.maxSum){
result = new SubTreeResult(r, r, sum, sum);
}else {
result = new SubTreeResult(r, leftResult.maxSumNode, leftResult.maxSum, sum);
}
}else {
if(sum >= rightResult.maxSum){
result = new SubTreeResult(r, r, sum, sum);
}else {
result = new SubTreeResult(r, rightResult.maxSumNode, rightResult.maxSum, sum);
}
}
}else if(leftResult != null){
int sum = r.value + leftResult.sum;
if(sum >= leftResult.maxSum){
result = new SubTreeResult(r, r, sum, sum);
}else {
result = new SubTreeResult(r, leftResult.maxSumNode, leftResult.maxSum, sum);
}
}else if(rightResult != null){
int sum = r.value + rightResult.sum;
if(sum >= rightResult.maxSum){
result = new SubTreeResult(r, r, sum, sum);
}else {
result = new SubTreeResult(r, rightResult.maxSumNode, rightResult.maxSum, sum);
}
}else {
result = new SubTreeResult(r, r, r.value, r.value);
}
return result;
}
public Node buildBinaryTree(ArrayList<Integer> list){
int size = list.size();
if(size == 0) return null;
Node [] nodes = new Node[size];
for(int i = 0; i< size; i++){
int number = list.get(i);
nodes[i] = new Node(number, null, null);
int idx = i;
if(i%2 == 0){
idx = i - 1;
}
if(idx > 0){
idx /= 2;
if(i%2 == 0){
nodes[idx].right = nodes[i];
}else {
nodes[idx].left = nodes[i];
}
}
}
return nodes[0];
}
public void printTree(Node root){
if(root != null){
System.out.println(root.value);
if(root.left != null)
printTree(root.left);
if(root.right != null)
printTree(root.right);
}
}
/**
* @param args
*/
public static void main(String[] args) {
// TODO Auto-generated method stub
MaxSumTree mst = new MaxSumTree();
ArrayList<Integer> list = new ArrayList<Integer>();
list.add(3);
list.add(-4);
list.add(2);
list.add(7);
list.add(-5);
list.add(9);
list.add(-8);
list.add(-2);
list.add(5);
list.add(20);
list.add(100);
list.add(-30);
list.add(-150);
Node root = mst.buildBinaryTree(list);
mst.printTree(root);
SubTreeResult str = mst.findMaxSubTree(root);
System.out.println("max sum subtree root:" + str.maxSumNode.value + ", max sub:" + str.maxSum);
}
}