The distance between two nodes in a binary tree refers to the minimum number of edges that must be traversed to move from one node to another. Since all nodes in a binary tree share a common ancestor, we can calculate their distance by first identifying the Lowest Common Ancestor of a Binary Tree.
After that we have 2 ways to calculate the distance:
Calculate the distance of nodes p and q from their LCA.
Calculate the distance of nodes p, q and their LCA from root node and plugin the maths.
Method 1 - Finding distance from lowest common ancestor#
Here is the approach:
Find the Lowest Common Ancestor (LCA):
The LCA of two nodes is the deepest node that is an ancestor of both nodes. This helps identify the common point where the paths to p and q diverge.
Find the Distance from LCA to each Node:
Calculate the distance from the LCA to p and the LCA to q using Depth First Search (DFS).
Calculate Total Distance:
The total distance between p and q is the sum of these two distances.
For example, if distance(6, 4) = 3, distance(root, 6) = 2, and distance(root, 4) = 3, the LCA of 4 and 6 is 5, and distance(root, 5) = 1. Using the formula:
private TreeNode findLCA(TreeNode root, int a, int b) {
if (root ==null|| root.val== a || root.val== b) {
return root;
}
TreeNode left = findLCA(root.left, a, b);
TreeNode right = findLCA(root.right, a, b);
if (left !=null&& right !=null) {
return root;
}
return left !=null? left : right;
}
privateintfindLevel(TreeNode root, int target, int level) {
if (root ==null) {
return-1;
}
if (root.val== target) {
return level;
}
int left = findLevel(root.left, target, level + 1);
if (left !=-1) {
return left;
}
return findLevel(root.right, target, level + 1);
}
publicintfindDistance(TreeNode root, int a, int b) {
TreeNode lca = findLCA(root, a, b);
int levelA = findLevel(root, a, 0);
int levelB = findLevel(root, b, 0);
int levelLCA = findLevel(root, lca.val, 0);
return levelA + levelB - 2 * levelLCA;
}
}
classSolution:
deffind_lca(self, root: Optional[TreeNode], a: int, b: int) -> Optional[TreeNode]:
ifnot root or root.val == a or root.val == b:
return root
left = self.find_lca(root.left, a, b)
right = self.find_lca(root.right, a, b)
if left and right:
return root
return left if left else right
deffind_level(self, root: Optional[TreeNode], target: int, level: int) -> int:
ifnot root:
return-1if root.val == target:
return level
left = self.find_level(root.left, target, level +1)
if left !=-1:
return left
return self.find_level(root.right, target, level +1)
deffind_distance(self, root: Optional[TreeNode], a: int, b: int) -> int:
lca = self.find_lca(root, a, b)
level_a = self.find_level(root, a, 0)
level_b = self.find_level(root, b, 0)
level_lca = self.find_level(root, lca.val, 0)
return level_a + level_b -2* level_lca