Problem

Design a max stack that supports push, pop, top, peekMax and popMax.

  1. push(x) – Push element x onto stack.
  2. pop() – Remove the element on top of the stack and return it.
  3. top() – Get the element on the top.
  4. peekMax() – Retrieve the maximum element in the stack.
  5. popMax() – Retrieve the maximum element in the stack, and remove it. If you find more than one maximum elements, only remove the top-most one.

Similar Problem

Examples

Example 1:

MaxStack stack = new MaxStack();
stack.push(5);
stack.push(1);
stack.push(5);
stack.top(); -> 5
stack.popMax(); -> 5
stack.top(); -> 1
stack.peekMax(); -> 5
stack.pop(); -> 1
stack.top(); -> 5

Solution

Method 1 - Using 2 Stacks

This is not all that difficult if you think about it. Regular stacks support push() and pop() functions. We need to add a new read-only property, maximum. You can’t just add a new variable to track the maximum because when the stack is popped you wouldn’t be know how to update the maximum variable without scanning through all items in the stack which would require you to pop and re-push each item, which is downright ugly. So we need to keep some auxiliary list of the stack items so we can update the maximum variable after we pop items.

A clever way to do that is to use another stack that stores the maximum of all items in the stack. The maximum value is calculated when items are pushed to the stack and when items are popped from the stack we also pop from the maxStack to reveal the new maximum value.

Let’s call this stack max_stack. This stack will always have the maximum value at the top. Modify ‘push’ and ‘pop’ operations as follows:

push:

  • If the element being pushed is less than the top element of max_stack then push it on ‘max_stack’ as well.
  • Else, push the top element of ‘max_stack’ again on max_stack.

pop:

  • Every time you pop an element from the original stack, pop from max_stack as well.

Example Run

Suppose, elements are pushed in the following order: 7 3 5 8 9 1 2

original_stack        max_stack
2                        8
1                        7
9                        7
8                        7
5                        7
3                        7
7                        7

You can see that at any stage, the max_stack can be queried for the maximum element in the stack.

Code

Java
class MaxStack {
    Stack<Integer> stack;
    Stack<Integer> max;

    /** initialize your data structure here. */
    public MaxStack() {
        stack = new Stack<>();
        max = new Stack<>();
    }
    
    public void push(int x) {
        stack.push(x);
        int maxVal = max.isEmpty() ? x : Math.max(max.peek(), x);
        max.push(maxVal);
    }
    
    public int pop() {
        max.pop();
        return stack.pop();
    }
    
    public int top() {
        return stack.peek();
    }
    
    public int peekMax() {
        return max.peek();
    }
    
    public int popMax() {
        int maxVal = peekMax();
        Stack<Integer> temp = new Stack<>();
        while(top() != maxVal) temp.push(pop());
        pop();
        
        while(!temp.isEmpty()) push(temp.pop());
        
        return maxVal;
    }
}

Method 2 - Using TreeNode and Doubly linked list

We can use:

  • Use TreeMap to store <Int, List of Nodes>, which gives O(logN) insert, delete and find MAX
  • Build DoubleLinkedList class to perform O(1) removal
  • The problem becomes finding the target value & remove from DoubleLinkedList

Code

Java
class MaxStack {
    TreeMap<Integer, List<Node>> map;
    DoubleLinkedList dll;

    public MaxStack() {
        map = new TreeMap();
        dll = new DoubleLinkedList();
    }

    // O(1)
    public void push(int x) {
        Node node = dll.add(x);
        map.putIfAbsent(x, new ArrayList<Node>());
        map.get(x).add(node);
    }

    // O(1)
    public int pop() {
        int val = dll.pop();
        removeFromMap(val);
        return val;
    }

    // O(1)
    public int top() {
        return dll.peek();
    }

    // O(logN)
    public int peekMax() {
        return map.lastKey();
    }

    // O(1)
    public int popMax() {
        int max = peekMax();
        Node node = removeFromMap(max);
        dll.unlink(node);
        return max;
    }
    
    // Find val from map, remove it from list, & remove list if empty
    // O(1)
    private Node removeFromMap(int val) {
        List<Node> list = map.get(val);
        Node node = list.remove(list.size() - 1);
        if (list.isEmpty()) map.remove(val);
        return node;
    }
    
    // Define DoubleLinkedList class
    class DoubleLinkedList {
        Node head, tail;

        public DoubleLinkedList() {
            head = new Node(0);
            tail = new Node(0);
            head.next = tail;
            tail.prev = head;
        }

        public Node add(int val) {
            Node x = new Node(val);
            x.next = tail;
            x.prev = tail.prev;
            tail.prev = tail.prev.next = x; // append to tail
            return x;
        }

        public int pop() {
            return unlink(tail.prev).val;
        }

        public int peek() {
            return tail.prev.val;
        }

        public Node unlink(Node node) {
            node.prev.next = node.next;
            node.next.prev = node.prev;
            return node;
        }
    }

    class Node {
        int val;
        Node prev, next;
        public Node(int v) {val = v;}
    }
}