Problem

Given a tree where each edge has a weight, compute the length of the longest path in the tree.

Examples

Example 1

    a
   /|\
  b c d
     / \
    e   f
   / \
  g   h 
Input: edges = [('a', 'b', 3), ('a', 'c', 5), ('a', 'd', 8), ('d', 'e', 2), ('d', 'f', 4), ('e', 'g', 1), ('e', 'h', 1)]
Output: 17
Explanation: 
Weights: a-b: 3, a-c: 5, a-d: 8, d-e: 2, d-f: 4, e-g: 1, e-h: 1
The longest path is c -> a -> d -> f with a length of 17.

Solution

Method 1 - DFS

To find the longest path in a weighted tree, we can use Depth First Search (DFS). The idea is to track the maximum length paths in the tree by considering:

  • The longest path that passes through each node (using two longest paths in its subtree).
  • The longest path in any subtree without passing through the current node.

We will maintain a global variable to keep track of the maximum path length encountered during our traversal.

Approach

  1. Use a helper function to perform DFS on the tree and return the maximum path length ending at a given node.
  2. In the DFS function:
    • For each child, compute the longest path in its subtree.
    • Track the two longest paths in the current node’s subtree.
    • Update the global maximum path length by considering the longest path passing through the current node (sum of the two longest paths from its children).
  3. Return the global maximum path length.

Code

Java
class Solution {
    class TreeNode {
        char val;
        List<Edge> children = new ArrayList<>();
        TreeNode(char x) { val = x; }
    }
    
    class Edge {
        TreeNode node;
        int weight;
        Edge(TreeNode n, int w) { node = n; weight = w; }
    }
    
    private int maxLen = 0;
    
    public int longestPath(TreeNode root) {
        dfs(root);
        return maxLen;
    }

    private int dfs(TreeNode node) {
        if (node == null) return 0;

        int max1 = 0, max2 = 0;

        for (Edge edge : node.children) {
            int length = dfs(edge.node) + edge.weight;
            if (length > max1) {
                max2 = max1;
                max1 = length;
            } else if (length > max2) {
                max2 = length;
            }
        }
        
        maxLen = Math.max(maxLen, max1 + max2);
        return max1;
    }
    
    // Example usage:
    public static void main(String[] args) {
        TreeNode a = new Solution().new TreeNode('a');
        TreeNode b = new Solution().new TreeNode('b');
        TreeNode c = new Solution().new TreeNode('c');
        TreeNode d = new Solution().new TreeNode('d');
        TreeNode e = new Solution().new TreeNode('e');
        TreeNode f = new Solution().new TreeNode('f');
        TreeNode g = new Solution().new TreeNode('g');
        TreeNode h = new Solution().new TreeNode('h');
        
        a.children.add(new Solution().new Edge(b, 3));
        a.children.add(new Solution().new Edge(c, 5));
        a.children.add(new Solution().new Edge(d, 8));
        d.children.add(new Solution().new Edge(e, 2));
        d.children.add(new Solution().new Edge(f, 4));
        e.children.add(new Solution().new Edge(g, 1));
        e.children.add(new Solution().new Edge(h, 1));
        
        Solution sol = new Solution();
        System.out.println(sol.longestPath(a));  // Output: 17
    }
}
Python
class Solution:
    class TreeNode:
        def __init__(self, val: str):
            self.val: str = val
            self.children: List['Solution.Edge'] = []

    class Edge:
        def __init__(self, node: 'Solution.TreeNode', weight: int):
            self.node: 'Solution.TreeNode' = node
            self.weight: int = weight
    
    def __init__(self):
        self.max_len = 0
    
    def longestPath(self, nodes: List[str], edges: List[Tuple[str, str, int]]) -> int:
        node_map: Dict[str, 'Solution.TreeNode'] = {val: self.TreeNode(val) for val in nodes}
        
        for parent, child, weight in edges:
            node_map[parent].children.append(self.Edge(node_map[child], weight))
        
        if not nodes:
            return 0

        self.dfs(node_map[nodes[0]])
        return self.max_len
    
    def dfs(self, node: 'Solution.TreeNode') -> int:
        if not node:
            return 0
        
        max1, max2 = 0, 0
        
        for edge in node.children:
            length = self.dfs(edge.node) + edge.weight
            if length > max1:
                max2 = max1
                max1 = length
            elif length > max2:
                max2 = length
        
        self.max_len = max(self.max_len, max1 + max2)
        return max1

# Example usage:
nodes = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h']
edges = [('a', 'b', 3), ('a', 'c', 5), ('a', 'd', 8), ('d', 'e', 2), ('d', 'f', 4), ('e', 'g', 1), ('e', 'h', 1)]

sol = Solution()
print(sol.longestPath(nodes, edges))  # Output: 17

Complexity

  • ⏰ Time complexity: O(n) where n is the number of nodes in the tree. Each node and its edges are processed once.
  • 🧺 Space complexity: O(n) for the recursion stack in the worst case if the tree is skewed.

Method 2 -

Code

Java
Python

Complexity

  • ⏰ Time complexity: O(NNNXXXNNN)
  • 🧺 Space complexity: O(NNNXXX)