Monday, January 26, 2015

Maximum sum range query using segment tree

There is a classic SPOJ problem which can be solved by segment tree. Here is the link to the problem.
http://www.spoj.com/problems/GSS3/

Basically it asks for any given input of numbers, find the maximum sum in a range (x, y| x<=y), the sum is defined as (A[i]+ A[i+1] + ... + A[j]) where x<=i<=j<=y.

To better understand the following solution, I suggest you read the two blogs before you read the below solutions. : http://blueocean-penn.blogspot.com/2015/01/segment-tree-order-statistics-binary.html and http://blueocean-penn.blogspot.com/2015/01/find-diameter-for-any-tree.html

Please note the maximum sum subarray problem can be solved in linear time like this:
public int maximumSum(int[] nums, int i, int j){

int currSum = nums[i], maxSum = nums[i];
for(int k = i+1; k<=j; k++){
 currSum = Math.max(nums[k], currSum + nums[k]);
 maxSum = Math.max(maxSum, currSum);
}

return maxSum;
}

/* at each node we store 5 piece of info
* data: the input number if at leaf level
* sum: the sum of the all elements at the interval represented by the current node
* prefixSum: the maximum sum of the elements starting from left bound of the current interval
* postfixSum: the maximum sum of the elements ending at right bound of current interval
* bestSum: the maximum sum of the subarrays at current interval
*/

example:

2 --------------------------------- 10
|<------prefixSum ....

2 --------------6 7-----------------10
|<--prefixSum..   |<--prefixSum
|<--- Sum ----->|

prefixSum(2,10) = Max(prefixSum(2,6), Sum(2,6) + prefixSum(7,10)
similar to postfixSum


2 --------------------------------- 10
       |<------BestSum----->|

2 --------------6 7-----------------10
 |<--BestSum-->|   |<--BestSum-->|
 ...postfixSum-->||<--prefixSum

BestSum(2,10) = Max(BestSum(2,6), BestSum(7,10), postfixSum(2,6)+prefixSum(7,10))



public class Node {
int data;
int sum;
int prefixSum;
int postfixSum;
int bestSum;
}
public class SegmentTree {
Node[] tree;
int len;
public SegmentTree(int[] nums){
len = nums.length;
int x = (int) Math.ceil(Math.log10(len)/Math.log10(2));
tree = new Node[(int) (Math.pow(2, x)*2-1)];
constructTree(0, nums, 0, nums.length-1);
}
//Time Complexity for tree construction is O(n). There are total 2n-1 nodes, and value of //every node is calculated only once in tree construction.
private Node constructTree(int current, int[] nums, int start, int end){
if(start==end){
tree[current] = new Node();
tree[current].data = tree[current].sum = tree[current].prefixSum 
= tree[current].postfixSum = tree[current].bestSum = nums[start];
}else{
int left = current*2+1;
int right = current*2+2;
int mid = start + (end-start)/2;
Node l = constructTree(left, nums, start, mid), 
r = constructTree(right, nums, mid+1, end);
 tree[current] = new Node();
tree[current].sum = l.sum + r.sum;
tree[current].prefixSum = Math.max(l.prefixSum, l.sum + r.prefixSum);
tree[current].postfixSum = Math.max(r.postfixSum, l.postfixSum + r.sum);
tree[current].bestSum = Math.max(Math.max(l.bestSum, r.bestSum), l.postfixSum+r.prefixSum);
}
return tree[current];
}
//Time complexity to query is O(lgN). To query a range max sum, we process at most 
//4 nodes at every level and number of levels is O(lgN). 
public int getMaxSum(int l, int r){
Node result =  getMaxSumRec(l, r, 0, 0, len-1);
return result==null? Integer.MIN_VALUE : result.bestSum;
}
private Node getMaxSumRec(int l, int r, int index, int start, int end){
if(l<=start && r>=end)
return tree[index];
else if(end<l || start>r){
return null;
}else{
int mid = start + (end-start)/2;
Node left = getMaxSumRec(l, r, 2*index+1, start, mid), 
right = getMaxSumRec(l, r, 2*index+2, mid+1, end);
if(left==null)
return right;
else if(right==null)
return left;
else{
Node result = new Node();
result.sum = left.sum + right.sum;
result.prefixSum = Math.max(left.prefixSum, left.sum + right.prefixSum);
result.postfixSum = Math.max(right.postfixSum, left.postfixSum + right.sum);
result.bestSum = Math.max(Math.max(left.bestSum, right.bestSum), left.postfixSum+right.prefixSum);
return result;
}
}
}
}

}

10 comments:

  1. I think this problem can be solved by greedy algorithm.
    This is my solution: http://allenlipeng47.com/PersonalPage/index/view/116/nkey

    ReplyDelete
    Replies
    1. you can look at the following method in the blog above: public int maximumSum(int[] nums, int i, int j){..}. However we need to remember the purpose of segment tree. ;)

      Delete
    2. I don't understand your prefixSum, or postfixSum. But it looks cool that segment tree can solve this problem. When I run the code, with [3, -3, -4, -5, 1], it throws NullPointerException().

      Delete
    3. i just updated the code, i forgot to initialize new Node()

      r = constructTree(right, nums, mid+1, end);
      tree[current] = new Node(); <=== missing this
      tree[current].sum = l.sum + r.sum;

      Delete
    4. it is hard to understand for me. Do you have other material or link to explain?

      Delete
    5. the idea using prefixSum, postfixSum... is very similar to way we calculate diameter of tree. let me draw a picture to explain tomorrow. it essentially is dynamic programming. DP(current) = fn(DP(current.left) + DP(current.right)), we just need to figure out what kind of function of fn() should be.

      Delete
    6. please see my little drawing in the blog.

      Delete
    7. Thank you! I finally understood. The key lies in:
      result.bestSum = Math.max(Math.max(left.bestSum, right.bestSum), left.postfixSum + right.prefixSum);

      Delete
    8. With prefix, postfix, it makes segment tree more powerful.

      Delete
  2. This comment has been removed by the author.

    ReplyDelete