Problem
Given a tree where each edge has a weight, compute the length of the longest path in the tree.
Examples
Example 1
a
/|\
b c d
/ \
e f
/ \
g h
Input: edges = [('a', 'b', 3), ('a', 'c', 5), ('a', 'd', 8), ('d', 'e', 2), ('d', 'f', 4), ('e', 'g', 1), ('e', 'h', 1)]
Output: 17
Explanation:
Weights: a-b: 3, a-c: 5, a-d: 8, d-e: 2, d-f: 4, e-g: 1, e-h: 1
The longest path is c -> a -> d -> f with a length of 17.
Solution
Method 1 - DFS
To find the longest path in a weighted tree, we can use Depth First Search (DFS). The idea is to track the maximum length paths in the tree by considering:
- The longest path that passes through each node (using two longest paths in its subtree).
- The longest path in any subtree without passing through the current node.
We will maintain a global variable to keep track of the maximum path length encountered during our traversal.
Approach
- Use a helper function to perform DFS on the tree and return the maximum path length ending at a given node.
- In the DFS function:
- For each child, compute the longest path in its subtree.
- Track the two longest paths in the current node’s subtree.
- Update the global maximum path length by considering the longest path passing through the current node (sum of the two longest paths from its children).
- Return the global maximum path length.
Code
Java
class Solution {
class TreeNode {
char val;
List<Edge> children = new ArrayList<>();
TreeNode(char x) { val = x; }
}
class Edge {
TreeNode node;
int weight;
Edge(TreeNode n, int w) { node = n; weight = w; }
}
private int maxLen = 0;
public int longestPath(TreeNode root) {
dfs(root);
return maxLen;
}
private int dfs(TreeNode node) {
if (node == null) return 0;
int max1 = 0, max2 = 0;
for (Edge edge : node.children) {
int length = dfs(edge.node) + edge.weight;
if (length > max1) {
max2 = max1;
max1 = length;
} else if (length > max2) {
max2 = length;
}
}
maxLen = Math.max(maxLen, max1 + max2);
return max1;
}
// Example usage:
public static void main(String[] args) {
TreeNode a = new Solution().new TreeNode('a');
TreeNode b = new Solution().new TreeNode('b');
TreeNode c = new Solution().new TreeNode('c');
TreeNode d = new Solution().new TreeNode('d');
TreeNode e = new Solution().new TreeNode('e');
TreeNode f = new Solution().new TreeNode('f');
TreeNode g = new Solution().new TreeNode('g');
TreeNode h = new Solution().new TreeNode('h');
a.children.add(new Solution().new Edge(b, 3));
a.children.add(new Solution().new Edge(c, 5));
a.children.add(new Solution().new Edge(d, 8));
d.children.add(new Solution().new Edge(e, 2));
d.children.add(new Solution().new Edge(f, 4));
e.children.add(new Solution().new Edge(g, 1));
e.children.add(new Solution().new Edge(h, 1));
Solution sol = new Solution();
System.out.println(sol.longestPath(a)); // Output: 17
}
}
Python
class Solution:
class TreeNode:
def __init__(self, val: str):
self.val: str = val
self.children: List['Solution.Edge'] = []
class Edge:
def __init__(self, node: 'Solution.TreeNode', weight: int):
self.node: 'Solution.TreeNode' = node
self.weight: int = weight
def __init__(self):
self.max_len = 0
def longestPath(self, nodes: List[str], edges: List[Tuple[str, str, int]]) -> int:
node_map: Dict[str, 'Solution.TreeNode'] = {val: self.TreeNode(val) for val in nodes}
for parent, child, weight in edges:
node_map[parent].children.append(self.Edge(node_map[child], weight))
if not nodes:
return 0
self.dfs(node_map[nodes[0]])
return self.max_len
def dfs(self, node: 'Solution.TreeNode') -> int:
if not node:
return 0
max1, max2 = 0, 0
for edge in node.children:
length = self.dfs(edge.node) + edge.weight
if length > max1:
max2 = max1
max1 = length
elif length > max2:
max2 = length
self.max_len = max(self.max_len, max1 + max2)
return max1
# Example usage:
nodes = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h']
edges = [('a', 'b', 3), ('a', 'c', 5), ('a', 'd', 8), ('d', 'e', 2), ('d', 'f', 4), ('e', 'g', 1), ('e', 'h', 1)]
sol = Solution()
print(sol.longestPath(nodes, edges)) # Output: 17
Complexity
- ⏰ Time complexity:
O(n)
wheren
is the number of nodes in the tree. Each node and its edges are processed once. - 🧺 Space complexity:
O(n)
for the recursion stack in the worst case if the tree is skewed.
Method 2 -
Code
Java
Python
Complexity
- ⏰ Time complexity:
O(NNNXXXNNN)
- 🧺 Space complexity:
O(NNNXXX)