Problem

Given the root of a binary tree and two integers p and q, return the distance between the nodes of value p and value q in the tree.

The distance between two nodes is the number of edges on the path from one to the other.

Examples

Example 1

graph TD;
    A[3] --> B[5]
    A[3] --> C[1]
    B[5] --> D[6]
    B[5] --> E[2]
    C[1] --> F[0]
    C[1] --> G[8]
    E[2] --> H[7]
    E[2] --> I[4]
  
1
2
3
Input: root = [3,5,1,6,2,0,8,null,null,7,4], p = 5, q = 0
Output: 3
Explanation: There are 3 edges between 5 and 0: 5-3-1-0.

Example 2

1
2
3
nput: root = [3,5,1,6,2,0,8,null,null,7,4], p = 5, q = 7
Output: 2
Explanation: There are 2 edges between 5 and 7: 5-2-7.

Example 3

1
2
3
Input: root = [3,5,1,6,2,0,8,null,null,7,4], p = 5, q = 5
Output: 0
Explanation: The distance between a node and itself is 0.

Solution

The distance between two nodes in a binary tree refers to the minimum number of edges that must be traversed to move from one node to another. Since all nodes in a binary tree share a common ancestor, we can calculate their distance by first identifying the Lowest Common Ancestor of a Binary Tree.

After that we have 2 ways to calculate the distance:

  1. Calculate the distance of nodes p and q from their LCA.
  2. Calculate the distance of nodes p, q and their LCA from root node and plugin the maths.

Method 1 - Finding distance from lowest common ancestor

Here is the approach:

  1. Find the Lowest Common Ancestor (LCA):
    • The LCA of two nodes is the deepest node that is an ancestor of both nodes. This helps identify the common point where the paths to p and q diverge.
  2. Find the Distance from LCA to each Node:
    • Calculate the distance from the LCA to p and the LCA to q using Depth First Search (DFS).
  3. Calculate Total Distance:
    • The total distance between p and q is the sum of these two distances.

Code

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
class Solution {
    public int findDistance(TreeNode root, int p, int q) {
        TreeNode lca = findLCA(root, p, q);
        int distP = findDist(lca, p, 0);
        int distQ = findDist(lca, q, 0);
        return distP + distQ;
    }
    private TreeNode findLCA(TreeNode root, int p, int q) {
        if (root == null || root.val == p || root.val == q) {
            return root;
        }
        TreeNode left = findLCA(root.left, p, q);
        TreeNode right = findLCA(root.right, p, q);
        if (left != null && right != null) {
            return root;
        }
        return left != null ? left : right;
    }

    private int findDist(TreeNode root, int target, int d) {
        if (root == null) {
            return -1;
        }
        if (root.val == target) {
            return d;
        }
        int left = findDist(root.left, target, d + 1);
        if (left != -1) {
            return left;
        }
        return findDist(root.right, target, d + 1);
    }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
class Solution:
    def find_distance(self, root: Optional[TreeNode], p: int, q: int) -> int:
        lca = self.find_lca(root, p, q)
        dist_p = self.find_dist(lca, p, 0)
        dist_q = self.find_dist(lca, q, 0)
        return dist_p + dist_q
    def find_lca(self, root: Optional[TreeNode], p: int, q: int) -> Optional[TreeNode]:
        if not root or root.val == p or root.val == q:
            return root
        left = self.find_lca(root.left, p, q)
        right = self.find_lca(root.right, p, q)
        if left and right:
            return root
        return left if left else right

    def find_dist(self, root: Optional[TreeNode], target: int, d: int) -> int:
        if not root:
            return -1
        if root.val == target:
            return d
        left = self.find_dist(root.left, target, d + 1)
        if left != -1:
            return left
        return self.find_dist(root.right, target, d + 1)

Complexity

  • ⏰ Time complexity: O(n)
    • Finding the LCA requires traversal of the entire binary tree in the worst case.
    • Calculating distances to p and q involves two additional traversals in the subtree determined by the LCA.
  • 🧺 Space complexity: O(h). If the binary tree is balanced, the maximum height h will result in space used by the stack during recursive calls.

Method 2 - Finding the LCA and then use maths to get the right distance

Once we calculate LCA, we can use the folowing formula:

1
2
3
4
lca = LCA(root, p, q);
distance(p, q) = distance(root, p) + distance(root, q) - 2 * distance(root, lca), 
which simplifies to:
distance(p, q) = level(p) + level(q) - 2 * level(lca).

For example, if distance(6, 4) = 3distance(root, 6) = 2, and distance(root, 4) = 3, the LCA of 4 and 6 is 5, and distance(root, 5) = 1. Using the formula:

1
distance(6, 4) = 2 + 3 - 2 * 1 = 3

The problem is therefore reduced to two subtasks:

  1. Find distance from root to given node in binary tree (depth or level).
  2. Identifying the LCA of two nodes in binary tree.

To solve this efficiently, we can implement helper functions for both tasks, achieving a time complexity of O(n).

Approach

  1. Finding LCA:
    • Use a recursive method to identify the node that is the lowest common ancestor.
    • Traverse the left and right subtrees to locate nodes a and b.
    • If both nodes are found in opposite branches, the current node is the LCA.
  2. Finding Levels:
    • Perform a DFS traversal to compute the level of any node relative to the root.
  3. Final Distance:
    • Using the formula above, compute the total distance between the two nodes.

Code

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
    private TreeNode findLCA(TreeNode root, int a, int b) {
        if (root == null || root.val == a || root.val == b) {
            return root;
        }
        TreeNode left = findLCA(root.left, a, b);
        TreeNode right = findLCA(root.right, a, b);
        if (left != null && right != null) {
            return root;
        }
        return left != null ? left : right;
    }
    
    private int findLevel(TreeNode root, int target, int level) {
        if (root == null) {
            return -1;
        }
        if (root.val == target) {
            return level;
        }
        int left = findLevel(root.left, target, level + 1);
        if (left != -1) {
            return left;
        }
        return findLevel(root.right, target, level + 1);
    }
    
    public int findDistance(TreeNode root, int a, int b) {
        TreeNode lca = findLCA(root, a, b);
        int levelA = findLevel(root, a, 0);
        int levelB = findLevel(root, b, 0);
        int levelLCA = findLevel(root, lca.val, 0);
        return levelA + levelB - 2 * levelLCA;
    }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
class Solution:
    def find_lca(self, root: Optional[TreeNode], a: int, b: int) -> Optional[TreeNode]:
        if not root or root.val == a or root.val == b:
            return root
        left = self.find_lca(root.left, a, b)
        right = self.find_lca(root.right, a, b)
        if left and right:
            return root
        return left if left else right

    def find_level(self, root: Optional[TreeNode], target: int, level: int) -> int:
        if not root:
            return -1
        if root.val == target:
            return level
        left = self.find_level(root.left, target, level + 1)
        if left != -1:
            return left
        return self.find_level(root.right, target, level + 1)

    def find_distance(self, root: Optional[TreeNode], a: int, b: int) -> int:
        lca = self.find_lca(root, a, b)
        level_a = self.find_level(root, a, 0)
        level_b = self.find_level(root, b, 0)
        level_lca = self.find_level(root, lca.val, 0)
        return level_a + level_b - 2 * level_lca

Complexity

  • ⏰ Time complexity: O(n)
    • Finding LCA requires traversing the entire binary tree once in the worst case.
    • Calculating levels requires additional traversals to measure depth from root/LCA.
  • 🧺 Space complexity: O(h). The recursion stack space for DFS, where h is the height of the binary tree.