Problem

You are given a doubly linked list, which contains nodes that have a next pointer, a previous pointer, and an additional child pointer. This child pointer may or may not point to a separate doubly linked list, also containing these special nodes. These child lists may have one or more children of their own, and so on, to produce a multilevel data structure as shown in the example below.

Given the head of the first level of the list, flatten the list so that all the nodes appear in a single-level, doubly linked list. Let curr be a node with a child list. The nodes in the child list should appear after curr and before curr.next in the flattened list.

Return the head of the flattened list. The nodes in the list must have all of their child pointers set to null.

Examples

Example 1:

Input: head = [1,2,3,4,5,6,null,null,null,7,8,9,10,null,null,11,12]
Output: [1,2,3,7,8,11,12,9,10,4,5,6]
Explanation: The multilevel linked list in the input is shown.

Example 2:

Input:

 1---2
     |
	 3

Output:

1 --- 3 --- 2
Input:
head = [1,2,null,3]
Output:
 [1,3,2]

Explanation: The multilevel linked list in the input is shown.

Example 3:

Input: head = []
Output: []

Explanation: There could be empty list in the input.

Constraints:

  • The number of Nodes will not exceed 1000.
  • 1 <= Node.val <= 105

How the multilevel linked list is represented in test cases:

We use the multilevel linked list from Example 1 above:

 1---2---3---4---5---6--NULL
         |
         7---8---9---10--NULL
             |
             11--12--NULL

The serialization of each level is as follows:

[1,2,3,4,5,6,null]
[7,8,9,10,null]
[11,12,null]

To serialize all levels together, we will add nulls in each level to signify no node connects to the upper node of the previous level. The serialization becomes:

[1, 2, 3, 4, 5, 6, null]
             |
[null, null, 7, 8, 9, 10, null]
                   |
[ null, 11, 12, null]

Merging the serialization of each level and removing trailing nulls we obtain:

[1,2,3,4,5,6,null,null,null,7,8,9,10,null,null,11,12]

Solution

Method 1 - Using Stack

Here is the approach:

  1. Use a Stack:
    • Utilize a stack to manage the traversal of nodes with children. Push nodes onto the stack when encountering a child pointer.
  2. Traverse and Flatten:
    • Traverse the list, and when a node with a child is found, push the next node (if it exists) onto the stack.
    • Set the child node as the next node, and continue traversing the child list.
    • Ensure the child pointers are set to null after flattening.
  3. Connecting Nodes:
    • Maintain proper connections between prevnext, and child pointers while updating the list to ensure a consistent doubly linked structure.
  4. Edge Cases:
    • Handle cases where the list is empty or only contains one level without any children.

Code

Java
public class Solution {
	static class Node {
	    public int val;
	    public Node prev;
	    public Node next;
	    public Node child;
	}
    public Node flatten(Node head) {
        if (head == null)
            return head;

        Node dummy = new Node();
        dummy.next = head;
        Node prev = dummy;

        Stack<Node> stack = new Stack<>();
        stack.push(head);

        while (!stack.isEmpty()) {
            Node curr = stack.pop();

            if (curr.next != null) {
                stack.push(curr.next);
            }

            if (curr.child != null) {
                stack.push(curr.child);
                curr.child = null;
            }

            prev.next = curr;
            curr.prev = prev;
            prev = curr;
        }

        dummy.next.prev = null;
        return dummy.next;
    }
}
Python
class Node:
    def __init__(self, val, prev=None, next=None, child=None):
        self.val = val
        self.prev = prev
        self.next = next
        self.child = child


class Solution:
    def flatten(self, head: "Node") -> "Node":
        if not head:
            return None

        dummy = Node(0)
        dummy.next = head
        stack = [head]
        prev = dummy

        while stack:
            curr = stack.pop()

            # push the next node and child node to the stack if they exist
            if curr.next:
                stack.append(curr.next)
            if curr.child:
                stack.append(curr.child)
                # we do not need the child pointer anymore
                curr.child = None

            prev.next = curr
            curr.prev = prev

            prev = curr

        # Detach the dummy node from the resultant list
        dummy.next.prev = None
        return dummy.next

Complexity

  • Time: O(n), where n is the number of nodes in the list. Each node is visited once.
  • Space: O(d), where d is the maximum depth of the children lists.