You've already forked neighbours
add some docstrings
This commit is contained in:
@@ -2,12 +2,22 @@ import numpy as np
|
||||
|
||||
|
||||
def euclidean(x: np.ndarray, y: np.ndarray):
|
||||
"""Calculate the Euclidean distance between two vectors
|
||||
|
||||
:param x: vector as np.ndarray
|
||||
:param y: vector as np.ndarray
|
||||
:return: Euclidean distance between two vectors
|
||||
"""
|
||||
|
||||
return np.linalg.norm(x - y)
|
||||
|
||||
|
||||
def manhattan(x: np.ndarray, y: np.ndarray):
|
||||
"""Calculate the Manhattan distance between two vectors
|
||||
|
||||
:param x: vector as np.ndarray
|
||||
:param y: vector as np.ndarray
|
||||
:return: Manhattan distance between two vectors
|
||||
"""
|
||||
|
||||
return np.sum(np.abs(x - y))
|
||||
|
||||
|
||||
def cosine(x: np.ndarray, y: np.ndarray):
|
||||
return 1 - (x.dot(y) / (np.linalg.norm(x) * np.linalg.norm(y)))
|
||||
|
||||
@@ -5,13 +5,40 @@ from .exceptions import *
|
||||
|
||||
|
||||
class KNNClassifier:
|
||||
"""K-nearest neighbors classifier
|
||||
|
||||
Weighted kNN classifier based on random projection forest.
|
||||
|
||||
Supports different (including custom) smoothing kernels and distance metrics.
|
||||
|
||||
Attributes:
|
||||
features: number of features in each sample
|
||||
forest: an instance of RPTForest
|
||||
classes: an array of labels (integers from 0 to N) corresponding to loaded train points
|
||||
classes_count: number of classes used
|
||||
"""
|
||||
|
||||
def __init__(self, features, classes_count, trees_count, rpt_m):
|
||||
"""Initializes new classifier
|
||||
|
||||
:param features: number of features in each sample
|
||||
:param classes_count: number of classes used
|
||||
:param trees_count: number of trees in the forest
|
||||
:param rpt_m: maximum number of samples in one leaf of an RP tree
|
||||
"""
|
||||
|
||||
self.features = features
|
||||
self.forest = RPTForest(features, trees_count, rpt_m)
|
||||
self.classes = None
|
||||
self.classes_count = classes_count
|
||||
|
||||
def load(self, points, classes):
|
||||
"""Loads train data, builds a corresponding forest
|
||||
|
||||
:param points: np.ndarray of train samples
|
||||
:param classes: an array of labels (integers from 0 to N) corresponding to loaded train points
|
||||
"""
|
||||
|
||||
if not isinstance(points, np.ndarray):
|
||||
raise InvalidType("points should be represented as np.ndarray")
|
||||
|
||||
@@ -31,6 +58,15 @@ class KNNClassifier:
|
||||
self.forest.load(points)
|
||||
|
||||
def predict(self, point: np.ndarray, distance, kernel, h):
|
||||
"""Predict class of given sample
|
||||
|
||||
:param point: target point as np.ndarray
|
||||
:param distance: distance metric function (from neighbours.distance)
|
||||
:param kernel: smoothing kernel function (from neighbours.kernel)
|
||||
:param h: bandwidth
|
||||
:return: predicted class
|
||||
"""
|
||||
|
||||
nearest_point_indexes = self.forest.get_neighbours(point)
|
||||
|
||||
votes = np.zeros(self.classes_count)
|
||||
|
||||
Reference in New Issue
Block a user