Problem

There exists an undirected tree with n nodes numbered 0 to n - 1. You are given a 0-indexed 2D integer array edges of length n - 1, where edges[i] = [ui, vi] indicates that there is an edge between nodes ui and vi in the tree. You are also given a positive integer k, and a 0-indexed array of non-negative integers nums of length n, where nums[i] represents the value of the node numbered i.

Alice wants the sum of values of tree nodes to be maximum, for which Alice can perform the following operation any number of times (including zero) on the tree:

  • Choose any edge [u, v] connecting the nodes u and v, and update their values as follows:
    • nums[u] = nums[u] XOR k
    • nums[v] = nums[v] XOR k

Return the maximum possible sum of the values Alice can achieve by performing the operation any number of times.

Examples

Example 1:

graph TD;
	A(0) --- B(1) & C(2)
  
1
2
3
4
5
6
Input: nums = [1,2,1], k = 3, edges = [[0,1],[0,2]]
Output: 6
Explanation: Alice can achieve the maximum sum of 6 using a single operation:
- Choose the edge [0,2]. nums[0] and nums[2] become: 1 XOR 3 = 2, and the array nums becomes: [1,2,1] -> [2,2,2].
The total sum of values is 2 + 2 + 2 = 6.
It can be shown that 6 is the maximum achievable sum of values.

Example 2:

graph TD;
A(0) --- B(1)
  
1
2
3
4
5
6
Input: nums = [2,3], k = 7, edges = [[0,1]]
Output: 9
Explanation: Alice can achieve the maximum sum of 9 using a single operation:
- Choose the edge [0,1]. nums[0] becomes: 2 XOR 7 = 5 and nums[1] become: 3 XOR 7 = 4, and the array nums becomes: [2,3] -> [5,4].
The total sum of values is 5 + 4 = 9.
It can be shown that 9 is the maximum achievable sum of values.

Example 3:

graph TD;
A(0) --- B(1) & C(2) & D(3) & E(4) & F(5)
  
1
2
3
Input: nums = [7,7,7,7,7,7], k = 3, edges = [[0,1],[0,2],[0,3],[0,4],[0,5]]
Output: 42
Explanation: The maximum achievable sum is 42 which can be achieved by Alice performing no operations.

Constraints:

  • 2 <= n == nums.length <= 2 * 10^4
  • 1 <= k <= 10^9
  • 0 <= nums[i] <= 10^9
  • edges.length == n - 1
  • edges[i].length == 2
  • 0 <= edges[i][0], edges[i][1] <= n - 1
  • The input is generated such that edges represent a valid tree.

Solution

Method 1 - Top Down DP approach with memoization

Suppose we replace the values of two arbitrary, non-adjacent nodes U and V in the tree using their XOR values. In a tree that is undirected, connected, and acyclic, there is always a path between two nodes. Let this path have a length L, and the set of nodes along this path is denoted as P = {P1, P2, P3, ..., P(L-1)} in sequential order from U to V. If we perform an operation on all edges between U and V, there will be exactly L operations.

After these L operations, the values of all nodes in P remain unchanged due to the commutative and identity properties of XOR (i.e., A XOR B XOR B = A). However, the values of U and V are modified to U XOR k and V XOR k, respectively. This type of operation is referred to as an “effective operation.”

Performing multiple effective operations modifies the values of m nodes (where m <= nn is the total number of nodes), and it is observed that m is always even because each effective operation affects a pair of nodes.

A brute-force approach to maximising the sum involves recursion, where both possibilities—keeping a node’s value unchanged or updating it with XOR—are explored. To compute the maximum sum, the operation is only valid when performed on an even number of nodes.

The recursion is based on these principles:

  • If all nodes have been processed (index = nums.length), the result depends on whether the number of updated nodes is even. Return 0 for even parity and INT_MIN for odd parity.
  • Include the parity of updated nodes (isEven) as a parameter. Parity flips between even and odd when a node’s value is changed.

For each node:

  1. We can either leave it unchanged, retaining the current parity, and recursively calculate for the next node.
  2. Alternatively, apply an XOR operation to update its value and flip the parity, then recursively calculate the sum.

