Problem
There is an m x n
binary grid matrix
with all the values set 0
initially. Design an algorithm to randomly pick an index (i, j)
where matrix[i][j] == 0
and flips it to 1
. All the indices (i, j)
where matrix[i][j] == 0
should be equally likely to be returned.
Optimize your algorithm to minimize the number of calls made to the built-in random function of your language and optimize the time and space complexity.
Implement the Solution
class:
Solution(int m, int n)
Initializes the object with the size of the binary matrixm
andn
.int[] flip()
Returns a random index[i, j]
of the matrix wherematrix[i][j] == 0
and flips it to1
.void reset()
Resets all the values of the matrix to be0
.
Examples
Example 1:
Input:
["Solution", "flip", "flip", "flip", "reset", "flip"]
[[3, 1], [], [], [], [], []]
Output:
[null, [1, 0], [2, 0], [0, 0], null, [2, 0
Explanation
Solution solution = new Solution(3, 1);
solution.flip(); // return [1, 0], [0,0], [1,0], and [2,0] should be equally likely to be returned.
solution.flip(); // return [2, 0], Since [1,0] was returned, [2,0] and [0,0]
solution.flip(); // return [0, 0], Based on the previously returned indices, only [0,0] can be returned.
solution.reset(); // All the values are reset to 0 and can be returned.
solution.flip(); // return [2, 0], [0,0], [1,0], and [2,0] should be equally likely to be returned.
Solution
Method 1 - Treat 2D Array as 1D Array and Use Fisher Yates
We model the matrix as a 1D array with an initial size of row * cols
. For each flip, randomly choose an index from 0 to size-1
and flip it. Once an element is flipped, reduce the total count by 1:
int r = rand.nextInt(total--);
Next, swap the flipped element with the tail element (index: size-1
), record this mapping (key: original index, value: tail index) in a Map, and decrease the size:
map.put(r, map.getOrDefault(total, total));
In subsequent flips, if the same index is picked, use the map to find the actual element stored at that index:
int idx = map.getOrDefault(r, r);
Code
Java
class Solution {
private Map<Integer, Integer> map;
private int rows, cols, total;
private Random rand;
public Solution(int m, int n) {
map = new HashMap<>();
rand = new Random();
rows = m;
cols = n;
total = m * n;
}
public int[] flip() {
// generate index, decrease total number of values
int r = rand.nextInt(total--);
// check if we have already put something at this index
int x = map.getOrDefault(r, r);
// swap - put total at index that we generated
map.put(r, map.getOrDefault(total, total));
return new int[]{x / cols, x % cols};
}
public void reset() {
map.clear();
total = rows * cols;
}
}
Dry Run and Explanation
There are only 3 possible scenarios.
For instance, if row = 2
, cols = 3
, and total = 6
:
The randomly generated number is the last one:
- Sequence:
0 1 2 3 4 5 6
- Random value:
r = 6
total
is now5
6
is not in themap
, thusx = 6
- Update
map
withmap[6] = 5
- This is straightforward. In future, numbers will be generated between
0-5
, ensuring6
is not repeated. - Chosen value:
{6}
- Sequence:
The randomly generated number hasn’t appeared before:
- Remaining numbers:
0 1 2 3 4 5
- Random value:
r = 2
total
is now4
2
is not in themap
, thusx = 2
- Update
map
withmap[6] = 5
andmap[2] = 4
- Chosen values:
{6, 2}
- Remaining numbers:
The randomly generated number has appeared before:
- Remaining numbers:
0 1 2 3 4
- Random value:
r = 2
, which repeats total
is now3
- Since
map[2] = 4
,x = 4
- Update
map
withmap[6] = 5
andmap[2] = 3
- Chosen values:
{6, 2, 4}
- Remaining numbers: