You've already forked neighbours
minimum viable classifier version
This commit is contained in:
@@ -1 +1,3 @@
|
|||||||
from .knn_classifier import *
|
from .knn_classifier import *
|
||||||
|
from . import distance
|
||||||
|
from . import kernel
|
||||||
|
|||||||
13
neighbours/distance.py
Normal file
13
neighbours/distance.py
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def euclidean(x: np.ndarray, y: np.ndarray):
|
||||||
|
return np.linalg.norm(x - y)
|
||||||
|
|
||||||
|
|
||||||
|
def manhattan(x: np.ndarray, y: np.ndarray):
|
||||||
|
return np.sum(np.abs(x - y))
|
||||||
|
|
||||||
|
|
||||||
|
def cosine(x: np.ndarray, y: np.ndarray):
|
||||||
|
return 1 - (x.dot(y) / (np.linalg.norm(x) * np.linalg.norm(y)))
|
||||||
13
neighbours/kernel.py
Normal file
13
neighbours/kernel.py
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def rectangular(x):
|
||||||
|
return (np.abs(x) <= 1) / 2
|
||||||
|
|
||||||
|
|
||||||
|
def gaussian(x):
|
||||||
|
return 0.3989422804014326779 * np.exp(-2 * x * x)
|
||||||
|
|
||||||
|
|
||||||
|
def epanechnikov(x):
|
||||||
|
return (np.abs(x) <= 1) * 0.75 * (1 - (x * x))
|
||||||
@@ -5,10 +5,11 @@ from .exceptions import *
|
|||||||
|
|
||||||
|
|
||||||
class KNNClassifier:
|
class KNNClassifier:
|
||||||
def __init__(self, features, trees_count, rpt_m):
|
def __init__(self, features, classes_count, trees_count, rpt_m):
|
||||||
self.features = features
|
self.features = features
|
||||||
self.forest = RPTForest(features, trees_count, rpt_m)
|
self.forest = RPTForest(features, trees_count, rpt_m)
|
||||||
self.classes = None
|
self.classes = None
|
||||||
|
self.classes_count = classes_count
|
||||||
|
|
||||||
def load(self, points, classes):
|
def load(self, points, classes):
|
||||||
if not isinstance(points, np.ndarray):
|
if not isinstance(points, np.ndarray):
|
||||||
@@ -29,5 +30,12 @@ class KNNClassifier:
|
|||||||
|
|
||||||
self.forest.load(points)
|
self.forest.load(points)
|
||||||
|
|
||||||
def predict(self, point: np.ndarray):
|
def predict(self, point: np.ndarray, distance, kernel, h):
|
||||||
pass
|
nearest_point_indexes = self.forest.get_neighbours(point)
|
||||||
|
|
||||||
|
votes = np.zeros(self.classes_count)
|
||||||
|
|
||||||
|
for point_ix in nearest_point_indexes:
|
||||||
|
votes[self.classes[point_ix]] += kernel(distance(point, self.forest.get_point(point_ix)) / h)
|
||||||
|
|
||||||
|
return np.argmax(votes)
|
||||||
|
|||||||
@@ -68,7 +68,7 @@ class RPTForest:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, features, trees_count, m):
|
def __init__(self, features, trees_count, m):
|
||||||
"""Creates new RPTForest
|
"""Creates new random projection tree forest
|
||||||
|
|
||||||
:param features: number of features in each sample
|
:param features: number of features in each sample
|
||||||
:param trees_count: number of trees in the forest
|
:param trees_count: number of trees in the forest
|
||||||
|
|||||||
Reference in New Issue
Block a user