Using recursion alone can lead to inefficiency due to exponential possibilities. Dynamic programming resolves this by caching results in a 2D table (rows for node indices and columns for parity) to avoid redundant calculations.

Algorithm

Main Function: maximumValueSum(nums, k, edges)

  1. Initialise a 2D memoisation array memo with all values set to (-1).
  2. Call the recursive function maxSumOfNodes with:
    • Initial index (= 0),
    • Initial parity (= 1) (even parity assumed at the start),
    • Input array nums,
    • XOR value (k), and
    • Memoisation array memo.
  3. Return the result of maxSumOfNodes.

Recursive Function: maxSumOfNodes(index, isEven, nums, k, memo)

  1. If all nodes have been processed ((index = nums.length)):
    • Return (0) if (isEven = 1), or (INT_MIN) otherwise.
  2. If the result for the current state exists in memo, retrieve and return it.
  3. Explore two possibilities:
    • No XOR applied: Calculate the sum by keeping the current value and recursively processing the next node with the same parity.
    • XOR applied: Compute the sum by updating the value using (XOR), flipping the parity, and recursively processing the next node.
  4. Memoise the maximum of the two cases for the current state and return it.

Code

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
class Solution {
    public long maximumValueSum(int[] nums, int k, int[][] edges) {
        long[][] memo = new long[nums.length][2];
        for (long[] row : memo) {
            Arrays.fill(row, -1);
        }
        return calculateMaxSum(0, 1, nums, k, memo);
    }

    private long calculateMaxSum(int index, int isEven, int[] nums, int k, long[][] memo) {
        if (index == nums.length) {
            return isEven == 1 ? 0 : Integer.MIN_VALUE; // Base case: all nodes processed
        }
        if (memo[index][isEven] != -1) {
            return memo[index][isEven]; // Use cached result if available
        }
        // Option 1: No XOR operation
        long noXorSum = nums[index] + calculateMaxSum(index + 1, isEven, nums, k, memo);
        // Option 2: XOR operation applied
        long xorSum = (nums[index] ^ k) + calculateMaxSum(index + 1, isEven ^ 1, nums, k, memo);
        // Memoise and return the maximum of both options
        return memo[index][isEven] = Math.max(noXorSum, xorSum);
    }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
class Solution:
    def maximumValueSum(self, nums: List[int], k: int, edges: List[List[int]]) -> int:
        memo = [[-1 for _ in range(2)] for _ in range(len(nums))]
        return self.calculate_max_sum(0, 1, nums, k, memo)

    def calculate_max_sum(self, index: int, is_even: int, nums: List[int], k: int, memo: List[List[int]]) -> int:
        if index == len(nums):
            return 0 if is_even == 1 else float('-inf')  # Base case
        
        if memo[index][is_even] != -1:
            return memo[index][is_even]  # Return cached result
        
        # Option 1: No XOR operation
        no_xor_sum = nums[index] + self.calculate_max_sum(index + 1, is_even, nums, k, memo)
        # Option 2: XOR operation applied
        xor_sum = (nums[index] ^ k) + self.calculate_max_sum(index + 1, is_even ^ 1, nums, k, memo)
        
        memo[index][is_even] = max(no_xor_sum, xor_sum)  # Cache the result
        return memo[index][is_even]

Complexity

  • ⏰ Time complexity:

    • Without Memoisation: The recursion is exponential without caching results. For each node, two branches (include XOR or exclude XOR) are explored, resulting in a time complexity of O(2ⁿ).
    • With Memoisation: The recursion explores all combinations of index (node indices, size n) and parity (2 possibilities). Each state is computed once, making the total complexity O(n).
  • 🧺 Space complexity:

