Skip to content

Commit 99945f3

Browse files
authoredOct 22, 2021
Add the "Weighted Random" algorithm (trekhleb#792)
* Add the link to the Weighted Random algorithm to the main README. * Add Weighted Random implementation and tests. * Add Weighted Random README. * Add Weighted Random README. * Add Weighted Random README.
1 parent d0576a2 commit 99945f3

File tree

4 files changed

+259
-0
lines changed

4 files changed

+259
-0
lines changed
 

‎README.md

+2
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,8 @@ a set of rules that precisely define a sequence of operations.
156156
* `B` [k-Means](src/algorithms/ml/k-means) - k-Means clustering algorithm
157157
* **Image Processing**
158158
* `B` [Seam Carving](src/algorithms/image-processing/seam-carving) - content-aware image resizing algorithm
159+
* **Statistics**
160+
* `B` [Weighted Random](src/algorithms/statistics/weighted-random) - select the random item from the list based on items' weights
159161
* **Uncategorized**
160162
* `B` [Tower of Hanoi](src/algorithms/uncategorized/hanoi-tower)
161163
* `B` [Square Matrix Rotation](src/algorithms/uncategorized/square-matrix-rotation) - in-place algorithm
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
# Weighted Random
2+
3+
## What is "Weighted Random"
4+
5+
Let's say you have a list of **items**. Item could be anything. For example, we may have a list of fruits and vegetables that you like to eat: `[ '🍌', '🍎', '🥕' ]`.
6+
7+
The list of **weights** represent the weight (or probability, or importance) of each item. Weights are numbers. For example, the weights like `[3, 7, 1]` would say that:
8+
9+
- you would like to eat `🍎 apples` more often (`7` out of `3 + 7 + 1 = 11` times),
10+
- then you would like to eat `bananas 🍌` less often (only `3` out of `11` times),
11+
- and the `carrots 🥕` you really don't like (want to eat it only `1` out of `11` times).
12+
13+
> If we speak in terms of probabilities than the weights list might be an array of floats that sum up to `1` (i.e. `[0.1, 0.5, 0.2, 0.2]`).
14+
15+
The **Weighted Random** in this case will be the function that will randomly return you the item from the list, and it will take each item's weight into account, so that items with the higher weight will be picked more often.
16+
17+
Example of the function interface:
18+
19+
```javascript
20+
const items = [ '🍌', '🍎', '🥕' ];
21+
const weights = [ 3, 7, 1 ];
22+
23+
function weightedRandom(items, weights) {
24+
// implementation goes here ...
25+
}
26+
27+
const nextSnackToEat = weightedRandom(items, weights); // Could be '🍎'
28+
```
29+
30+
## Applications of Weighted Random
31+
32+
- In [Genetic Algorithm](https://en.wikipedia.org/wiki/Genetic_algorithm) the weighted random is used during the "Selection" phase, when we need to select the fittest/strongest individuums based on their fitness score for mating and for producing the next stronger generation. You may find an **example** in the [Self-Parking Car in 500 Lines of Code](https://trekhleb.dev/blog/2021/self-parking-car-evolution/) article.
33+
- In [Recurrent Neural Networks (RNN)](https://en.wikipedia.org/wiki/Recurrent_neural_network) when trying to decide what letter to choose next (to form the sentence) based on the next letter probability. You may find an **example** in the [Recipe Generation using Recurrent Neural Network (RNN)](https://nbviewer.org/github/trekhleb/machine-learning-experiments/blob/master/experiments/recipe_generation_rnn/recipe_generation_rnn.ipynb) Jupyter notebook.
34+
- In [Nginx Load Balancing](https://docs.nginx.com/nginx/admin-guide/load-balancer/http-load-balancer/) to send HTTP requests more often to the servers with the higher weights.
35+
- And more...
36+
37+
## The Algorithm
38+
39+
The **straightforward approach** would be to:
40+
41+
1. Repeat each item in the list according to its weight.
42+
2. Pick the random item from the list.
43+
44+
For example in our case with fruits and vegetables we could generate the following list of size `3 + 7 + 1 = 11`:
45+
46+
```javascript
47+
const items = [ '🍌', '🍎', '🥕' ];
48+
const weights = [ 3, 7, 1 ];
49+
50+
// Repeating the items based on weights.
51+
const weightedItems = [
52+
'🍌', '🍌', '🍌',
53+
'🍎', '🍎', '🍎', '🍎', '🍎', '🍎', '🍎',
54+
'🥕',
55+
];
56+
57+
// And now just pick the random item from weightedItems array.
58+
```
59+
60+
However, as you may see, this approach may require a lot of memory, in case if the objects are heavy, and in case if we have a lot of them to repeat in `weightedItems` list.
61+
62+
The **more efficient approach** would be to:
63+
64+
1. Prepare the list of cumulative weights for each item (i.e. the `cumulativeWeights` list which will have the same number of elements as the original `weights` list). In our case it will look like this: `cumulativeWeights = [3, 3 + 7, 3 + 7 + 1] = [3, 10, 11]`
65+
2. Generate the random number `randomNumber` from `0` to the highest cumulative weight value. In our case the random number will be in a range of `[0..11]`. Let's say that we have `randomNumber = 8`.
66+
3. Go through the `cumulativeWeights` list from left to right and pick the first element which is higher or equal to the `randomNumber`. The index of such element we will use to pick the item from the `items` array.
67+
68+
The idea behind this approach is that the higher weights will "occupy" more numeric space. Therefore, there is a higher chance that the random number will fall into the "higher weight numeric bucket".
69+
70+
```javascript
71+
const weights = [3, 7, 1 ];
72+
const cumulativeWeights = [3, 10, 11];
73+
74+
// In a pseudo-representation we may think about the cumulativeWeights array like this.
75+
const pseudoCumulativeWeights = [
76+
1, 2, 3, // <-- [3] numbers
77+
4, 5, 6, 7, 8, 9, 10, // <-- [7] numbers
78+
11, // <-- [1] number
79+
];
80+
```
81+
82+
Here is an example of how the `weightedRandom` function might be implemented:
83+
84+
```javascript
85+
/**
86+
* Picks the random item based on its weight.
87+
* The items with higher weight will be picked more often (with a higher probability).
88+
*
89+
* For example:
90+
* - items = ['banana', 'orange', 'apple']
91+
* - weights = [0, 0.2, 0.8]
92+
* - weightedRandom(items, weights) in 80% of cases will return 'apple', in 20% of cases will return
93+
* 'orange' and it will never return 'banana' (because probability of picking the banana is 0%)
94+
*
95+
* @param {any[]} items
96+
* @param {number[]} weights
97+
* @returns {{item: any, index: number}}
98+
*/
99+
export default function weightedRandom(items, weights) {
100+
if (items.length !== weights.length) {
101+
throw new Error('Items and weights must be of the same size');
102+
}
103+
104+
if (!items.length) {
105+
throw new Error('Items must not be empty');
106+
}
107+
108+
// Preparing the cumulative weights array.
109+
// For example:
110+
// - weights = [1, 4, 3]
111+
// - cumulativeWeights = [1, 5, 8]
112+
const cumulativeWeights = [];
113+
for (let i = 0; i < weights.length; i += 1) {
114+
cumulativeWeights[i] = weights[i] + (cumulativeWeights[i - 1] || 0);
115+
}
116+
117+
// Getting the random number in a range of [0...sum(weights)]
118+
// For example:
119+
// - weights = [1, 4, 3]
120+
// - maxCumulativeWeight = 8
121+
// - range for the random number is [0...8]
122+
const maxCumulativeWeight = cumulativeWeights[cumulativeWeights.length - 1];
123+
const randomNumber = maxCumulativeWeight * Math.random();
124+
125+
// Picking the random item based on its weight.
126+
// The items with higher weight will be picked more often.
127+
for (let itemIndex = 0; itemIndex < items.length; itemIndex += 1) {
128+
if (cumulativeWeights[itemIndex] >= randomNumber) {
129+
return {
130+
item: items[itemIndex],
131+
index: itemIndex,
132+
};
133+
}
134+
}
135+
}
136+
```
137+
138+
## Implementation
139+
140+
- Check the [weightedRandom.js](weightedRandom.js) file for the implementation of the `weightedRandom()` function.
141+
- Check the [weightedRandom.test.js](__test__/weightedRandom.test.js) file for the tests-cases.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
import weightedRandom from '../weightedRandom';
2+
3+
describe('weightedRandom', () => {
4+
it('should throw an error when the number of weights does not match the number of items', () => {
5+
const getWeightedRandomWithInvalidInputs = () => {
6+
weightedRandom(['a', 'b', 'c'], [10, 0]);
7+
};
8+
expect(getWeightedRandomWithInvalidInputs).toThrow('Items and weights must be of the same size');
9+
});
10+
11+
it('should throw an error when the number of weights or items are empty', () => {
12+
const getWeightedRandomWithInvalidInputs = () => {
13+
weightedRandom([], []);
14+
};
15+
expect(getWeightedRandomWithInvalidInputs).toThrow('Items must not be empty');
16+
});
17+
18+
it('should correctly do random selection based on wights in straightforward cases', () => {
19+
expect(weightedRandom(['a', 'b', 'c'], [1, 0, 0])).toEqual({ index: 0, item: 'a' });
20+
expect(weightedRandom(['a', 'b', 'c'], [0, 1, 0])).toEqual({ index: 1, item: 'b' });
21+
expect(weightedRandom(['a', 'b', 'c'], [0, 0, 1])).toEqual({ index: 2, item: 'c' });
22+
expect(weightedRandom(['a', 'b', 'c'], [0, 1, 1])).not.toEqual({ index: 0, item: 'a' });
23+
expect(weightedRandom(['a', 'b', 'c'], [1, 0, 1])).not.toEqual({ index: 1, item: 'b' });
24+
expect(weightedRandom(['a', 'b', 'c'], [1, 1, 0])).not.toEqual({ index: 2, item: 'c' });
25+
});
26+
27+
it('should correctly do random selection based on wights', () => {
28+
// Number of times we're going to select the random items based on their weights.
29+
const ATTEMPTS_NUM = 1000;
30+
// The +/- delta in the number of times each item has been actually selected.
31+
// I.e. if we want the item 'a' to be selected 300 times out of 1000 cases (30%)
32+
// then 267 times is acceptable since it is bigger that 250 (which is 300 - 50)
33+
// ans smaller than 350 (which is 300 + 50)
34+
const THRESHOLD = 50;
35+
36+
const items = ['a', 'b', 'c']; // The actual items values don't matter.
37+
const weights = [0.1, 0.3, 0.6];
38+
39+
const counter = [];
40+
for (let i = 0; i < ATTEMPTS_NUM; i += 1) {
41+
const randomItem = weightedRandom(items, weights);
42+
if (!counter[randomItem.index]) {
43+
counter[randomItem.index] = 1;
44+
} else {
45+
counter[randomItem.index] += 1;
46+
}
47+
}
48+
49+
for (let itemIndex = 0; itemIndex < items.length; itemIndex += 1) {
50+
/*
51+
i.e. item with the index of 0 must be selected 100 times (ideally)
52+
or with the threshold of [100 - 50, 100 + 50] times.
53+
54+
i.e. item with the index of 1 must be selected 300 times (ideally)
55+
or with the threshold of [300 - 50, 300 + 50] times.
56+
57+
i.e. item with the index of 2 must be selected 600 times (ideally)
58+
or with the threshold of [600 - 50, 600 + 50] times.
59+
*/
60+
expect(counter[itemIndex]).toBeGreaterThan(ATTEMPTS_NUM * weights[itemIndex] - THRESHOLD);
61+
expect(counter[itemIndex]).toBeLessThan(ATTEMPTS_NUM * weights[itemIndex] + THRESHOLD);
62+
}
63+
});
64+
});
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
/**
2+
* Picks the random item based on its weight.
3+
* The items with higher weight will be picked more often (with a higher probability).
4+
*
5+
* For example:
6+
* - items = ['banana', 'orange', 'apple']
7+
* - weights = [0, 0.2, 0.8]
8+
* - weightedRandom(items, weights) in 80% of cases will return 'apple', in 20% of cases will return
9+
* 'orange' and it will never return 'banana' (because probability of picking the banana is 0%)
10+
*
11+
* @param {any[]} items
12+
* @param {number[]} weights
13+
* @returns {{item: any, index: number}}
14+
*/
15+
/* eslint-disable consistent-return */
16+
export default function weightedRandom(items, weights) {
17+
if (items.length !== weights.length) {
18+
throw new Error('Items and weights must be of the same size');
19+
}
20+
21+
if (!items.length) {
22+
throw new Error('Items must not be empty');
23+
}
24+
25+
// Preparing the cumulative weights array.
26+
// For example:
27+
// - weights = [1, 4, 3]
28+
// - cumulativeWeights = [1, 5, 8]
29+
const cumulativeWeights = [];
30+
for (let i = 0; i < weights.length; i += 1) {
31+
cumulativeWeights[i] = weights[i] + (cumulativeWeights[i - 1] || 0);
32+
}
33+
34+
// Getting the random number in a range of [0...sum(weights)]
35+
// For example:
36+
// - weights = [1, 4, 3]
37+
// - maxCumulativeWeight = 8
38+
// - range for the random number is [0...8]
39+
const maxCumulativeWeight = cumulativeWeights[cumulativeWeights.length - 1];
40+
const randomNumber = maxCumulativeWeight * Math.random();
41+
42+
// Picking the random item based on its weight.
43+
// The items with higher weight will be picked more often.
44+
for (let itemIndex = 0; itemIndex < items.length; itemIndex += 1) {
45+
if (cumulativeWeights[itemIndex] >= randomNumber) {
46+
return {
47+
item: items[itemIndex],
48+
index: itemIndex,
49+
};
50+
}
51+
}
52+
}

0 commit comments

Comments
 (0)
Please sign in to comment.