Problem

You are given an integer array nums. Your task is to remove all elements from the array by performing one of the following operations at each step until nums is empty:

  • Choose any two elements from the first three elements of nums and remove them. The cost of this operation is the maximum of the two elements removed.
  • If fewer than three elements remain in nums, remove all the remaining elements in a single operation. The cost of this operation is the maximum of the remaining elements.

Return the minimum cost required to remove all the elements.

Example 1

1
2
3
4
5
6
7
8
Input: nums = [6,2,8,4]
Output: 12
Explanation:
Initially, `nums = [6, 2, 8, 4]`.
* In the first operation, remove `nums[0] = 6` and `nums[2] = 8` with a cost of `max(6, 8) = 8`. Now, `nums = [2, 4]`.
* In the second operation, remove the remaining elements with a cost of `max(2, 4) = 4`.
The cost to remove all elements is `8 + 4 = 12`. This is the minimum cost to
remove all elements in `nums`. Hence, the output is 12.

Example 2

1
2
3
4
5
6
7
8
Input: nums = [2,1,3,3]
Output: 5
Explanation:
Initially, `nums = [2, 1, 3, 3]`.
* In the first operation, remove `nums[0] = 2` and `nums[1] = 1` with a cost of `max(2, 1) = 2`. Now, `nums = [3, 3]`.
* In the second operation remove the remaining elements with a cost of `max(3, 3) = 3`.
The cost to remove all elements is `2 + 3 = 5`. This is the minimum cost to
remove all elements in `nums`. Hence, the output is 5.

Constraints

  • 1 <= nums.length <= 1000
  • 1 <= nums[i] <= 10^6

Examples

Solution

Method 1 – Dynamic Programming (Bitmask or Memoization)

Intuition

At each step, we can remove any two of the first three elements, paying the maximum of the two, or if fewer than three remain, remove all at once. This is a classic DP problem where the state is the current array. Since the array can be up to 16 elements (as per LeetCode constraints), we can use memoization (tuple or bitmask) to cache results.

Approach

  1. Use recursion with memoization (cache by tuple of current array or bitmask).
  2. If the array has 3 or more elements:
    • Try all pairs among the first three elements, remove them, pay the max, and recurse.
    • Take the minimum cost among all options.
  3. If the array has 2 or 1 elements, remove all and pay the max.
  4. Return the minimum cost for the current state.

