Problem#
Given a linked list where each node has a next
pointer and a child
pointer that can point to a separate, nested list, flatten the list so that all nodes appear in a single-level, depth-first order.
Examples#
Example 1:
1
2
3
4
5
Input: head = [ 1 , [ 2 , [ 5 , 6 ]], 3 , 4 ]
1 -> 2 -> 3 -> 4
|
5 -> 6
Output: 1 -> 2 -> 5 -> 6 -> 3 -> 4
Example 2:
1
2
3
4
5
6
7
Input: head = [ 1 , [ 2 , [ 5 , [ 6 , [ 8 ], 7 ]], 3 , 4 ]
1 -> 2 -> 3 -> 4
|
5 -> 6 -> 7
|
8
Output: 1 -> 2 -> 5 -> 6 -> 8 -> 7 -> 3 -> 4
Solution#
Method 1 - Using Stack#
Here are the steps:
Initialize a Stack:
Use a stack to help with the depth-first traversal of the list.
Traverse the List:
Push nodes onto the stack when encountering a child
pointer.
Continue traversal using next
pointers if no child
is encountered.
Flatten the List:
Pop nodes from the stack and continue linking them using the next
pointer to create a single-level list.
Code#
Java
Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
class MultiLevelListNode {
int val;
MultiLevelListNode next;
MultiLevelListNode child;
MultiLevelListNode(int x) {
val = x;
}
}
public class Solution {
public MultiLevelListNode flatten (MultiLevelListNode head) {
if (head == null )
return null ;
MultiLevelListNode dummy = new MultiLevelListNode(0);
MultiLevelListNode prev = dummy;
Stack< MultiLevelListNode> stack = new Stack<> ();
stack.push (head);
while (! stack.isEmpty ()) {
MultiLevelListNode curr = stack.pop ();
prev.next = curr;
prev = curr;
if (curr.next != null ) {
stack.push (curr.next );
}
if (curr.child != null ) {
stack.push (curr.child );
curr.child = null ;
}
}
return dummy.next ;
}
// Helper function to create a multi-level list example for testing
public static MultiLevelListNode createExampleList () {
MultiLevelListNode node1 = new MultiLevelListNode(1);
MultiLevelListNode node2 = new MultiLevelListNode(2);
MultiLevelListNode node3 = new MultiLevelListNode(3);
MultiLevelListNode node4 = new MultiLevelListNode(4);
MultiLevelListNode node5 = new MultiLevelListNode(5);
MultiLevelListNode node6 = new MultiLevelListNode(6);
node1.next = node2;
node2.next = node3;
node3.next = node4;
node2.child = node5;
node5.next = node6;
return node1;
}
// Helper function to print the multi-level linked list
public void printLinkedList (MultiLevelListNode head) {
MultiLevelListNode current = head;
while (current != null ) {
System.out .print (current.val + " -> " );
current = current.next ;
}
System.out .println ("None" );
}
public static void main (String[] args) {
Solution solution = new Solution();
MultiLevelListNode head = Solution.createExampleList ();
System.out .println ("Original List:" );
solution.printLinkedList (head);
MultiLevelListNode flattenedHead = solution.flatten (head);
System.out .println ("Flattened List:" );
solution.printLinkedList (flattenedHead);
}
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
class MultiLevelListNode :
def __init__(self, val= 0 , next= None , child= None ):
self. val = val
self. next = next
self. child = child
class Solution :
def flatten (self, head: MultiLevelListNode) -> MultiLevelListNode:
if not head:
return None
dummy = MultiLevelListNode(0 )
prev = dummy
stack = [head]
while stack:
curr = stack. pop()
prev. next = curr
prev = curr
if curr. next:
stack. append(curr. next)
if curr. child:
stack. append(curr. child)
curr. child = None
return dummy. next
# Helper function to create a multi-level list example for testing
def create_example_list ():
node1 = MultiLevelListNode(1 )
node2 = MultiLevelListNode(2 )
node3 = MultiLevelListNode(3 )
node4 = MultiLevelListNode(4 )
node5 = MultiLevelListNode(5 )
node6 = MultiLevelListNode(6 )
node1. next = node2
node2. next = node3
node3. next = node4
node2. child = node5
node5. next = node6
return node1
# Helper function to print the multi-level linked list
def print_linked_list (head):
current = head
while current:
print(f " { current. val} -> " , end= "" )
current = current. next
print("None" )
# Example usage
head = create_example_list()
solution = Solution()
print("Original List:" )
print_linked_list(head)
flattened_head = solution. flatten(head)
print("Flattened List:" )
print_linked_list(flattened_head)
Complexity#
⏰ Time complexity: O(n)
, where n
is the total number of nodes in the multi-level list. Each node is visited exactly once.
🧺 Space complexity: O(n)
, for the stack used during the traversal.