Skip to content

Commit 4623bb9

Browse files
committedDec 16, 2020
Add k-nearest neighbors algorithm.
1 parent b13291d commit 4623bb9

File tree

7 files changed

+190
-126
lines changed

7 files changed

+190
-126
lines changed
 

‎README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ a set of rules that precisely define a sequence of operations.
143143
* `B` [Caesar Cipher](src/algorithms/cryptography/caesar-cipher) - simple substitution cipher
144144
* **Machine Learning**
145145
* `B` [NanoNeuron](https://github.com/trekhleb/nano-neuron) - 7 simple JS functions that illustrate how machines can actually learn (forward/backward propagation)
146-
* `B` [KNN](src/algorithms/ML/KNN) - K Nearest Neighbors
146+
* `B` [k-NN](src/algorithms/ml/knn) - k-nearest neighbors classification algorithm
147147
* **Uncategorized**
148148
* `B` [Tower of Hanoi](src/algorithms/uncategorized/hanoi-tower)
149149
* `B` [Square Matrix Rotation](src/algorithms/uncategorized/square-matrix-rotation) - in-place algorithm

‎src/algorithms/ML/KNN/README.md

-23
This file was deleted.

‎src/algorithms/ML/KNN/__test__/knn.test.js

-42
This file was deleted.

‎src/algorithms/ML/KNN/knn.js

-60
This file was deleted.

‎src/algorithms/ml/knn/README.md

+41
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# k-Nearest Neighbors Algorithm
2+
3+
The **k-nearest neighbors algorithm (k-NN)** is a supervised Machine Learning algorithm. It's a classification algorithm, determining the class of a sample vector using a sample data.
4+
5+
In k-NN classification, the output is a class membership. An object is classified by a plurality vote of its neighbors, with the object being assigned to the class most common among its `k` nearest neighbors (`k` is a positive integer, typically small). If `k = 1`, then the object is simply assigned to the class of that single nearest neighbor.
6+
7+
The idea is to calculate the similarity between two data points on the basis of a distance metric. [Euclidean distance](https://en.wikipedia.org/wiki/Euclidean_distance) is used mostly for this task.
8+
9+
![Euclidean distance between two points](https://upload.wikimedia.org/wikipedia/commons/5/55/Euclidean_distance_2d.svg)
10+
11+
_Image source: [Wikipedia](https://en.wikipedia.org/wiki/Euclidean_distance)_
12+
13+
The algorithm is as follows:
14+
15+
1. Check for errors like invalid data/labels.
16+
2. Calculate the euclidean distance of all the data points in training data with the classification point
17+
3. Sort the distances of points along with their classes in ascending order
18+
4. Take the initial `K` classes and find the mode to get the most similar class
19+
5. Report the most similar class
20+
21+
Here is a visualization of k-NN classification for better understanding:
22+
23+
![KNN Visualization 1](https://upload.wikimedia.org/wikipedia/commons/e/e7/KnnClassification.svg)
24+
25+
_Image source: [Wikipedia](https://en.wikipedia.org/wiki/K-nearest_neighbors_algorithm)_
26+
27+
The test sample (green dot) should be classified either to blue squares or to red triangles. If `k = 3` (solid line circle) it is assigned to the red triangles because there are `2` triangles and only `1` square inside the inner circle. If `k = 5` (dashed line circle) it is assigned to the blue squares (`3` squares vs. `2` triangles inside the outer circle).
28+
29+
Another k-NN classification example:
30+
31+
![KNN Visualization 2](https://media.geeksforgeeks.org/wp-content/uploads/graph2-2.png)
32+
33+
_Image source: [GeeksForGeeks](https://media.geeksforgeeks.org/wp-content/uploads/graph2-2.png)_
34+
35+
Here, as we can see, the classification of unknown points will be judged by their proximity to other points.
36+
37+
It is important to note that `K` is preferred to have odd values in order to break ties. Usually `K` is taken as `3` or `5`.
38+
39+
## References
40+
41+
- [k-nearest neighbors algorithm on Wikipedia](https://en.wikipedia.org/wiki/K-nearest_neighbors_algorithm)
+71
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
import kNN from '../kNN';
2+
3+
describe('kNN', () => {
4+
it('should throw an error on invalid data', () => {
5+
expect(() => {
6+
kNN();
7+
}).toThrowError('Either dataSet or labels or toClassify were not set');
8+
});
9+
10+
it('should throw an error on invalid labels', () => {
11+
const noLabels = () => {
12+
kNN([[1, 1]]);
13+
};
14+
expect(noLabels).toThrowError('Either dataSet or labels or toClassify were not set');
15+
});
16+
17+
it('should throw an error on not giving classification vector', () => {
18+
const noClassification = () => {
19+
kNN([[1, 1]], [1]);
20+
};
21+
expect(noClassification).toThrowError('Either dataSet or labels or toClassify were not set');
22+
});
23+
24+
it('should throw an error on not giving classification vector', () => {
25+
const inconsistent = () => {
26+
kNN([[1, 1]], [1], [1]);
27+
};
28+
expect(inconsistent).toThrowError('Inconsistent vector lengths');
29+
});
30+
31+
it('should find the nearest neighbour', () => {
32+
let dataSet;
33+
let labels;
34+
let toClassify;
35+
let expectedClass;
36+
37+
dataSet = [[1, 1], [2, 2]];
38+
labels = [1, 2];
39+
toClassify = [1, 1];
40+
expectedClass = 1;
41+
expect(kNN(dataSet, labels, toClassify)).toBe(expectedClass);
42+
43+
dataSet = [[1, 1], [6, 2], [3, 3], [4, 5], [9, 2], [2, 4], [8, 7]];
44+
labels = [1, 2, 1, 2, 1, 2, 1];
45+
toClassify = [1.25, 1.25];
46+
expectedClass = 1;
47+
expect(kNN(dataSet, labels, toClassify)).toBe(expectedClass);
48+
49+
dataSet = [[1, 1], [6, 2], [3, 3], [4, 5], [9, 2], [2, 4], [8, 7]];
50+
labels = [1, 2, 1, 2, 1, 2, 1];
51+
toClassify = [1.25, 1.25];
52+
expectedClass = 2;
53+
expect(kNN(dataSet, labels, toClassify, 5)).toBe(expectedClass);
54+
});
55+
56+
it('should find the nearest neighbour with equal distances', () => {
57+
const dataSet = [[0, 0], [1, 1], [0, 2]];
58+
const labels = [1, 3, 3];
59+
const toClassify = [0, 1];
60+
const expectedClass = 3;
61+
expect(kNN(dataSet, labels, toClassify)).toBe(expectedClass);
62+
});
63+
64+
it('should find the nearest neighbour in 3D space', () => {
65+
const dataSet = [[0, 0, 0], [0, 1, 1], [0, 0, 2]];
66+
const labels = [1, 3, 3];
67+
const toClassify = [0, 0, 1];
68+
const expectedClass = 3;
69+
expect(kNN(dataSet, labels, toClassify)).toBe(expectedClass);
70+
});
71+
});

‎src/algorithms/ml/knn/kNN.js

+77
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
/**
2+
* Calculates calculate the euclidean distance between 2 vectors.
3+
*
4+
* @param {number[]} x1
5+
* @param {number[]} x2
6+
* @returns {number}
7+
*/
8+
function euclideanDistance(x1, x2) {
9+
// Checking for errors.
10+
if (x1.length !== x2.length) {
11+
throw new Error('Inconsistent vector lengths');
12+
}
13+
// Calculate the euclidean distance between 2 vectors and return.
14+
let squaresTotal = 0;
15+
for (let i = 0; i < x1.length; i += 1) {
16+
squaresTotal += (x1[i] - x2[i]) ** 2;
17+
}
18+
return Number(Math.sqrt(squaresTotal).toFixed(2));
19+
}
20+
21+
/**
22+
* Classifies the point in space based on k-nearest neighbors algorithm.
23+
*
24+
* @param {number[][]} dataSet - array of data points, i.e. [[0, 1], [3, 4], [5, 7]]
25+
* @param {number[]} labels - array of classes (labels), i.e. [1, 1, 2]
26+
* @param {number[]} toClassify - the point in space that needs to be classified, i.e. [5, 4]
27+
* @param {number} k - number of nearest neighbors which will be taken into account (preferably odd)
28+
* @return {number} - the class of the point
29+
*/
30+
export default function kNN(
31+
dataSet,
32+
labels,
33+
toClassify,
34+
k = 3,
35+
) {
36+
if (!dataSet || !labels || !toClassify) {
37+
throw new Error('Either dataSet or labels or toClassify were not set');
38+
}
39+
40+
// Calculate distance from toClassify to each point for all dimensions in dataSet.
41+
// Store distance and point's label into distances list.
42+
const distances = [];
43+
for (let i = 0; i < dataSet.length; i += 1) {
44+
distances.push({
45+
dist: euclideanDistance(dataSet[i], toClassify),
46+
label: labels[i],
47+
});
48+
}
49+
50+
// Sort distances list (from closer point to further ones).
51+
// Take initial k values, count with class index
52+
const kNearest = distances.sort((a, b) => {
53+
if (a.dist === b.dist) {
54+
return 0;
55+
}
56+
return a.dist < b.dist ? -1 : 1;
57+
}).slice(0, k);
58+
59+
// Count the number of instances of each class in top k members.
60+
const labelsCounter = {};
61+
let topClass = 0;
62+
let topClassCount = 0;
63+
for (let i = 0; i < kNearest.length; i += 1) {
64+
if (kNearest[i].label in labelsCounter) {
65+
labelsCounter[kNearest[i].label] += 1;
66+
} else {
67+
labelsCounter[kNearest[i].label] = 1;
68+
}
69+
if (labelsCounter[kNearest[i].label] > topClassCount) {
70+
topClassCount = labelsCounter[kNearest[i].label];
71+
topClass = kNearest[i].label;
72+
}
73+
}
74+
75+
// Return the class with highest count.
76+
return topClass;
77+
}

0 commit comments

Comments
 (0)
Please sign in to comment.