Problem
You are given the root
of a binary tree and an integer distance
. A pair of two different leaf nodes of a binary tree is said to be good if the length of the shortest path between them is less than or equal to distance
.
Return the number of good leaf node pairs in the tree.
Examples
Example 1:
1
/ \
2 3
\
4
Input:
root = [1,2,3,null,4], distance = 3
Output:
1
Explanation: The leaf nodes of the tree are 3 and 4 and the length of the shortest path between them is 3. This is the only good pair.
Example 2:
1
/ \
2 3
/ \ / \
4 5 6 7
Input:
root = [1,2,3,4,5,6,7], distance = 3
Output:
2
Explanation: The good pairs are [4,5] and [6,7] with shortest path = 2. The pair [4,6] is not good because the length of ther shortest path between them is 4.
Example 3:
7
/ \
1 4
/ / \
6 5 3
\
2
Input:
root = [7,1,4,6,null,5,3,null,null,null,null,null,2], distance = 3
Output:
1
Explanation: The only good pair is [2,5].
Solution
Method 1 - Postorder DFS
Here is a step-by-step plan to solve this problem:
- Perform a DFS traversal of the binary tree.
- For each node, collect the distances to leaves in its subtree.
- Merge the lists of distances from the left and right subtrees. As, we are moving from child to parent node, after doing DFS, we increment the distances by 1, at the time of merging, as the distance to leaves increase.
- Count pairs of leaf nodes from the merged list that satisfy the distance constraint.
- Return the number of valid pairs.
Also, when going up to parent, we consider only threshold distance - 1
, because for any distance >= distance - 1
, after incrementing, it will become at least threshold distance
, which we do not need to propagate any further up the tree – it exceeds the permissible limit for forming any good pairs with other distances. Because, when we go to parent, then for another leaf node, we have to go down by 1 at least, and hence the distance will at least become threshold distance + 1
.
Here is the video explanation:
Code
Java
public class Solution {
public int countPairs(TreeNode root, int distance) {
int[] ans = new int[1]; // Using an array to allow updates within the dfs function
dfs(root, distance, ans);
return ans[0];
}
private int[] dfs(TreeNode node, int distance, int[] ans) {
if (node == null) {
return new int[0];
}
if (node.left == null && node.right == null) { // If it's a leaf node
return new int[]{1};
}
int[] leftDistances = dfs(node.left, distance, ans);
int[] rightDistances = dfs(node.right, distance, ans);
// Count pairs across left and right subtrees
for (int ld: leftDistances) {
for (int rd: rightDistances) {
if (ld + rd <= distance) {
ans[0]++;
}
}
}
// Prepare and return distance list for the current node
int[] distances = new int[leftDistances.length + rightDistances.length];
int index = 0;
for (int ld: leftDistances) {
if (ld + 1 < distance) { // Only consider up to `distance-1`
distances[index++] = ld + 1;
}
}
for (int rd: rightDistances) {
if (rd + 1 < distance) { // Only consider up to `distance-1`
distances[index++] = rd + 1;
}
}
return Arrays.copyOf(distances, index);
}
}
Python
def countPairs(root: TreeNode, distance: int) -> int:
result = [0] # Using a list to allow updates within the DFS function
def dfs(node):
if not node:
return []
if not node.left and not node.right: # If it's a leaf node
return [1]
left_distances = dfs(node.left)
right_distances = dfs(node.right)
# Count pairs across left and right subtrees
for ld in left_distances:
for rd in right_distances:
if ld + rd <= distance:
result[0] += 1
# Prepare and return distance list for the current node
distances = []
for ld in left_distances:
if ld + 1 < distance: # Only consider up to `distance-1`
distances.append(ld + 1)
for rd in right_distances:
if rd + 1 < distance: # Only consider up to `distance-1`
distances.append(rd + 1)
return distances
dfs(root)
return result[0]
Complexity
- Time:
O(n*d^2)
wheren
is number of nodes in tree, andd
is number of leaf nodes. - Space:
O(n)
because recursion stack takesO(h)
times, which isO(log n)
for balanced tree, andO(n)
for skewed tree. Then, list of distances list takeO(d)
space, but in worst case it is alsoO(n)
.
Dry Run
Lets take eg. 2 for dry run, where distance = 3 and tree is:
graph TD; 1; 1 --- 2; 1 --- 3; 2 --- 4; 2 --- 5; 3 --- 6; 3 --- 7;
We initialize ans[0] = 0
and start DFS from node 1
.
DFS Order
- We reach node
4
, which is leaf node. We return distance to itself as[1]
. - Similarly, for node
5
, we return distance as[1]
. - Node
2
- Now, we reach node `2- Left distances from child
4
:[1]
. - Right distances from child
5
:[1]
. - Count pairs
(4, 5)
:1 + 1 = 2 <= 3
→ Incrementans
to[1]
. - Prepare distances:
[2, 2]
and return[2, 2]
.
- Left distances from child
- Node
6
, it is a leaf node ⇨ return[1]
as distance to self. - Node
7
, it is a leaf node ⇨ return[1]
as distance to self. - Node
3
- now we come to node3
:- Left distances from child
6
:[1]
. - Right distances from child
7
:[1]
. - Count pairs
(6, 7)
:1 + 1 = 2 <= 3
→ Incrementresult
to[2]
. - Prepare distances:
[2, 2]
and return[2, 2]
.
- Left distances from child
- Node
1
- We are at node1
, we get- Left distances from child
2
:[2, 2]
. - Right distances from child
3
:[2, 2]
. - Count pairs
(4, 6)
:2 + 2 = 4 > 3
→ Not counted. - Count pairs
(4, 7)
:2 + 2 = 4 > 3
→ Not counted. - Count pairs
(5, 6)
:2 + 2 = 4 > 3
→ Not counted. - Count pairs
(5, 7)
:2 + 2 = 4 > 3
→ Not counted. - Prepare distances (if any). Since all current distances are not less than
distance - 1
, no distances are prepared to return.
- Left distances from child