Problem

Given a binary tree and a positive integer k, return all nodes that are exactly k distance away from any leaf node. A leaf node is a node with no children.

OR

Given a binary tree and a positive integer k, print all nodes that are exactly k distance away from any leaf node. A leaf node is a node with no children.

Examples

Example 1

graph TD;
    A1[1] 
    A1 --> B2[2] 
    A1 --> B3[3]
    B2 --> C4[4]
    B2 --> C5[5]
    C4 --> D7[7]
    C4 --> D8[8]
    B3 --> C6[6]
    C6 --> D9[9]
    C6 --> D10[10]

    style A1 fill:#f9f,stroke:#333,stroke-width:4px
    style B3 fill:#f9f,stroke:#333,stroke-width:4px
  
Input: root = [1, 2, 3, 4, 5, null, 6, 7, 8, null, null, 9, 10], k = 2
Output: [1, 2, 3]
Explanation: 
Nodes at a distance of 2 from any leaf nodes (7, 8, 5, 9, 10) are 1 and 3.
- Distance 2 from leaf 7: Node 4 -> 2
- Distance 2 from leaf 8: Node 4 -> 2
- Distance 2 from leaf 5: Node 2 -> 1
- Distance 2 from leaf 9: Node 6 -> 3
- Distance 2 from leaf 10: Node 6 -> 3

Example 2

Input: root = [1, 2, 3], k = 1
Output: [1]

Similar Problem

Nodes at Distance k from root in Binary Tree All Nodes Distance K in Binary Tree

Solution

Method 1 - DFS

Here are key points:

  • Use Depth First Search (DFS) to traverse the tree.
  • Track the path from the root to each node using a list or stack.
  • Identify the leaf nodes (nodes with no children).
  • From each leaf node, trace back k steps to find the nodes that are exactly k distance away.
  • Use a set to avoid duplicates, since multiple leaf nodes can point to the same root node at distance k.

Approach

  1. Use DFS to traverse the tree and maintain a path.
  2. When a leaf node is encountered, record the nodes that are k distance back up the path.
  3. Avoid duplicates by using a set.

Code

Java
public class Solution {
    public List<Integer> nodesAtDistanceKFromLeaf(TreeNode root, int k) {
        List<Integer> ans = new ArrayList<>();
        Set<TreeNode> uniqueNodes = new HashSet<>();
        List<TreeNode> path = new ArrayList<>();
        dfs(root, k, path, uniqueNodes);
        for (TreeNode node : uniqueNodes) {
            ans.add(node.val);
        }
        return ans;
    }

    private void dfs(TreeNode node, int k, List<TreeNode> path, Set<TreeNode> uniqueNodes) {
        if (node == null) {
            return;
        }

        // Add current node to the path
        path.add(node);

        // Check if leaf node
        if (node.left == null && node.right == null) {
            if (path.size() > k) {
                uniqueNodes.add(path.get(path.size() - k - 1));
            }
        }

        // Recur for left and right children
        dfs(node.left, k, path, uniqueNodes);
        dfs(node.right, k, path, uniqueNodes);

        // Remove current node from the path on return
        path.remove(path.size() - 1);
    }
}
Python
class Solution:
    def nodesAtDistanceKFromLeaf(self, root: Optional[TreeNode], k: int) -> List[int]:
        def dfs(node: TreeNode, k: int, path: List[TreeNode], unique_nodes: Set[TreeNode]) -> None:
            if not node:
                return

            # Add current node to the path
            path.append(node)

            # Check if leaf node
            if not node.left and not node.right:
                if len(path) > k:
                    unique_nodes.add(path[-k-1])

            # Recur for left and right children
            dfs(node.left, k, path, unique_nodes)
            dfs(node.right, k, path, unique_nodes)

            # Remove current node from the path on return
            path.pop()

        unique_nodes = set()
        ans = []
        dfs(root, k, [], unique_nodes)
        return [node.val for node in unique_nodes]

Complexity

  • ⏰ Time complexity: O(n), where n is the number of nodes in the tree because each node is visited once.
  • 🧺 Space complexity: O(h + m), where h is the maximum depth of the tree (for the recursion stack) and m is the number of unique nodes at distance k from leaf nodes (for the result storage).