Distance between two nodes in binary tree
Problem
Given the root of a binary tree and two integers p and q, return the distance between the nodes of value p and value q in the tree.
The distance between two nodes is the number of edges on the path from one to the other.
Examples
Example 1
graph TD; A[3] --> B[5] A[3] --> C[1] B[5] --> D[6] B[5] --> E[2] C[1] --> F[0] C[1] --> G[8] E[2] --> H[7] E[2] --> I[4]
Input: root = [3,5,1,6,2,0,8,null,null,7,4], p = 5, q = 0
Output: 3
Explanation: There are 3 edges between 5 and 0: 5-3-1-0.
Example 2
nput: root = [3,5,1,6,2,0,8,null,null,7,4], p = 5, q = 7
Output: 2
Explanation: There are 2 edges between 5 and 7: 5-2-7.
Example 3
Input: root = [3,5,1,6,2,0,8,null,null,7,4], p = 5, q = 5
Output: 0
Explanation: The distance between a node and itself is 0.
Solution
The distance between two nodes in a binary tree refers to the minimum number of edges that must be traversed to move from one node to another. Since all nodes in a binary tree share a common ancestor, we can calculate their distance by first identifying the [Lowest Common Ancestor of a Binary Tree](lowest-common-ancestor-of-a-binary-tree).
After that we have 2 ways to calculate the distance:
- Calculate the distance of nodes
pandqfrom their LCA. - Calculate the distance of nodes
p,qand their LCA from root node and plugin the maths.
Method 1 - Finding distance from lowest common ancestor
Here is the approach:
- Find the Lowest Common Ancestor (LCA):
- The LCA of two nodes is the deepest node that is an ancestor of both nodes. This helps identify the common point where the paths to
pandqdiverge.
- The LCA of two nodes is the deepest node that is an ancestor of both nodes. This helps identify the common point where the paths to
- Find the Distance from LCA to each Node:
- Calculate the distance from the LCA to
pand the LCA toqusing Depth First Search (DFS).
- Calculate the distance from the LCA to
- Calculate Total Distance:
- The total distance between
pandqis the sum of these two distances.
- The total distance between
Code
Java
class Solution {
public int findDistance(TreeNode root, int p, int q) {
TreeNode lca = findLCA(root, p, q);
int distP = findDist(lca, p, 0);
int distQ = findDist(lca, q, 0);
return distP + distQ;
}
private TreeNode findLCA(TreeNode root, int p, int q) {
if (root == null || root.val == p || root.val == q) {
return root;
}
TreeNode left = findLCA(root.left, p, q);
TreeNode right = findLCA(root.right, p, q);
if (left != null && right != null) {
return root;
}
return left != null ? left : right;
}
private int findDist(TreeNode root, int target, int d) {
if (root == null) {
return -1;
}
if (root.val == target) {
return d;
}
int left = findDist(root.left, target, d + 1);
if (left != -1) {
return left;
}
return findDist(root.right, target, d + 1);
}
}
Python
class Solution:
def find_distance(self, root: Optional[TreeNode], p: int, q: int) -> int:
lca = self.find_lca(root, p, q)
dist_p = self.find_dist(lca, p, 0)
dist_q = self.find_dist(lca, q, 0)
return dist_p + dist_q
def find_lca(self, root: Optional[TreeNode], p: int, q: int) -> Optional[TreeNode]:
if not root or root.val == p or root.val == q:
return root
left = self.find_lca(root.left, p, q)
right = self.find_lca(root.right, p, q)
if left and right:
return root
return left if left else right
def find_dist(self, root: Optional[TreeNode], target: int, d: int) -> int:
if not root:
return -1
if root.val == target:
return d
left = self.find_dist(root.left, target, d + 1)
if left != -1:
return left
return self.find_dist(root.right, target, d + 1)
Complexity
- ⏰ Time complexity:
O(n)- Finding the LCA requires traversal of the entire binary tree in the worst case.
- Calculating distances to
pandqinvolves two additional traversals in the subtree determined by the LCA.
- 🧺 Space complexity:
O(h). If the binary tree is balanced, the maximum heighthwill result in space used by the stack during recursive calls.
Method 2 - Finding the LCA and then use maths to get the right distance
Once we calculate LCA, we can use the folowing formula:
lca = LCA(root, p, q);
distance(p, q) = distance(root, p) + distance(root, q) - 2 * distance(root, lca),
which simplifies to:
distance(p, q) = level(p) + level(q) - 2 * level(lca).
For example, if distance(6, 4) = 3, distance(root, 6) = 2, and distance(root, 4) = 3, the LCA of 4 and 6 is 5, and distance(root, 5) = 1. Using the formula:
distance(6, 4) = 2 + 3 - 2 * 1 = 3
The problem is therefore reduced to two subtasks:
- [Find distance from root to given node in binary tree](find-distance-from-root-to-given-node-in-binary-tree) (depth or level).
- Identifying the [LCA](lowest-common-ancestor-of-a-binary-tree) of two nodes in binary tree.
To solve this efficiently, we can implement helper functions for both tasks, achieving a time complexity of O(n).
Approach
- Finding LCA:
- Use a recursive method to identify the node that is the lowest common ancestor.
- Traverse the left and right subtrees to locate nodes
aandb. - If both nodes are found in opposite branches, the current node is the LCA.
- Finding Levels:
- Perform a DFS traversal to compute the level of any node relative to the root.
- Final Distance:
- Using the formula above, compute the total distance between the two nodes.
Code
Java
private TreeNode findLCA(TreeNode root, int a, int b) {
if (root == null || root.val == a || root.val == b) {
return root;
}
TreeNode left = findLCA(root.left, a, b);
TreeNode right = findLCA(root.right, a, b);
if (left != null && right != null) {
return root;
}
return left != null ? left : right;
}
private int findLevel(TreeNode root, int target, int level) {
if (root == null) {
return -1;
}
if (root.val == target) {
return level;
}
int left = findLevel(root.left, target, level + 1);
if (left != -1) {
return left;
}
return findLevel(root.right, target, level + 1);
}
public int findDistance(TreeNode root, int a, int b) {
TreeNode lca = findLCA(root, a, b);
int levelA = findLevel(root, a, 0);
int levelB = findLevel(root, b, 0);
int levelLCA = findLevel(root, lca.val, 0);
return levelA + levelB - 2 * levelLCA;
}
}
Python
class Solution:
def find_lca(self, root: Optional[TreeNode], a: int, b: int) -> Optional[TreeNode]:
if not root or root.val == a or root.val == b:
return root
left = self.find_lca(root.left, a, b)
right = self.find_lca(root.right, a, b)
if left and right:
return root
return left if left else right
def find_level(self, root: Optional[TreeNode], target: int, level: int) -> int:
if not root:
return -1
if root.val == target:
return level
left = self.find_level(root.left, target, level + 1)
if left != -1:
return left
return self.find_level(root.right, target, level + 1)
def find_distance(self, root: Optional[TreeNode], a: int, b: int) -> int:
lca = self.find_lca(root, a, b)
level_a = self.find_level(root, a, 0)
level_b = self.find_level(root, b, 0)
level_lca = self.find_level(root, lca.val, 0)
return level_a + level_b - 2 * level_lca
Complexity
- ⏰ Time complexity:
O(n)- Finding LCA requires traversing the entire binary tree once in the worst case.
- Calculating levels requires additional traversals to measure depth from root/LCA.
- 🧺 Space complexity:
O(h). The recursion stack space for DFS, wherehis the height of the binary tree.