Problem

Given the root of a binary tree, return the same tree where every subtree (of the given tree) not containing a 1 has been removed.

A subtree of a node node is node plus every node that is a descendant of node.

Examples

Example 1:

graph LR;
	subgraph Original
	    A[1] ~~~ N1:::hidden
	    A --> B[0] 
	    B --> C[0]:::prune
	    B --> D[1]
    end

    subgraph After["After Pruning"]
	    A2[1] ~~~ N3:::hidden
	    A2 --> B2[0] 
	    B2 ~~~ N2[0]:::hidden
	    B2 --> D2[1]
    end

classDef prune fill:#FFA500,stroke:#333,stroke-width:2px;
classDef hidden display: none;

    
Original --> After
  
Input: root = [1,null,0,0,1]
Output: [1,null,0,null,1]
Explanation: 
Only the red nodes satisfy the property "every subtree not containing a 1".
The diagram on the right represents the answer.

Example 2:

graph LR;
subgraph Original
    A --> B[0]:::prune
    A --> E[1]
    B --> C[0]:::prune
    B --> D[0]:::prune
    E --> F[0]:::prune
    E --> G[1]
    end

    subgraph After["After Pruning"]
    A2[1] ~~~ N3:::hidden
    A2 --> E2[1]
    E2 ~~~ N4:::hidden
    E2 --> G2[1]
    end

classDef prune fill:#FFA500,stroke:#333,stroke-width:2px;
classDef hidden display: none;

    
Original --> After
  
Input: root = [1,0,1,0,0,0,1]
Output: [1,null,1,null,1]

Example 3:

graph LR;
subgraph Original
    A --> B[1] & E[0]
    B --> C[1] & D[1]
    E --> F[0]:::prune & G[1]
    C --> H[0]:::prune
    C ~~~ N12:::hidden
end

subgraph After["After Pruning"]
    A2 --> B2[1] & E2[0]
    B2 --> C2[1] & D2[1]
    E2 ~~~ N4[0]:::hidden
    E2 --> G2[1]
    C2 ~~~ N5[0]:::hidden
end

classDef prune fill:#FFA500,stroke:#333,stroke-width:2px;
classDef hidden display: none;


Original --> After
  
Input: root = [1,1,0,1,1,0,1,0]
Output: [1,1,0,1,1,null,1]

Solution

Method 1 - Using Traversal

Here is how we can do this:

  • Use a recursive approach to traverse the binary tree.
  • For each node, inspect its left and right subtrees recursively.
  • If the current node has value 0 and both its left and right subtrees are null, prune this node by returning null.
  • Otherwise, continue the recursion and return the node.

Approach

  1. Implement a recursive function to evaluate subtrees.
  2. Traverse each subtree and prune as necessary based on the presence of a 1.
  3. Return the pruned tree.

Code

Java
public class Solution {
    public TreeNode pruneTree(TreeNode root) {
        if (root == null) {
            return null;
        }

        root.left = pruneTree(root.left);
        root.right = pruneTree(root.right);

        if (root.left == null && root.right == null && root.val == 0) {
            return null;
        }

        return root;
    }
}
Python
class Solution:
    def pruneTree(self, root: Optional[TreeNode]) -> Optional[TreeNode]:
        if not root:
            return None
        
        root.left = self.pruneTree(root.left)
        root.right = self.pruneTree(root.right)

        if root.left is None and root.right is None and root.val == 0:
            return None

        return root

Complexity

  • ⏰ Time complexity: O(n) where n is the number of nodes in the binary tree, as we need to visit each node.
  • 🧺 Space complexity: O(h) where h is the height of the tree due to the recursion stack.