Problem

Flatten a binary tree to a linked list in-place in in-order traversal.

Examples

Example 1:

1
2
3
4
5
         1
        / \
       2   5
      / \   \
     3   4   6

The flattened tree should look like:

1
 3 -> 2 -> 4 -> 1 -> 5 -> 6
1
2
Input: root = [1,2,5,3,4,null,6]
Output: [1,null,2,null,3,null,4,null,5,null,6]

Note that the left child of all nodes should be NULL.

Solution

Transforming a binary tree (BT) into a singly linked list (SLL) is simple, as it involves traversing the tree in in-order and adding the nodes to the linked list. This can be achieved either recursively or iteratively. Using a stack for iterative in-order traversal simplifies managing the resulting list, but in this post, I will focus on the recursive approach, which takes a bit more effort to grasp.

Method 1 - Recursive but not in-place

  1. Perform in-order traversal of the binary tree using recursion.
  2. During traversal, create new linked list nodes, appending them in the order of the traversal.
  3. Maintain a pointer to the head of the resulting linked list and another pointer for building the list iteratively (tail).
  4. Return the head of the newly created linked list.

To achieve this, we can pass and update a reference to the tail (current last node of the linked list) during the recursive calls.

Code

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
class Solution {
	public ListNode flatten(TreeNode root, ListNode tail) {
	    if (root == null) {
	        return tail; // Return current tail if root is null
	    }
	
	    // Recursively flatten the left subtree, updating the tail
	    tail = flatten(root.left, tail);
	
	    // Add current root node to the linked list
	    tail.next = new ListNode(root.val);
	    tail = tail.next; // Update the tail to the newly added node
	
	    // Recursively flatten the right subtree, updating the tail
	    return flatten(root.right, tail);
	}
	
	// Public method to start flattening with a dummy head node
	public ListNode convertToLinkedList(TreeNode root) {
	    ListNode dummy = new ListNode(-1); // Create a dummy head node for simplicity
	    flatten(root, dummy);             // Pass dummy node as the initial tail
	    return dummy.next;                // Return the linked list starting after dummy node
	}
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
class Solution:
    def flatten(self, root, tail):
        if not root:
            return tail  # Return current tail if root is None

        # Flatten the left subtree and update the tail
        tail = self.flatten(root.left, tail)

        # Add current root node to linked list
        tail.next = self.ListNode(root.val)
        tail = tail.next  # Update tail to the newly created node

        # Flatten the right subtree and update the tail
        return self.flatten(root.right, tail)

    def convert_to_linked_list(self, root):
        dummy = self.ListNode(-1)  # Create a dummy head for simplicity
        self.flatten(root, dummy)  # Start flattening with the dummy node as the tail
        return dummy.next  # Return the linked list starting from dummy.next

Complexity

  • ⏰ Time complexity: O(n). Each node is visited during traversal.
  • 🧺 Space complexity: O(n) Stack space used for recursion, where h is the tree height, which in case of skewed tree is O(n). And we also need to return the converted list, for that we need O(n) space.

Method 2 - Recursive and in place

In in-order traversal, nodes are visited in the sequence left → root → right. Therefore, to flatten the binary tree:

  1. Recursively traverse the left subtree in in-order.
  2. Connect the root node to the last node (previously active node).
  3. Then, recursively flatten the right subtree in in-order.

For this approach:

  • Use a variable (prev) to keep track of the previously visited node during traversal.
  • Update the right pointer of the prev node to point to the current node and set left to null.

Approach

  1. Use a recursive function to traverse the tree in in-order.
  2. Maintain a prev variable to hold the previous node in the flattened sequence.
  3. During traversal:
    • Disconnect the left pointer of the current node.
    • Connect the right pointer of the prev node to the current node.
    • Move prev to the current node.
  4. Ensure that recursion handles the left and right subtrees properly.

Code

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
class Solution {
    private TreeNode prev = null; // To keep track of the previous node in traversal

    public void flatten(TreeNode root) {
        if (root == null) return;

        // Flatten the left subtree
        flatten(root.left);

        // Process the current node
        if (prev != null) {
            prev.right = root; // Update the right pointer of previous node
            prev.left = null;  // Remove the left pointer for linked list
        }
        prev = root; // Update prev to current node

        // Flatten the right subtree
        flatten(root.right);
    }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
class Solution:
    def __init__(self):
        self.prev = None # To keep track of the previous node in traversal
    
    def flatten(self, root):
        if not root:
            return
        
        # Flatten the left subtree
        self.flatten(root.left)
        
