minimum viable classifier version

This commit is contained in:
hashlag
2024-01-30 23:30:01 +03:00
parent 19d96c9176
commit 00dcff8623
5 changed files with 40 additions and 4 deletions

View File

@@ -1 +1,3 @@
from .knn_classifier import * from .knn_classifier import *
from . import distance
from . import kernel

13
neighbours/distance.py Normal file
View File

@@ -0,0 +1,13 @@
import numpy as np
def euclidean(x: np.ndarray, y: np.ndarray):
return np.linalg.norm(x - y)
def manhattan(x: np.ndarray, y: np.ndarray):
return np.sum(np.abs(x - y))
def cosine(x: np.ndarray, y: np.ndarray):
return 1 - (x.dot(y) / (np.linalg.norm(x) * np.linalg.norm(y)))

13
neighbours/kernel.py Normal file
View File

@@ -0,0 +1,13 @@
import numpy as np
def rectangular(x):
return (np.abs(x) <= 1) / 2
def gaussian(x):
return 0.3989422804014326779 * np.exp(-2 * x * x)
def epanechnikov(x):
return (np.abs(x) <= 1) * 0.75 * (1 - (x * x))

View File

@@ -5,10 +5,11 @@ from .exceptions import *
class KNNClassifier: class KNNClassifier:
def __init__(self, features, trees_count, rpt_m): def __init__(self, features, classes_count, trees_count, rpt_m):
self.features = features self.features = features
self.forest = RPTForest(features, trees_count, rpt_m) self.forest = RPTForest(features, trees_count, rpt_m)
self.classes = None self.classes = None
self.classes_count = classes_count
def load(self, points, classes): def load(self, points, classes):
if not isinstance(points, np.ndarray): if not isinstance(points, np.ndarray):
@@ -29,5 +30,12 @@ class KNNClassifier:
self.forest.load(points) self.forest.load(points)
def predict(self, point: np.ndarray): def predict(self, point: np.ndarray, distance, kernel, h):
pass nearest_point_indexes = self.forest.get_neighbours(point)
votes = np.zeros(self.classes_count)
for point_ix in nearest_point_indexes:
votes[self.classes[point_ix]] += kernel(distance(point, self.forest.get_point(point_ix)) / h)
return np.argmax(votes)

View File

@@ -68,7 +68,7 @@ class RPTForest:
""" """
def __init__(self, features, trees_count, m): def __init__(self, features, trees_count, m):
"""Creates new RPTForest """Creates new random projection tree forest
:param features: number of features in each sample :param features: number of features in each sample
:param trees_count: number of trees in the forest :param trees_count: number of trees in the forest