Problem

Given a linked list where each node has a next pointer and a child pointer that can point to a separate, nested list, flatten the list so that all nodes appear in a single-level, depth-first order.

Examples

Example 1:

 Input: head = [1, [2, [5, 6] ], 3, 4]
  1 -> 2 -> 3 -> 4
       |
       5 -> 6
 Output: 1 -> 2 -> 5 -> 6 -> 3 -> 4

Example 2:

 Input: head = [1, [2, [5, [6, [8 ], 7] ], 3, 4]
  1 -> 2 -> 3 -> 4
       |
       5 -> 6 -> 7
            |
            8
 Output: 1 -> 2 -> 5 -> 6 -> 8 -> 7 -> 3 -> 4

Solution

Method 1 - Using Stack

Here are the steps:

  1. Initialize a Stack:
    • Use a stack to help with the depth-first traversal of the list.
  2. Traverse the List:
    • Push nodes onto the stack when encountering a child pointer.
    • Continue traversal using next pointers if no child is encountered.
  3. Flatten the List:
    • Pop nodes from the stack and continue linking them using the next pointer to create a single-level list.

Code

Java
class MultiLevelListNode {
    int val;
    MultiLevelListNode next;
    MultiLevelListNode child;
    MultiLevelListNode(int x) {
        val = x;
    }
}

public class Solution {
    public MultiLevelListNode flatten(MultiLevelListNode head) {
        if (head == null)
            return null;

        MultiLevelListNode dummy = new MultiLevelListNode(0);
        MultiLevelListNode prev = dummy;
        Stack<MultiLevelListNode> stack = new Stack<>();
        stack.push(head);

        while (!stack.isEmpty()) {
            MultiLevelListNode curr = stack.pop();
            prev.next = curr;
            prev = curr;

            if (curr.next != null) {
                stack.push(curr.next);
            }

            if (curr.child != null) {
                stack.push(curr.child);
                curr.child = null;
            }
        }

        return dummy.next;
    }

    // Helper function to create a multi-level list example for testing
    public static MultiLevelListNode createExampleList() {
        MultiLevelListNode node1 = new MultiLevelListNode(1);
        MultiLevelListNode node2 = new MultiLevelListNode(2);
        MultiLevelListNode node3 = new MultiLevelListNode(3);
        MultiLevelListNode node4 = new MultiLevelListNode(4);
        MultiLevelListNode node5 = new MultiLevelListNode(5);
        MultiLevelListNode node6 = new MultiLevelListNode(6);

        node1.next = node2;
        node2.next = node3;
        node3.next = node4;
        node2.child = node5;
        node5.next = node6;

        return node1;
    }

    // Helper function to print the multi-level linked list
    public void printLinkedList(MultiLevelListNode head) {
        MultiLevelListNode current = head;
        while (current != null) {
            System.out.print(current.val + " -> ");
            current = current.next;
        }
        System.out.println("None");
    }

    public static void main(String[] args) {
        Solution solution = new Solution();

        MultiLevelListNode head = Solution.createExampleList();

        System.out.println("Original List:");
        solution.printLinkedList(head);

        MultiLevelListNode flattenedHead = solution.flatten(head);

        System.out.println("Flattened List:");
        solution.printLinkedList(flattenedHead);
    }
}
Python
class MultiLevelListNode:
    def __init__(self, val=0, next=None, child=None):
        self.val = val
        self.next = next
        self.child = child


class Solution:
    def flatten(self, head: MultiLevelListNode) -> MultiLevelListNode:
        if not head:
            return None

        dummy = MultiLevelListNode(0)
        prev = dummy
        stack = [head]

        while stack:
            curr = stack.pop()
            prev.next = curr
            prev = curr

            if curr.next:
                stack.append(curr.next)

            if curr.child:
                stack.append(curr.child)
                curr.child = None

        return dummy.next


# Helper function to create a multi-level list example for testing
def create_example_list():
    node1 = MultiLevelListNode(1)
    node2 = MultiLevelListNode(2)
    node3 = MultiLevelListNode(3)
    node4 = MultiLevelListNode(4)
    node5 = MultiLevelListNode(5)
    node6 = MultiLevelListNode(6)

    node1.next = node2
    node2.next = node3
    node3.next = node4
    node2.child = node5
    node5.next = node6

    return node1


# Helper function to print the multi-level linked list
def print_linked_list(head):
    current = head
    while current:
        print(f"{current.val} -> ", end="")
        current = current.next
    print("None")


# Example usage
head = create_example_list()
solution = Solution()

print("Original List:")
print_linked_list(head)

flattened_head = solution.flatten(head)

print("Flattened List:")
print_linked_list(flattened_head)

Complexity

  • ⏰ Time complexity: O(n), where n is the total number of nodes in the multi-level list. Each node is visited exactly once.
  • 🧺 Space complexity: O(n), for the stack used during the traversal.