[Solved] Maximum Segment Sum After Removals LeetCode Contest

You are given two 0-indexed integer arrays nums and removeQueries, both of length n. For the ith query, the element in nums at the index removeQueries[i] is removed, splitting nums into different segments.

segment is a contiguous sequence of positive integers in nums. A segment sum is the sum of every element in a segment.

Return an integer array answer, of length n, where answer[i] is the maximum segment sum after applying the ith removal.

Note: The same index will not be removed more than once.

Maximum Segment Sum After Removals LeetCode Contest

Example 1:

Input: nums = [1,2,5,6,1], removeQueries = [0,3,2,4,1]
Output: [14,7,2,2,0]
Explanation: Using 0 to indicate a removed element, the answer is as follows:
Query 1: Remove the 0th element, nums becomes [0,2,5,6,1] and the maximum segment sum is 14 for segment [2,5,6,1].
Query 2: Remove the 3rd element, nums becomes [0,2,5,0,1] and the maximum segment sum is 7 for segment [2,5].
Query 3: Remove the 2nd element, nums becomes [0,2,0,0,1] and the maximum segment sum is 2 for segment [2]. 
Query 4: Remove the 4th element, nums becomes [0,2,0,0,0] and the maximum segment sum is 2 for segment [2]. 
Query 5: Remove the 1st element, nums becomes [0,0,0,0,0] and the maximum segment sum is 0, since there are no segments.
Finally, we return [14,7,2,2,0].

Example 2:

Input: nums = [3,2,11,1], removeQueries = [3,2,1,0]
Output: [16,5,3,0]
Explanation: Using 0 to indicate a removed element, the answer is as follows:
Query 1: Remove the 3rd element, nums becomes [3,2,11,0] and the maximum segment sum is 16 for segment [3,2,11].
Query 2: Remove the 2nd element, nums becomes [3,2,0,0] and the maximum segment sum is 5 for segment [3,2].
Query 3: Remove the 1st element, nums becomes [3,0,0,0] and the maximum segment sum is 3 for segment [3].
Query 4: Remove the 0th element, nums becomes [0,0,0,0] and the maximum segment sum is 0, since there are no segments.
Finally, we return [16,5,3,0].

Constraints:

  • n == nums.length == removeQueries.length
  • 1 <= n <= 105
  • 1 <= nums[i] <= 109
  • 0 <= removeQueries[i] < n
  • All the values of removeQueries are unique.

Solution

int find(int i, vector<long long>& ds) {
    return ds[i] < 0 ? i : ds[i] = find(ds[i], ds);
}
void merge(int s1, int s2, vector<long long>& ds) {
    int p1 = find(s1, ds), p2 = find(s2, ds);
    ds[p2] += ds[p1];
    ds[p1] = p2;
}
vector<long long> maximumSegmentSum(vector<int>& nums, vector<int>& rq) {
    vector<long long> res(nums.size()), ds(nums.size(), INT_MAX);
    for (int i = rq.size() - 1; i > 0; --i) {
        int j = rq[i];
        ds[j] = -nums[j];
        if (j > 0 && ds[j - 1] != INT_MAX)
            merge(j, j - 1, ds);
        if (j < nums.size() - 1 && ds[j + 1] != INT_MAX)
            merge(j, j + 1, ds);
        res[i - 1] = max(res[i], -ds[find(j, ds)]);
    }
    return res;
}
class Solution {
    class Pair {
        int start, end;
        Pair(int start, int end) {
            this.start = start;
            this.end = end;
        }
    }
    class MultiSet {
        TreeMap<Long, Integer> freq;
    
        MultiSet() {
            freq = new TreeMap<>();
        }
    
        void add(long x) {
            freq.put(x, freq.getOrDefault(x, 0) + 1);
        }
    
        boolean remove(long x) {
            Integer f = freq.get(x);
            if (f == null)
                return false;
            else if (f == 1) {
                freq.remove(x);
            } else {
                freq.put(x, f - 1);
            }
            return true;
        }
        Long getMax() {
            if(freq.size() == 0) return 0L;
            return freq.lastKey();
        }
    }
    public long[] maximumSegmentSum(int[] arr, int[] removeQueries) {
        long[] pre = new long[arr.length + 1];  // pefix sum to query subarray sum in O(1)
        for(int i = 0; i < arr.length; i++) pre[i+1] = pre[i] + arr[i];
        TreeSet<Pair> set = new TreeSet<>((a, b) -> a.start - b.start);  // order by start of range
        MultiSet sums = new MultiSet();
        set.add(new Pair(0, arr.length-1));   // initaially one segment of full range
        sums.add(pre[arr.length] - pre[0]);   // add the sum of the full ranged segment to MultiSet sums
        long[] ans = new long[removeQueries.length];
        for(int i = 0; i < removeQueries.length; i++) {
            Pair p = set.floor(new Pair(removeQueries[i], removeQueries[i])); // find range that contains removeQueries[i]
            set.remove(p);  // remove the range from set
            sums.remove(pre[p.end+1] - pre[p.start]);  // remove the corresponding sum from sums
			// split the range into two ranges (if non empty)
			// range with smaller start
            if(p.start <= removeQueries[i]-1) {
                Pair p1 = new Pair(p.start, removeQueries[i] - 1);
                set.add(p1); // add the pair for range with smaller start
                sums.add(pre[p1.end + 1] - pre[p1.start]);  // add corresponding sum
            }
			// remaining range
            if(removeQueries[i]+1 <= p.end) {
                Pair p2 = new Pair(removeQueries[i] + 1, p.end);
                set.add(p2); // add the pair with remaining range
                sums.add(pre[p2.end + 1] - pre[p2.start]); // add corresponding range
            }
            ans[i] = sums.getMax()==null?0L:sums.getMax(); // get max or 0 if empty from the MultiSet
        }
        return ans;
    }
}
class Uni:
    def __init__(self, n):
        self.rep = list(range(n))
        self.val = [0] * n
    
    def find(self, x):
        if self.rep[x] != x: self.rep[x] = self.find(self.rep[x])
        return self.rep[x]

    def merge(self, x, y):
        x, y = self.find(x), self.find(y)
        if x != y:
            x, y = min(x, y), max(x, y)
            self.rep[y] = x
            self.val[x] += self.val[y]

class Solution:
    def maximumSegmentSum(self, nums: List[int], removeQueries: List[int]) -> List[int]:
        n = len(nums)
        ans, uni, curmax = [], Uni(n), 0
        for idx in removeQueries[::-1]:
            ans.append(curmax)
            uni.val[idx] = nums[idx]
            if idx > 0 and uni.val[idx - 1]: uni.merge(idx - 1, idx)
            if idx + 1 < n and uni.val[idx + 1]: uni.merge(idx + 1, idx)
            curmax = max(curmax, uni.val[uni.find(idx)])
        return ans[::-1]

Happy Learning – If you require any further information, feel free to contact me.

Share your love
Saurav Hathi

Saurav Hathi

I'm currently studying Bachelor of Computer Science at Lovely Professional University in Punjab.

📌 Nodejs and Android 😎
📌 Java

Articles: 444

Leave a Reply

Your email address will not be published. Required fields are marked *