Problem

There is an m x n binary grid matrix with all the values set 0 initially. Design an algorithm to randomly pick an index (i, j) where matrix[i][j] == 0 and flips it to 1. All the indices (i, j) where matrix[i][j] == 0 should be equally likely to be returned.

Optimize your algorithm to minimize the number of calls made to the built-in random function of your language and optimize the time and space complexity.

Implement the Solution class:

  • Solution(int m, int n) Initializes the object with the size of the binary matrix m and n.
  • int[] flip() Returns a random index [i, j] of the matrix where matrix[i][j] == 0 and flips it to 1.
  • void reset() Resets all the values of the matrix to be 0.

Examples

Example 1:

Input: 
["Solution", "flip", "flip", "flip", "reset", "flip"]
[[3, 1], [], [], [], [], []]
Output: 
[null, [1, 0], [2, 0], [0, 0], null, [2, 0

Explanation
Solution solution = new Solution(3, 1);
solution.flip();  // return [1, 0], [0,0], [1,0], and [2,0] should be equally likely to be returned.
solution.flip();  // return [2, 0], Since [1,0] was returned, [2,0] and [0,0]
solution.flip();  // return [0, 0], Based on the previously returned indices, only [0,0] can be returned.
solution.reset(); // All the values are reset to 0 and can be returned.
solution.flip();  // return [2, 0], [0,0], [1,0], and [2,0] should be equally likely to be returned.

Solution

Method 1 - Treat 2D Array as 1D Array and Use Fisher Yates

We model the matrix as a 1D array with an initial size of row * cols. For each flip, randomly choose an index from 0 to size-1 and flip it. Once an element is flipped, reduce the total count by 1:

int r = rand.nextInt(total--);

Next, swap the flipped element with the tail element (index: size-1), record this mapping (key: original index, value: tail index) in a Map, and decrease the size:

map.put(r, map.getOrDefault(total, total));

In subsequent flips, if the same index is picked, use the map to find the actual element stored at that index:

int idx = map.getOrDefault(r, r);

Code

Java
class Solution {
    private Map<Integer, Integer> map;
    private int rows, cols, total;
    private Random rand;
    public Solution(int m, int n) {
        map = new HashMap<>();
        rand = new Random();
        rows = m; 
        cols = n; 
        total = m * n;

    }
    
    public int[] flip() {
        // generate index, decrease total number of values
        int r = rand.nextInt(total--);
        // check if we have already put something at this index
        int x = map.getOrDefault(r, r);
        // swap - put total at index that we generated
        map.put(r, map.getOrDefault(total, total));
        return new int[]{x / cols, x % cols};
    }
    
    public void reset() {
        map.clear();
        total = rows * cols;
    }
}

Dry Run and Explanation

There are only 3 possible scenarios.

For instance, if row = 2cols = 3, and total = 6:

  1. The randomly generated number is the last one:

    • Sequence: 0 1 2 3 4 5 6
    • Random value: r = 6
    • total is now 5
    • 6 is not in the map, thus x = 6
    • Update map with map[6] = 5
    • This is straightforward. In future, numbers will be generated between 0-5, ensuring 6 is not repeated.
    • Chosen value: {6}
  2. The randomly generated number hasn’t appeared before:

    • Remaining numbers: 0 1 2 3 4 5
    • Random value: r = 2
    • total is now 4
    • 2 is not in the map, thus x = 2
    • Update map with map[6] = 5 and map[2] = 4
    • Chosen values: {6, 2}
  3. The randomly generated number has appeared before:

    • Remaining numbers: 0 1 2 3 4
    • Random value: r = 2, which repeats
    • total is now 3
    • Since map[2] = 4x = 4
    • Update map with map[6] = 5 and map[2] = 3
    • Chosen values: {6, 2, 4}