        # Process the current node
        if self.prev:
            self.prev.right = root # Update the right pointer of prev node
            self.prev.left = None  # Remove the left pointer for linked list
        self.prev = root # Update prev to current node
        
        # Flatten the right subtree
        self.flatten(root.right)

Complexity

  • ⏰ Time complexity: O(n). Each node is visited once during traversal.
  • 🧺 Space complexity: O(h). Space is used by the recursive call stack, where h is the height of the tree.

Method 3 - Iterative Using Stack

How can the binary tree be flattened in-place without relying on additional data structures to hold the list? Once a node is visited, its left and right pointers are no longer required for further traversal. These pointers can be repurposed to form the next and prev links of the linked list. For instance, the right pointer can serve as the next link in the flattened list.

Approach

  1. Use a stack to simulate the recursive process for in-order traversal.
  2. Pop nodes from the stack, process them, and update their right pointer to the next node in order while setting the left pointer to null.
  3. Push the right and left children into the stack for traversal (right first to ensure left is processed first).
  4. Update a previous pointer to track the last visited node and establish links accordingly.

Code

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
class Solution {
    // Definition for binary tree node
    public static class TreeNode {
        int val;
        TreeNode left;
        TreeNode right;
        TreeNode(int val) {
            this.val = val;
        }
    }

    public void flatten(TreeNode root) {
        if (root == null) return;

        Stack<TreeNode> stack = new Stack<>();
        stack.push(root);

        TreeNode prev = null;

        while (!stack.isEmpty()) {
            TreeNode current = stack.pop();

            // Process the current node
            if (prev != null) {
                prev.right = current; // Update previous node's right pointer
                prev.left = null;     // Remove the left pointer
            }

            // Push right and left children into the stack (right first)
            if (current.right != null) stack.push(current.right);
            if (current.left != null) stack.push(current.left);

            prev = current; // Update prev to the current node
        }
    }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
class Solution:
    # Definition for binary tree node
    class TreeNode:
        def __init__(self, val=0, left=None, right=None):
            self.val = val
            self.left = left
            self.right = right

    def flatten(self, root):
        if not root:
            return

        stack = [root]
        prev = None

        while stack:
            current = stack.pop()

            # Process the current node
            if prev:
                prev.right = current  # Update previous node's right pointer
                prev.left = None      # Remove the left pointer

            # Push right and left children into the stack (right first)
            if current.right:
                stack.append(current.right)
            if current.left:
                stack.append(current.left)

Complexity

  • ⏰ Time complexity: O(n). Each node is processed exactly once.
  • 🧺 Space complexity: O(h). Where h is the maximum depth of the tree, corresponding to the stack size. In the worst case, for a skewed tree, this is O(n).

Method 4 - Constant Space, Iterative (Morris Traversal)

Tree traversal can be performed in any order without using a recursion stack or an explicit stack, achieving constant space complexity. Using the Morris Inorder Traversal method discussed earlier, the same mechanism can be applied while appending the visited nodes during traversal. The implementation works in O(n) time and O(1) extra space, excluding the space needed for the returned linked list. In the future, we will explore how this can be done directly in place.

Code

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
class Solution {
    public void flatten(TreeNode root) {
        TreeNode current = root;

        while (current != null) {
            if (current.left != null) {
                // Find predecessor (rightmost node in the left subtree)
                TreeNode predecessor = current.left;
                while (predecessor.right != null && predecessor.right != current) {
                    predecessor = predecessor.right;
                }

                if (predecessor.right == null) {
                    // Establish temporary link
                    predecessor.right = current;
                    current = current.left;
                } else {
                    // Restore tree structure
                    predecessor.right = null;

                    // Process the current node
                    TreeNode temp = current.left;
                    current.left = null;
                    current.right = temp;

                    current = current.right;
                }
            } else {
                current = current.right; // Move to the right subtree
            }
        }
    }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
class Solution:
    def flatten(self, root):
        current = root

        while current:
            if current.left:
                # Find predecessor (rightmost node in the left subtree)
                predecessor = current.left
                while predecessor.right and predecessor.right != current:
                    predecessor = predecessor.right

                if not predecessor.right:
                    # Establish temporary link
                    predecessor.right = current
                    current = current.left
                else:
                    # Restore tree structure
                    predecessor.right = None

                    # Process current node
                    temp = current.left
                    current.left = None
                    current.right = temp

                    current = current.right
            else:
                current = current.right

Complexity

  • ⏰ Time complexity: O(n). Each node is visited twice — once to set up the temporary pointer and once to process the node.
  • 🧺 Space complexity: O(1). No extra data structure is used.