Find the Maximum Sum of Node Values
Problem
There exists an undirected tree with n nodes numbered 0 to n - 1. You are given a 0-indexed 2D integer array edges of length n - 1, where edges[i] = [ui, vi] indicates that there is an edge between nodes ui and vi in the tree. You are also given a positive integer k, and a 0-indexed array of non-negative integers nums of length n, where nums[i] represents the value of the node numbered i.
Alice wants the sum of values of tree nodes to be maximum, for which Alice can perform the following operation any number of times (including zero) on the tree:
- Choose any edge
[u, v]connecting the nodesuandv, and update their values as follows:nums[u] = nums[u] XOR knums[v] = nums[v] XOR k
Return the maximum possible sum of the values Alice can achieve by performing the operation any number of times.
Examples
Example 1:
graph TD; A(0) --- B(1) & C(2)
Input: nums = [1,2,1], k = 3, edges = [[0,1],[0,2]]
Output: 6
Explanation: Alice can achieve the maximum sum of 6 using a single operation:
- Choose the edge [0,2]. nums[0] and nums[2] become: 1 XOR 3 = 2, and the array nums becomes: [1,2,1] -> [2,2,2].
The total sum of values is 2 + 2 + 2 = 6.
It can be shown that 6 is the maximum achievable sum of values.
Example 2:
graph TD; A(0) --- B(1)
Input: nums = [2,3], k = 7, edges = [[0,1]]
Output: 9
Explanation: Alice can achieve the maximum sum of 9 using a single operation:
- Choose the edge [0,1]. nums[0] becomes: 2 XOR 7 = 5 and nums[1] become: 3 XOR 7 = 4, and the array nums becomes: [2,3] -> [5,4].
The total sum of values is 5 + 4 = 9.
It can be shown that 9 is the maximum achievable sum of values.
Example 3:
graph TD; A(0) --- B(1) & C(2) & D(3) & E(4) & F(5)
Input: nums = [7,7,7,7,7,7], k = 3, edges = [[0,1],[0,2],[0,3],[0,4],[0,5]]
Output: 42
Explanation: The maximum achievable sum is 42 which can be achieved by Alice performing no operations.
Constraints:
2 <= n == nums.length <= 2 * 10^41 <= k <= 10^90 <= nums[i] <= 10^9edges.length == n - 1edges[i].length == 20 <= edges[i][0], edges[i][1] <= n - 1- The input is generated such that
edgesrepresent a valid tree.
Solution
Method 1 - Top Down DP approach with memoization
Suppose we replace the values of two arbitrary, non-adjacent nodes U and V in the tree using their XOR values. In a tree that is undirected, connected, and acyclic, there is always a path between two nodes. Let this path have a length L, and the set of nodes along this path is denoted as P = {P1, P2, P3, ..., P(L-1)} in sequential order from U to V. If we perform an operation on all edges between U and V, there will be exactly L operations.
After these L operations, the values of all nodes in P remain unchanged due to the commutative and identity properties of XOR (i.e., A XOR B XOR B = A). However, the values of U and V are modified to U XOR k and V XOR k, respectively. This type of operation is referred to as an "effective operation."
Performing multiple effective operations modifies the values of m nodes (where m <= n; n is the total number of nodes), and it is observed that m is always even because each effective operation affects a pair of nodes.
A brute-force approach to maximising the sum involves recursion, where both possibilities—keeping a node’s value unchanged or updating it with XOR—are explored. To compute the maximum sum, the operation is only valid when performed on an even number of nodes.
The recursion is based on these principles:
- If all nodes have been processed (
index = nums.length), the result depends on whether the number of updated nodes is even. Return0for even parity andINT_MINfor odd parity. - Include the parity of updated nodes (
isEven) as a parameter. Parity flips between even and odd when a node's value is changed.
For each node:
- We can either leave it unchanged, retaining the current parity, and recursively calculate for the next node.
- Alternatively, apply an XOR operation to update its value and flip the parity, then recursively calculate the sum.
Using recursion alone can lead to inefficiency due to exponential possibilities. Dynamic programming resolves this by caching results in a 2D table (rows for node indices and columns for parity) to avoid redundant calculations.
Algorithm
Main Function: maximumValueSum(nums, k, edges)
- Initialise a 2D memoisation array
memowith all values set to (-1). - Call the recursive function
maxSumOfNodeswith:- Initial index (= 0),
- Initial parity (= 1) (even parity assumed at the start),
- Input array
nums, - XOR value (k), and
- Memoisation array
memo.
- Return the result of
maxSumOfNodes.
Recursive Function: maxSumOfNodes(index, isEven, nums, k, memo)
- If all nodes have been processed ((index = nums.length)):
- Return (0) if (isEven = 1), or (INT_MIN) otherwise.
- If the result for the current state exists in
memo, retrieve and return it. - Explore two possibilities:
- No XOR applied: Calculate the sum by keeping the current value and recursively processing the next node with the same parity.
- XOR applied: Compute the sum by updating the value using (XOR), flipping the parity, and recursively processing the next node.
- Memoise the maximum of the two cases for the current state and return it.
Code
Java
class Solution {
public long maximumValueSum(int[] nums, int k, int[][] edges) {
long[][] memo = new long[nums.length][2];
for (long[] row : memo) {
Arrays.fill(row, -1);
}
return calculateMaxSum(0, 1, nums, k, memo);
}
private long calculateMaxSum(int index, int isEven, int[] nums, int k, long[][] memo) {
if (index == nums.length) {
return isEven == 1 ? 0 : Integer.MIN_VALUE; // Base case: all nodes processed
}
if (memo[index][isEven] != -1) {
return memo[index][isEven]; // Use cached result if available
}
// Option 1: No XOR operation
long noXorSum = nums[index] + calculateMaxSum(index + 1, isEven, nums, k, memo);
// Option 2: XOR operation applied
long xorSum = (nums[index] ^ k) + calculateMaxSum(index + 1, isEven ^ 1, nums, k, memo);
// Memoise and return the maximum of both options
return memo[index][isEven] = Math.max(noXorSum, xorSum);
}
}
Python
class Solution:
def maximumValueSum(self, nums: List[int], k: int, edges: List[List[int]]) -> int:
memo = [[-1 for _ in range(2)] for _ in range(len(nums))]
return self.calculate_max_sum(0, 1, nums, k, memo)
def calculate_max_sum(self, index: int, is_even: int, nums: List[int], k: int, memo: List[List[int]]) -> int:
if index == len(nums):
return 0 if is_even == 1 else float('-inf') # Base case
if memo[index][is_even] != -1:
return memo[index][is_even] # Return cached result
# Option 1: No XOR operation
no_xor_sum = nums[index] + self.calculate_max_sum(index + 1, is_even, nums, k, memo)
# Option 2: XOR operation applied
xor_sum = (nums[index] ^ k) + self.calculate_max_sum(index + 1, is_even ^ 1, nums, k, memo)
memo[index][is_even] = max(no_xor_sum, xor_sum) # Cache the result
return memo[index][is_even]
Complexity
-
⏰ Time complexity:
- Without Memoisation: The recursion is exponential without caching results. For each node, two branches (include XOR or exclude XOR) are explored, resulting in a time complexity of
O(2ⁿ). - With Memoisation: The recursion explores all combinations of
index(node indices, sizen) andparity(2 possibilities). Each state is computed once, making the total complexityO(n).
- Without Memoisation: The recursion is exponential without caching results. For each node, two branches (include XOR or exclude XOR) are explored, resulting in a time complexity of
-
🧺 Space complexity:
- Without Memoisation: Without caching, only the recursion stack is used, making the space complexity
O(n)due to the depth of the recursive tree. However, the time inefficiency makes this infeasible for largen. - With Memoisation: The space complexity is
O(n)for the DP table (storing results forn × 2states) andO(n)for the maximum recursion depth, leading toO(n)overall.
- Without Memoisation: Without caching, only the recursion stack is used, making the space complexity
Method 2 - Bottom up DP
Tabulation involves iterating over all possible combinations of changing parameters systematically. Unlike recursion, it avoids the overhead of the recursive stack, making it more space-efficient. Here, the two variables that change during iteration are the current node index (index) and the parity of operations (isEven). To explore all possibilities, we use two nested loops: one for index and another for isEven (0 for odd parity, 1 for even).
The base case is defined as:
- If
index == nums.size(), the state depends on the parity:dp[nums.size()][1] = 0(valid state with even parity).dp[nums.size()][0] = INT_MIN(invalid state with odd parity).
The goal is to determine the maximum sum of node values where operations affect an even number of nodes. This final result is stored in dp[0][1]. Using nested loops, the outer loop iterates over index in reverse (from last node to first), while the inner loop updates states for each isEven.
At each step, the tabulation matrix is updated to track the maximum sum for both cases:
- Performing the XOR operation.
- Skipping the XOR operation.
Finally, the result stored in dp[0][1] represents the maximum sum achievable after completing all operations.
Algorithm
- Initialisation:
- Define a 2D DP array
dpof size(n + 1) × 2(n= number of nodes). - Set base cases:
dp[n][1] = 0(even parity is valid).dp[n][0] = INT_MIN(odd parity is invalid).
- Define a 2D DP array
- Tabulation:
- Loop backward from
n - 1to0forindex. - For each
isEven(0 or 1), calculate:- Perform XOR: Result is
dp[index + 1][isEven ^ 1] + (nums[index] ^ k). - Skip XOR: Result is
dp[index + 1][isEven] + nums[index].
- Perform XOR: Result is
- Update
dp[index][isEven]with the maximum of the two options.
- Loop backward from
- Result: The value in
dp[0][1]contains the maximum sum when operations are performed with even parity.
Code
Java
class Solution {
public long maximumValueSum(int[] nums, int k, int[][] edges) {
int n = nums.length;
long[][] dp = new long[n + 1][2]; // Define DP matrix.
dp[n][1] = 0; // Base case: Valid assignment with even parity.
dp[n][0] = Integer.MIN_VALUE; // Base case: Invalid assignment with odd parity.
for (int index = n - 1; index >= 0; index--) {
for (int isEven = 0; isEven <= 1; isEven++) {
// Option 1: Apply XOR operation.
long performXor = dp[index + 1][isEven ^ 1] + (nums[index] ^ k);
// Option 2: Skip XOR operation.
long skipXor = dp[index + 1][isEven] + nums[index];
// Take the max of the two options.
dp[index][isEven] = Math.max(performXor, skipXor);
}
}
// Return the maximum sum with even parity at 0-index.
return dp[0][1];
}
Python
class Solution:
def maximumValueSum(self, nums: List[int], k: int, edges: List[List[int]]) -> int:
n = len(nums)
dp = [[float('-inf') for _ in range(2)] for _ in range(n + 1)] # Define DP matrix.
dp[n][1] = 0 # Base case: Valid assignment with even parity.
dp[n][0] = float('-inf') # Base case: Invalid assignment with odd parity.
for index in range(n - 1, -1, -1): # Iterate backward.
for is_even in range(2): # Check both parity possibilities.
# Option 1: Apply XOR operation.
perform_xor = dp[index + 1][is_even ^ 1] + (nums[index] ^ k)
# Option 2: Skip XOR operation.
skip_xor = dp[index + 1][is_even] + nums[index]
# Update DP matrix with the maximum value.
dp[index][is_even] = max(perform_xor, skip_xor)
# Return the maximum sum with even parity at the start.
return dp[0][1]
Complexity
- ⏰ Time complexity:
O(n)- Tabulation involves iterating through all
nindices and two parity states (isEven = 0 or 1). - Overall complexity:
O(n × 2) = O(n)
- Tabulation involves iterating through all
- 🧺 Space complexity:
O(n). - DP array requires space proportional to(n + 1) × 2, which is:O(n)