Problem

You are tasked with implementing a segment tree using Java and Python to calculate the maximum value in a given range [left, right] of an array. A segment tree allows both querying and updating efficiently.

The main operations are:

  1. Build: Construct the segment tree from a given array.
  2. Update: Modify a specific index in the array and update the segment tree accordingly.
  3. Query: Find the maximum value in a given range [left, right].

Examples

Example 1

1
2
3
4
Input: [2, 6, 10, 4, 7, 28, 9, 11, 6, 33]
Query Range: [0, 5]
Output: 28
Explanation: The maximum value in range `[0, 5]` (elements: `[2, 6, 10, 4, 7, 28]`) is 28.

Example 2

1
2
3
4
5
Input: [2, 6, 10, 4, 7, 28, 9, 11, 6, 33]
Change index `3` to `40`
Query Range: [2, 6]
Output: 40
Explanation: After updating the element at index `3` to `40`, the maximum value in range `[2, 6]` (elements: `[10, 40, 7, 28, 9]`) is 40.

Solution

Method 1 - Using the segment tree

segment tree is an efficient way to perform range queries such as maximum, minimum, or summation operations. The tree is built such that:

  • Each leaf node contains an element of the array.
  • Internal nodes represent aggregate information over their child nodes (in this case, the maximum value).

Key insight:

  • A query for the maximum in a range [left, right] moves through specialised paths divided into segments within the tree hierarchy.

Approach

  1. Tree Representation:
    • Use a size of 2 * n to store both leaves and internal nodes.
    • The leaves of the segment tree contain the original array values.
  2. Build:
    • Populate the leaves with array values.
    • Compute the maximum values for internal nodes, moving from the bottom up.
  3. Update:
    • Update values at the leaf node corresponding to the array index.
    • Propagate the changes upwards, recalculating maximums along the way.
  4. Query:
    • Traverse the range [left, right] within the tree.
    • Aggregate the maximum values as you traverse segments.

Code

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
public class SegmentTreeMaxQuery {

    private final int n;
    private final int[] segTree;

    public SegmentTreeMaxQuery(int n) {
        this.n = n;
        segTree = new int[2 * n];
    }

    void build(int[] a) {
        for (int i = 0; i < n; i++) {
            segTree[n + i] = a[i];
        }

        for (int i = n - 1; i >= 1; i--) {
            segTree[i] = Math.max(segTree[2 * i], segTree[2 * i + 1]);
        }
    }

    void update(int pos, int value) {
        pos += n;
        segTree[pos] = value;

        while (pos > 1) {
            pos >>= 1;
            segTree[pos] = Math.max(segTree[2 * pos], segTree[2 * pos + 1]);
        }
    }

    int query(int left, int right) {
        left += n;
        right += n;
        int maxVal = Integer.MIN_VALUE;

        while (left < right) {
            if ((left & 1) == 1) {
                maxVal = Math.max(maxVal, segTree[left++]);
            }
            if ((right & 1) == 1) {
                maxVal = Math.max(maxVal, segTree[--right]);
            }
            left >>= 1;
            right >>= 1;
        }
        return maxVal;
    }

    public static void main(String[] args) {
        int[] a = { 2, 6, 10, 4, 7, 28, 9, 11, 6, 33 };
        SegmentTreeMaxQuery st = new SegmentTreeMaxQuery(a.length);

        st.build(a);

        System.out.printf(
            "Maximum in range %d to %d is %d\n",
            0,
            5,
            st.query(0, 6)
        );

        st.update(3, 40);

        System.out.printf(
            "Maximum in range %d to %d is %d\n",
            2,
            6,
            st.query(2, 7)
        );
    }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
class SegmentTreeMaxQuery:
    def __init__(self, n):
        self.n = n
        self.seg_tree = [0] * (2 * n)

    def build(self, a):
        for i in range(self.n):
            self.seg_tree[self.n + i] = a[i]
        
        for i in range(self.n - 1, 0, -1):
            self.seg_tree[i] = max(self.seg_tree[2 * i], self.seg_tree[2 * i + 1])

    def update(self, pos, value):
        pos += self.n
        self.seg_tree[pos] = value

        while pos > 1:
            pos //= 2
            self.seg_tree[pos] = max(self.seg_tree[2 * pos], self.seg_tree[2 * pos + 1])

    def query(self, left, right):
        left += self.n
        right += self.n
        max_val = float('-inf')

        while left < right:
            if left % 2 == 1:
                max_val = max(max_val, self.seg_tree[left])
                left += 1
            if right % 2 == 1:
                right -= 1
                max_val = max(max_val, self.seg_tree[right])
            
            left //= 2
            right //= 2
        return max_val

# Driver Code
a = [2, 6, 10, 4, 7, 28, 9, 11, 6, 33]
st = SegmentTreeMaxQuery(len(a))

st.build(a)
print(f"Maximum in range 0 to 5 is {st.query(0, 6)}")

st.update(3, 40)
print(f"Maximum in range 2 to 6 is {st.query(2, 7)}")

Complexity

  • ⏰ Time complexity
    • Tree Build: O(n) (populate leaves and compute internal nodes).
    • Update: O(log n) (move up the tree for updates).
    • Query: O(log n) (traverse log levels in the tree).
  • 🧺 Space complexity: O(n)