    • Without Memoisation: Without caching, only the recursion stack is used, making the space complexity O(n) due to the depth of the recursive tree. However, the time inefficiency makes this infeasible for large n.
    • With Memoisation: The space complexity is O(n) for the DP table (storing results for n × 2 states) and O(n) for the maximum recursion depth, leading to O(n) overall.

Method 2 - Bottom up DP

Tabulation involves iterating over all possible combinations of changing parameters systematically. Unlike recursion, it avoids the overhead of the recursive stack, making it more space-efficient. Here, the two variables that change during iteration are the current node index (index) and the parity of operations (isEven). To explore all possibilities, we use two nested loops: one for index and another for isEven (0 for odd parity, 1 for even).

The base case is defined as:

  • If index == nums.size(), the state depends on the parity:
    • dp[nums.size()][1] = 0 (valid state with even parity).
    • dp[nums.size()][0] = INT_MIN (invalid state with odd parity).

The goal is to determine the maximum sum of node values where operations affect an even number of nodes. This final result is stored in dp[0][1]. Using nested loops, the outer loop iterates over index in reverse (from last node to first), while the inner loop updates states for each isEven.

At each step, the tabulation matrix is updated to track the maximum sum for both cases:

  • Performing the XOR operation.
  • Skipping the XOR operation.

Finally, the result stored in dp[0][1] represents the maximum sum achievable after completing all operations.

Algorithm

  1. Initialisation:
    • Define a 2D DP array dp of size (n + 1) × 2 (n = number of nodes).
    • Set base cases:
      • dp[n][1] = 0 (even parity is valid).
      • dp[n][0] = INT_MIN (odd parity is invalid).
  2. Tabulation:
    • Loop backward from n - 1 to 0 for index.
    • For each isEven (0 or 1), calculate:
      • Perform XOR: Result is dp[index + 1][isEven ^ 1] + (nums[index] ^ k).
      • Skip XOR: Result is dp[index + 1][isEven] + nums[index].
    • Update dp[index][isEven] with the maximum of the two options.
  3. Result: The value in dp[0][1] contains the maximum sum when operations are performed with even parity.

Code

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
class Solution {
    public long maximumValueSum(int[] nums, int k, int[][] edges) {
        int n = nums.length;
        long[][] dp = new long[n + 1][2];  // Define DP matrix.
        
        dp[n][1] = 0;  // Base case: Valid assignment with even parity.
        dp[n][0] = Integer.MIN_VALUE;  // Base case: Invalid assignment with odd parity.
        
        for (int index = n - 1; index >= 0; index--) {
            for (int isEven = 0; isEven <= 1; isEven++) {
                // Option 1: Apply XOR operation.
                long performXor = dp[index + 1][isEven ^ 1] + (nums[index] ^ k);
                // Option 2: Skip XOR operation.
                long skipXor = dp[index + 1][isEven] + nums[index];
                
                // Take the max of the two options.
                dp[index][isEven] = Math.max(performXor, skipXor);
            }
        }
        
        // Return the maximum sum with even parity at 0-index.
        return dp[0][1];
    }
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
class Solution:
    def maximumValueSum(self, nums: List[int], k: int, edges: List[List[int]]) -> int:
        n = len(nums)
        dp = [[float('-inf') for _ in range(2)] for _ in range(n + 1)]  # Define DP matrix.
        
        dp[n][1] = 0  # Base case: Valid assignment with even parity.
        dp[n][0] = float('-inf')  # Base case: Invalid assignment with odd parity.
        
        for index in range(n - 1, -1, -1):  # Iterate backward.
            for is_even in range(2):  # Check both parity possibilities.
                # Option 1: Apply XOR operation.
                perform_xor = dp[index + 1][is_even ^ 1] + (nums[index] ^ k)
                # Option 2: Skip XOR operation.
                skip_xor = dp[index + 1][is_even] + nums[index]
                
                # Update DP matrix with the maximum value.
                dp[index][is_even] = max(perform_xor, skip_xor)
        
        # Return the maximum sum with even parity at the start.
        return dp[0][1]

Complexity

  • ⏰ Time complexity: O(n)
    • Tabulation involves iterating through all n indices and two parity states (isEven = 0 or 1).
    • Overall complexity: O(n × 2) = O(n)
  • 🧺 Space complexity: O(n). - DP array requires space proportional to (n + 1) × 2, which is: O(n)