[LeetCode] 689. Maximum Sum of 3 Non-Overlapping Subarrays

Problem

In a given array nums of positive integers, find three non-overlapping subarrays with maximum sum.

Each subarray will be of size k, and we want to maximize the sum of all 3*k entries.

Return the result as a list of indices representing the starting position of each interval (0-indexed). If there are multiple answers, return the lexicographically smallest one.

Example:

Input: [1,2,1,2,6,7,5,1], 2
Output: [0, 3, 5]
Explanation: Subarrays [1, 2], [2, 6], [7, 5] correspond to the starting indices [0, 3, 5].
We could have also taken [2, 1], but an answer of [1, 3, 5] would be lexicographically larger.

Note:

nums.length will be between 1 and 20000.
nums[i] will be between 1 and 65535.
k will be between 1 and floor(nums.length / 3).

Solution

class Solution {
    public int[] maxSumOfThreeSubarrays(int[] nums, int k) {
        //three parts: 0 ~ i-1, i ~ i+k-1, i+k ~ n-1 (i >= k)
        // (n-1) - (i+k) + 1 >= k ... so (i <= n-2k)
        
        if (nums == null || nums.length < 3*k) return null;
        
        int n = nums.length;
        
        int[] sum = new int[n+1];
        int[] left = new int[n];
        int[] right = new int[n];
        int[] res = new int[3];
        
        int max = 0;
        
        for (int i = 0; i < n; i++) {
            sum[i+1] = sum[i] + nums[i];
        }
        
        int leftMax = sum[k]-sum[0];
        left[k-1] = 0;
        for (int i = k; i < n; i++) {
            if (sum[i+1]-sum[i+1-k] > leftMax) {
                left[i] = i+1-k;
                leftMax = sum[i+1]-sum[i+1-k];
            } else {
                left[i] = left[i-1];
            }
        }
        
        int rightMax = sum[n]-sum[n-k];
        right[n-k] = n-k;
        for (int i = n-1-k; i >= 0; i--) {
            if (sum[i+k]-sum[i] > rightMax) {
                right[i] = i;
                rightMax = sum[i+k]-sum[i];
            } else {
                right[i] = right[i+1];
            }
        }
        
        for (int i = k; i <= n-2*k; i++) {
            int l = left[i-1];
            int r = right[i+k];
            int curMax = sum[l+k]-sum[l] + (sum[i+k]-sum[i]) + (sum[r+k]-sum[r]);
            if (curMax > max) {
                max = curMax;
                res[0] = l;
                res[1] = i;
                res[2] = r;
            }
        }
        
        return res;
    }
}

你可能感兴趣的:(java)