Code

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
class Solution {
public:
    int minCost(vector<int>& nums) {
        unordered_map<string, int> memo;
        function<int(vector<int>&)> dp = [&](vector<int>& arr) -> int {
            if (arr.size() <= 2) return arr.empty() ? 0 : *max_element(arr.begin(), arr.end());
            string key;
            for (int x : arr) key += to_string(x) + ",";
            if (memo.count(key)) return memo[key];
            int res = INT_MAX;
            for (int i = 0; i < 3; ++i) {
                for (int j = i+1; j < 3; ++j) {
                    vector<int> nxt;
                    for (int k = 0; k < arr.size(); ++k) if (k != i && k != j) nxt.push_back(arr[k]);
                    res = min(res, max(arr[i], arr[j]) + dp(nxt));
                }
            }
            return memo[key] = res;
        };
        return dp(nums);
    }
};
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
func minCost(nums []int) int {
    memo := map[string]int{}
    var dp func([]int) int
    dp = func(arr []int) int {
        if len(arr) <= 2 {
            if len(arr) == 0 { return 0 }
            mx := arr[0]
            for _, x := range arr { if x > mx { mx = x } }
            return mx
        }
        key := fmt.Sprint(arr)
        if v, ok := memo[key]; ok { return v }
        res := 1<<31-1
        for i := 0; i < 3; i++ {
            for j := i+1; j < 3; j++ {
                nxt := []int{}
                for k := 0; k < len(arr); k++ { if k != i && k != j { nxt = append(nxt, arr[k]) } }
                cost := arr[i]
                if arr[j] > cost { cost = arr[j] }
                res = min(res, cost+dp(nxt))
            }
        }
        memo[key] = res
        return res
    }
    return dp(nums)
}
func min(a, b int) int { if a < b { return a } else { return b } }
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
class Solution {
    public int minCost(int[] nums) {
        Map<String, Integer> memo = new HashMap<>();
        return dp(nums, nums.length, memo);
    }
    private int dp(int[] arr, int n, Map<String, Integer> memo) {
        if (n <= 2) {
            if (n == 0) return 0;
            int mx = arr[0];
            for (int i = 1; i < n; i++) mx = Math.max(mx, arr[i]);
            return mx;
        }
        String key = Arrays.toString(Arrays.copyOf(arr, n));
        if (memo.containsKey(key)) return memo.get(key);
        int res = Integer.MAX_VALUE;
        for (int i = 0; i < 3; i++) {
            for (int j = i+1; j < 3; j++) {
                int[] nxt = new int[n-2];
                int idx = 0;
                for (int k = 0; k < n; k++) if (k != i && k != j) nxt[idx++] = arr[k];
                res = Math.min(res, Math.max(arr[i], arr[j]) + dp(nxt, n-2, memo));
            }
        }
        memo.put(key, res);
        return res;
    }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
class Solution {
    fun minCost(nums: IntArray): Int {
        val memo = mutableMapOf<String, Int>()
        fun dp(arr: IntArray, n: Int): Int {
            if (n <= 2) return if (n == 0) 0 else arr.maxOrNull()!!
            val key = arr.copyOf(n).joinToString(",")
            if (key in memo) return memo[key]!!
            var res = Int.MAX_VALUE
            for (i in 0 until 3) {
                for (j in i+1 until 3) {
                    val nxt = arr.filterIndexed { idx, _ -> idx != i && idx != j }.toIntArray()
                    res = minOf(res, maxOf(arr[i], arr[j]) + dp(nxt, n-2))
                }
            }
            memo[key] = res
            return res
        }
        return dp(nums, nums.size)
    }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
class Solution:
    def minCost(self, nums: list[int]) -> int:
        from functools import lru_cache
        n = len(nums)
        @lru_cache(None)
        def dp(state: tuple) -> int:
            arr = list(state)
            if len(arr) <= 2:
                return max(arr) if arr else 0
            res = float('inf')
            for i in range(3):
                for j in range(i+1, 3):
                    nxt = tuple(arr[k] for k in range(len(arr)) if k != i and k != j)
                    res = min(res, max(arr[i], arr[j]) + dp(nxt))
            return res
        return dp(tuple(nums))
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
impl Solution {
    pub fn min_cost(nums: Vec<i32>) -> i32 {
        use std::collections::HashMap;
        fn dp(arr: &[i32], memo: &mut HashMap<Vec<i32>, i32>) -> i32 {
            if arr.len() <= 2 {
                return *arr.iter().max().unwrap_or(&0);
            }
            if let Some(&v) = memo.get(arr) { return v; }
            let mut res = i32::MAX;
            for i in 0..3 {
                for j in i+1..3 {
                    let mut nxt = vec![];
                    for (k, &x) in arr.iter().enumerate() {
                        if k != i && k != j { nxt.push(x); }
                    }
                    res = res.min(arr[i].max(arr[j]) + dp(&nxt, memo));
                }
            }
            memo.insert(arr.to_vec(), res);
            res
        }
        let mut memo = HashMap::new();
        dp(&nums, &mut memo)
    }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
class Solution {
    minCost(nums: number[]): number {
        const memo = new Map<string, number>();
        function dp(arr: number[]): number {
            if (arr.length <= 2) return arr.length === 0 ? 0 : Math.max(...arr);
            const key = arr.join(",");
            if (memo.has(key)) return memo.get(key)!;
            let res = Infinity;
            for (let i = 0; i < 3; i++) {
                for (let j = i+1; j < 3; j++) {
                    const nxt = arr.filter((_, k) => k !== i && k !== j);
                    res = Math.min(res, Math.max(arr[i], arr[j]) + dp(nxt));
                }
            }
            memo.set(key, res);
            return res;
        }
        return dp(nums);
    }
}

Complexity

  • ⏰ Time complexity: O(3^n), where n is the length of nums. Each state branches up to 3 choose 2 = 3 ways, but memoization prunes repeated states. For n up to 16, this is feasible.
  • 🧺 Space complexity: O(3^n), for memoization cache.