You are given a 0-indexed array of positive integers w where w[i] describes the weight of the ith index.
You need to implement the function pickIndex(), which randomly picks an index in the range [0, w.length - 1] (inclusive) and returns it. The probability of picking an index i is w[i] / sum(w).
For example, if w = [1, 3], the probability of picking index 0 is 1 / (1 + 3) = 0.25 (i.e., 25%), and the probability of picking index 1 is 3 / (1 + 3) = 0.75 (i.e., 75%).
**Input**["Solution","pickIndex"][[[1]],[]]**Output**[null,0]**Explanation**Solution solution =new Solution([1]);solution.pickIndex();// return 0. The only option is to return 0 since there is only one element in w.
**Input**["Solution","pickIndex","pickIndex","pickIndex","pickIndex","pickIndex"][[[1,3]],[]]**Output**[null,1,1,1,1,0]**Explanation**Solution solution =new Solution([1,3]);solution.pickIndex();// return 1. It is returning the second element (index = 1) that has a probability of 3/4.
solution.pickIndex();// return 1
solution.pickIndex();// return 1
solution.pickIndex();// return 1
solution.pickIndex();// return 0. It is returning the first element (index = 0) that has a probability of 1/4.
Since thisis a randomization problem, multiple answers are allowed.All of the following outputs can be considered correct:[null,1,1,1,1,0][null,1,1,1,1,1][null,1,1,1,0,0][null,1,1,1,0,1][null,1,0,1,0,0]......and so on.
We can think of weights on a straight line. Suppose w = [1,2,3], then:
So, when we call pick 6 times (6 is total of weights), then index 0 should show once, index 1 twice and index 2 3 times.
Suppose w = [1,2,3], then, the segments are:
Index 0: [0,1)
Index 1: [1,3)
Index 2: [3,6)
So we can think of the line parts as buckets or bars in straight line, where bar length is weight[i]. So, we can just total numbers, and whichever part the number falls on the bar or bucket is our answer.
To find the bucket, we need binary search, hence we are using Treemap.
By using a TreeMap, we can map the cumulative sum of weights to their corresponding indices. When picking an index, we generate a random number and use the TreeMap to quickly find the segment (index) it falls into, ensuring each index is picked with probability proportional to its weight.
A random number in [0,6) will fall into these segments according to their weights. More on Treemap: Java TreeMap Class.
// TreeMap equivalent in C++ is std::map
#include<vector>#include<map>#include<random>usingnamespace std;
classSolution {
map<int, int> tree;
int total;
mt19937 rng{random_device{}()};
uniform_int_distribution<int> dist;
public: Solution(vector<int>& w) {
int sum =0;
for (int i =0; i < w.size(); ++i) {
sum += w[i];
tree[sum] = i;
}
total = sum;
dist = uniform_int_distribution<int>(1, total);
}
intpickIndex() {
int x = dist(rng);
auto it = tree.lower_bound(x);
return it->second;
}
};
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import java.util.*;
classSolution {
privatefinal TreeMap<Integer, Integer> tree =new TreeMap<>();
privatefinalint total;
privatefinal Random rand =new Random();
publicSolution(int[] w) {
int sum = 0;
for (int i = 0; i < w.length; i++) {
sum += w[i];
tree.put(sum, i);
}
total = sum;
}
publicintpickIndex() {
int x = rand.nextInt(total) + 1;
return tree.ceilingEntry(x).getValue();
}
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
import random
classSolution:
def__init__(self, w: list[int]) ->None:
self.tree: dict[int, int] = {}
self.total: int =0for i, num in enumerate(w):
self.total += num
self.tree[self.total] = i
defpickIndex(self) -> int:
x = random.randint(1, self.total)
keys = sorted(self.tree.keys())
l, r =0, len(keys) -1while l < r:
m = (l + r) //2if keys[m] < x:
l = m +1else:
r = m
return self.tree[keys[l]]