Problem

You are given the root of a binary tree and an integer distance. A pair of two different leaf nodes of a binary tree is said to be good if the length of the shortest path between them is less than or equal to distance.

Return the number of good leaf node pairs in the tree.

Examples

Example 1:

  1
 / \
2   3
 \
  4
Input:
root = [1,2,3,null,4], distance = 3
Output:
 1
Explanation: The leaf nodes of the tree are 3 and 4 and the length of the shortest path between them is 3. This is the only good pair.

Example 2:

	 1
	/ \
   2   3
  / \ / \
 4  5 6  7
Input:
root = [1,2,3,4,5,6,7], distance = 3
Output:
 2
Explanation: The good pairs are [4,5] and [6,7] with shortest path = 2. The pair [4,6] is not good because the length of ther shortest path between them is 4.

Example 3:

	 7
   /   \
  1     4
 /     / \
6	  5   3
	       \  
		    2
Input:
root = [7,1,4,6,null,5,3,null,null,null,null,null,2], distance = 3
Output:
 1
Explanation: The only good pair is [2,5].

Solution

Method 1 - Postorder DFS

Here is a step-by-step plan to solve this problem:

  1. Perform a DFS traversal of the binary tree.
  2. For each node, collect the distances to leaves in its subtree.
  3. Merge the lists of distances from the left and right subtrees. As, we are moving from child to parent node, after doing DFS, we increment the distances by 1, at the time of merging, as the distance to leaves increase.
  4. Count pairs of leaf nodes from the merged list that satisfy the distance constraint.
  5. Return the number of valid pairs.

Also, when going up to parent, we consider only threshold distance - 1, because for any distance >= distance - 1, after incrementing, it will become at least threshold distance, which we do not need to propagate any further up the tree – it exceeds the permissible limit for forming any good pairs with other distances. Because, when we go to parent, then for another leaf node, we have to go down by 1 at least, and hence the distance will at least become threshold distance + 1.

Here is the video explanation:

Code

Java
public class Solution {

	public int countPairs(TreeNode root, int distance) {
		int[] ans = new int[1]; // Using an array to allow updates within the dfs function
		dfs(root, distance, ans);
		return ans[0];
	}

	private int[] dfs(TreeNode node, int distance, int[] ans) {
		if (node == null) {
			return new int[0];
		}

		if (node.left == null && node.right == null) { // If it's a leaf node
			return new int[]{1};
		}

		int[] leftDistances = dfs(node.left, distance, ans);
		int[] rightDistances = dfs(node.right, distance, ans);

		// Count pairs across left and right subtrees
		for (int ld: leftDistances) {
			for (int rd: rightDistances) {
				if (ld + rd <= distance) {
					ans[0]++;
				}
			}
		}

		// Prepare and return distance list for the current node
		int[] distances = new int[leftDistances.length + rightDistances.length];
		int index = 0;

		for (int ld: leftDistances) {
			if (ld + 1 < distance) { // Only consider up to `distance-1`
				distances[index++] = ld + 1;
			}
		}

		for (int rd: rightDistances) {
			if (rd + 1 < distance) { // Only consider up to `distance-1`
				distances[index++] = rd + 1;
			}
		}

		return Arrays.copyOf(distances, index);
	}
}
Python
def countPairs(root: TreeNode, distance: int) -> int:
    result = [0]  # Using a list to allow updates within the DFS function

    def dfs(node):
        if not node:
            return []

        if not node.left and not node.right:  # If it's a leaf node
            return [1]

        left_distances = dfs(node.left)
        right_distances = dfs(node.right)

        # Count pairs across left and right subtrees
        for ld in left_distances:
            for rd in right_distances:
                if ld + rd <= distance:
                    result[0] += 1

        # Prepare and return distance list for the current node
        distances = []
        for ld in left_distances:
            if ld + 1 < distance:  # Only consider up to `distance-1`
                distances.append(ld + 1)
        for rd in right_distances:
            if rd + 1 < distance:  # Only consider up to `distance-1`
                distances.append(rd + 1)

        return distances

    dfs(root)
    return result[0]

Complexity

  • Time: O(n*d^2) where n is number of nodes in tree, and d is number of leaf nodes.
  • Space: O(n) because recursion stack takes O(h) times, which is O(log n) for balanced tree, and O(n) for skewed tree. Then, list of distances list take O(d) space, but in worst case it is also O(n).

Dry Run

Lets take eg. 2 for dry run, where distance = 3 and tree is:

graph TD;
	1;
	1 --- 2;
	1 --- 3;
	2 --- 4;
	2 --- 5;
	3 --- 6;
	3 --- 7;
  

We initialize ans[0] = 0 and start DFS from node 1.

DFS Order
  • We reach node 4, which is leaf node. We return distance to itself as [1].
  • Similarly, for node 5, we return distance as [1].
  • Node 2 - Now, we reach node `2
    • Left distances from child 4[1].
    • Right distances from child 5[1].
    • Count pairs (4, 5)1 + 1 = 2 <= 3 → Increment ans to [1].
    • Prepare distances: [2, 2] and return [2, 2].
  • Node 6, it is a leaf node ⇨ return [1] as distance to self.
  • Node 7, it is a leaf node ⇨ return [1] as distance to self.
  • Node 3 - now we come to node 3:
    • Left distances from child 6[1].
    • Right distances from child 7[1].
    • Count pairs (6, 7)1 + 1 = 2 <= 3 → Increment result to [2].
    • Prepare distances: [2, 2] and return [2, 2].
  • Node 1 - We are at node 1, we get
    • Left distances from child 2[2, 2].
    • Right distances from child 3[2, 2].
    • Count pairs (4, 6)2 + 2 = 4 > 3 → Not counted.
    • Count pairs (4, 7)2 + 2 = 4 > 3 → Not counted.
    • Count pairs (5, 6)2 + 2 = 4 > 3 → Not counted.
    • Count pairs (5, 7)2 + 2 = 4 > 3 → Not counted.
    • Prepare distances (if any). Since all current distances are not less than distance - 1, no distances are prepared to return.