diff --git a/neighbours/__init__.py b/neighbours/__init__.py index 4a2c27f..1644ecb 100644 --- a/neighbours/__init__.py +++ b/neighbours/__init__.py @@ -1 +1,3 @@ from .knn_classifier import * +from . import distance +from . import kernel diff --git a/neighbours/distance.py b/neighbours/distance.py new file mode 100644 index 0000000..f850648 --- /dev/null +++ b/neighbours/distance.py @@ -0,0 +1,13 @@ +import numpy as np + + +def euclidean(x: np.ndarray, y: np.ndarray): + return np.linalg.norm(x - y) + + +def manhattan(x: np.ndarray, y: np.ndarray): + return np.sum(np.abs(x - y)) + + +def cosine(x: np.ndarray, y: np.ndarray): + return 1 - (x.dot(y) / (np.linalg.norm(x) * np.linalg.norm(y))) diff --git a/neighbours/kernel.py b/neighbours/kernel.py new file mode 100644 index 0000000..d688a07 --- /dev/null +++ b/neighbours/kernel.py @@ -0,0 +1,13 @@ +import numpy as np + + +def rectangular(x): + return (np.abs(x) <= 1) / 2 + + +def gaussian(x): + return 0.3989422804014326779 * np.exp(-2 * x * x) + + +def epanechnikov(x): + return (np.abs(x) <= 1) * 0.75 * (1 - (x * x)) diff --git a/neighbours/knn_classifier.py b/neighbours/knn_classifier.py index 79eec24..1a6048c 100644 --- a/neighbours/knn_classifier.py +++ b/neighbours/knn_classifier.py @@ -5,10 +5,11 @@ from .exceptions import * class KNNClassifier: - def __init__(self, features, trees_count, rpt_m): + def __init__(self, features, classes_count, trees_count, rpt_m): self.features = features self.forest = RPTForest(features, trees_count, rpt_m) self.classes = None + self.classes_count = classes_count def load(self, points, classes): if not isinstance(points, np.ndarray): @@ -29,5 +30,12 @@ class KNNClassifier: self.forest.load(points) - def predict(self, point: np.ndarray): - pass + def predict(self, point: np.ndarray, distance, kernel, h): + nearest_point_indexes = self.forest.get_neighbours(point) + + votes = np.zeros(self.classes_count) + + for point_ix in nearest_point_indexes: + votes[self.classes[point_ix]] += kernel(distance(point, self.forest.get_point(point_ix)) / h) + + return np.argmax(votes) diff --git a/neighbours/rp_neighbours.py b/neighbours/rp_neighbours.py index e92a945..398384b 100644 --- a/neighbours/rp_neighbours.py +++ b/neighbours/rp_neighbours.py @@ -68,7 +68,7 @@ class RPTForest: """ def __init__(self, features, trees_count, m): - """Creates new RPTForest + """Creates new random projection tree forest :param features: number of features in each sample :param trees_count: number of trees in the forest