Problem#
Given the root
of a binary tree, split the binary tree into two subtrees by removing one edge such that the product of the sums of the subtrees is maximized.
Return the maximum product of the sums of the two subtrees . Since the answer may be too large, return it modulo 10^9 + 7
.
Note that you need to maximize the answer before taking the mod and not after taking it.
Examples#
Example 1#
1
2
3
Input: root = [ 1 , 2 , 3 , 4 , 5 , 6 ]
Output: 110
Explanation: Remove the red edge and get 2 binary trees with sum 11 and 10. Their product is 110 ( 11 * 10 )
Example 2#
1
2
3
Input: root = [ 1 , null , 2 , 3 , 4 , null , null , 5 , 6 ]
Output: 90
Explanation: Remove the red edge and get 2 binary trees with sum 15 and 6. Their product is 90 ( 15 * 6 )
Solution#
Method 1 – DFS Subtree Sum Enumeration 1#
Intuition#
The main idea is to compute the sum of all nodes in the tree, then for every possible way to split the tree (by removing one edge), calculate the product of the sums of the resulting two subtrees. The maximum product is the answer.
Approach#
Traverse the tree to compute the total sum of all nodes.
Use DFS to compute the sum of each subtree and keep track of all possible subtree sums.
For each subtree sum, calculate the product of that sum and the total sum minus that subtree sum.
Return the maximum product modulo 10^9 + 7
.
Code#
C++#
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
struct TreeNode {
int val;
TreeNode * left, * right;
};
class Solution {
public :
long total = 0 , ans = 0 ;
long dfs (TreeNode* root) {
if (! root) return 0 ;
long sum = root-> val + dfs(root-> left) + dfs(root-> right);
ans = max(ans, sum * (total - sum));
return sum;
}
int maxProduct (TreeNode* root) {
function< long (TreeNode* )> sumTree = [& ](TreeNode* node) {
if (! node) return 0L ;
return node-> val + sumTree(node-> left) + sumTree(node-> right);
};
total = sumTree(root);
dfs(root);
return ans % 1000000007 ;
}
};
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
type TreeNode struct {
Val int
Left , Right * TreeNode
}
func maxProduct (root * TreeNode ) int {
var total , ans int64
var sumTree func (* TreeNode ) int64
sumTree = func (node * TreeNode ) int64 {
if node == nil { return 0 }
return int64(node .Val ) + sumTree (node .Left ) + sumTree (node .Right )
}
total = sumTree (root )
var dfs func (* TreeNode ) int64
dfs = func (node * TreeNode ) int64 {
if node == nil { return 0 }
sum := int64(node .Val ) + dfs (node .Left ) + dfs (node .Right )
prod := sum * (total - sum )
if prod > ans { ans = prod }
return sum
}
dfs (root )
return int(ans % 1000000007 )
}
Java#
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
class TreeNode {
int val;
TreeNode left, right;
}
class Solution {
long total = 0, ans = 0;
public int maxProduct (TreeNode root) {
total = sumTree(root);
dfs(root);
return (int )(ans % 1000000007);
}
private long sumTree (TreeNode node) {
if (node == null ) return 0;
return node.val + sumTree(node.left ) + sumTree(node.right );
}
private long dfs (TreeNode node) {
if (node == null ) return 0;
long sum = node.val + dfs(node.left ) + dfs(node.right );
ans = Math.max (ans, sum * (total - sum));
return sum;
}
}
Kotlin#
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
class TreeNode (var `val`: Int) {
var left: TreeNode? = null
var right: TreeNode? = null
}
class Solution {
var total = 0L
var ans = 0L
fun maxProduct (root: TreeNode?): Int {
total = sumTree(root)
dfs(root)
return (ans % 1000000007L ).toInt()
}
fun sumTree (node: TreeNode?): Long {
if (node == null ) return 0L
return node.`val` + sumTree(node.left) + sumTree(node.right)
}
fun dfs (node: TreeNode?): Long {
if (node == null ) return 0L
val sum = node.`val` + dfs(node.left) + dfs(node.right)
ans = maxOf(ans, sum * (total - sum))
return sum
}
}
Python#
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
class TreeNode :
def __init__ (self, val: int = 0 , left: 'TreeNode' = None , right: 'TreeNode' = None ):
self. val = val
self. left = left
self. right = right
class Solution :
def maxProduct (self, root: TreeNode) -> int:
MOD = 10 ** 9 + 7
self. total = self. dfs_sum(root)
self. ans = 0
def dfs (node: TreeNode) -> int:
if not node:
return 0
s = node. val + dfs(node. left) + dfs(node. right)
self. ans = max(self. ans, s * (self. total - s))
return s
dfs(root)
return self. ans % MOD
def dfs_sum (self, node: TreeNode) -> int:
if not node:
return 0
return node. val + self. dfs_sum(node. left) + self. dfs_sum(node. right)
Rust#
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
struct TreeNode {
val: i32 ,
left: Option< Box< TreeNode>> ,
right: Option< Box< TreeNode>> ,
}
impl Solution {
pub fn max_product (root: Option< Box< TreeNode>> ) -> i32 {
fn sum_tree (node: & Option< Box< TreeNode>> ) -> i64 {
match node {
Some(n) => n.val as i64 + sum_tree(& n.left) + sum_tree(& n.right),
None => 0 ,
}
}
fn dfs (node: & Option< Box< TreeNode>> , total: i64 , ans: & mut i64 ) -> i64 {
match node {
Some(n) => {
let s = n.val as i64 + dfs(& n.left, total, ans) + dfs(& n.right, total, ans);
* ans = (* ans).max(s * (total - s));
s
},
None => 0 ,
}
}
let total = sum_tree(& root);
let mut ans = 0 ;
dfs(& root, total, & mut ans);
(ans % 1_000_000_007 ) as i32
}
}
TypeScript#
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
class TreeNode {
val : number ;
left : TreeNode | null ;
right : TreeNode | null ;
constructor (val? : number , left? : TreeNode | null , right? : TreeNode | null ) {
this .val = val ?? 0 ;
this .left = left ?? null ;
this .right = right ?? null ;
}
}
class Solution {
maxProduct (root : TreeNode | null ): number {
let total = this .sumTree (root );
let ans = 0 ;
const dfs = (node : TreeNode | null ): number => {
if (! node ) return 0 ;
const s = node .val + dfs (node .left ) + dfs (node .right );
ans = Math.max (ans , s * (total - s ));
return s ;
};
dfs (root );
return ans % 1000000007 ;
}
sumTree (node : TreeNode | null ): number {
if (! node ) return 0 ;
return node .val + this .sumTree (node .left ) + this .sumTree (node .right );
}
}
Complexity#
⏰ Time complexity: O(n)
, where n is the number of nodes. Each node is visited twice (once for total sum, once for subtree sums).
🧺 Space complexity: O(h)
, where h is the height of the tree, due to recursion stack.