Problem

Given the root of a binary tree, each node in the tree has a distinct value.

After deleting all nodes with a value in to_delete, we are left with a forest (a disjoint union of trees).

Return the roots of the trees in the remaining forest. You may return the result in any order.

Examples

Example 1:

graph TD;
	1;
	1 --- 2;
	1 --- 3;
	2 --- 4;
	2 --- 5;
	3 --- 6;
	3 --- 7;
  
Input: root = [1,2,3,4,5,6,7], to_delete = [3,5]
Output:[[1,2,null,4],[6],[7]]

Example 2:

Input: root = [1,2,4,null,3], to_delete = [3]
Output:[[1,2,4]]

Solution

Method 1 - Recursion

We can use recursive traversal to do 2 things:

  1. Deletion: Decide during traversal whether the current node needs to be deleted.
  2. Forest Formation:
    • If a node is marked for deletion, its children should become new roots of the resulting forest if they exist.
    • If a node isn’t marked for deletion and it’s the root, add it to the result list.

So, here are the steps we can follow:

  1. Convert the to_delete list into a set for O(1) average-time complexity look-ups.
  2. Define a recursive function to traverse the tree:
    • If the current node is null, return null.
    • Recursively call the function on left and right children.
    • If the current node is in the to_delete set:
      • Recursively handle its left and right children to add them to the forest if they are not null.
      • Return null for the current node to remove it.
    • Otherwise, return the current node.
  3. Handle the Root: Check if the root itself is to be deleted or retained.

Here is the video explanation of the same:

Code

Java
public class Solution {

	public List<TreeNode> delNodes(TreeNode root, int[] to_delete) {
		Set<Integer> toDeleteSet = new HashSet<>();

		for (int val: to_delete) {
			toDeleteSet.add(val);
		}

		List<TreeNode> forest = new ArrayList<>();
		root = deleteNodes(root, toDeleteSet, forest, true);

		return forest;
	}

	private TreeNode deleteNodes(TreeNode node, Set<Integer> toDeleteSet, List<TreeNode> forest, boolean isRoot) {
		if (node == null) {
			return null;
		}

		boolean isDeleted = toDeleteSet.contains(node.val);

		if (isRoot && !isDeleted) {
			forest.add(node);
		}

		node.left = deleteNodes(node.left, toDeleteSet, forest, isDeleted);
		node.right = deleteNodes(node.right, toDeleteSet, forest, isDeleted);

		return isDeleted ? null : node;
	}
}

Complexity

  • ⏰ Time complexity: O(n)
  • 🧺 Space complexity: O(h) assuming height of tree