You've already forked neighbours
minimum viable classifier version
This commit is contained in:
@@ -1 +1,3 @@
|
||||
from .knn_classifier import *
|
||||
from . import distance
|
||||
from . import kernel
|
||||
|
||||
13
neighbours/distance.py
Normal file
13
neighbours/distance.py
Normal file
@@ -0,0 +1,13 @@
|
||||
import numpy as np
|
||||
|
||||
|
||||
def euclidean(x: np.ndarray, y: np.ndarray):
|
||||
return np.linalg.norm(x - y)
|
||||
|
||||
|
||||
def manhattan(x: np.ndarray, y: np.ndarray):
|
||||
return np.sum(np.abs(x - y))
|
||||
|
||||
|
||||
def cosine(x: np.ndarray, y: np.ndarray):
|
||||
return 1 - (x.dot(y) / (np.linalg.norm(x) * np.linalg.norm(y)))
|
||||
13
neighbours/kernel.py
Normal file
13
neighbours/kernel.py
Normal file
@@ -0,0 +1,13 @@
|
||||
import numpy as np
|
||||
|
||||
|
||||
def rectangular(x):
|
||||
return (np.abs(x) <= 1) / 2
|
||||
|
||||
|
||||
def gaussian(x):
|
||||
return 0.3989422804014326779 * np.exp(-2 * x * x)
|
||||
|
||||
|
||||
def epanechnikov(x):
|
||||
return (np.abs(x) <= 1) * 0.75 * (1 - (x * x))
|
||||
@@ -5,10 +5,11 @@ from .exceptions import *
|
||||
|
||||
|
||||
class KNNClassifier:
|
||||
def __init__(self, features, trees_count, rpt_m):
|
||||
def __init__(self, features, classes_count, trees_count, rpt_m):
|
||||
self.features = features
|
||||
self.forest = RPTForest(features, trees_count, rpt_m)
|
||||
self.classes = None
|
||||
self.classes_count = classes_count
|
||||
|
||||
def load(self, points, classes):
|
||||
if not isinstance(points, np.ndarray):
|
||||
@@ -29,5 +30,12 @@ class KNNClassifier:
|
||||
|
||||
self.forest.load(points)
|
||||
|
||||
def predict(self, point: np.ndarray):
|
||||
pass
|
||||
def predict(self, point: np.ndarray, distance, kernel, h):
|
||||
nearest_point_indexes = self.forest.get_neighbours(point)
|
||||
|
||||
votes = np.zeros(self.classes_count)
|
||||
|
||||
for point_ix in nearest_point_indexes:
|
||||
votes[self.classes[point_ix]] += kernel(distance(point, self.forest.get_point(point_ix)) / h)
|
||||
|
||||
return np.argmax(votes)
|
||||
|
||||
@@ -68,7 +68,7 @@ class RPTForest:
|
||||
"""
|
||||
|
||||
def __init__(self, features, trees_count, m):
|
||||
"""Creates new RPTForest
|
||||
"""Creates new random projection tree forest
|
||||
|
||||
:param features: number of features in each sample
|
||||
:param trees_count: number of trees in the forest
|
||||
|
||||
Reference in New Issue
Block a user