You've already forked neighbours
add some docstrings
This commit is contained in:
@@ -2,12 +2,22 @@ import numpy as np
|
|||||||
|
|
||||||
|
|
||||||
def euclidean(x: np.ndarray, y: np.ndarray):
|
def euclidean(x: np.ndarray, y: np.ndarray):
|
||||||
|
"""Calculate the Euclidean distance between two vectors
|
||||||
|
|
||||||
|
:param x: vector as np.ndarray
|
||||||
|
:param y: vector as np.ndarray
|
||||||
|
:return: Euclidean distance between two vectors
|
||||||
|
"""
|
||||||
|
|
||||||
return np.linalg.norm(x - y)
|
return np.linalg.norm(x - y)
|
||||||
|
|
||||||
|
|
||||||
def manhattan(x: np.ndarray, y: np.ndarray):
|
def manhattan(x: np.ndarray, y: np.ndarray):
|
||||||
|
"""Calculate the Manhattan distance between two vectors
|
||||||
|
|
||||||
|
:param x: vector as np.ndarray
|
||||||
|
:param y: vector as np.ndarray
|
||||||
|
:return: Manhattan distance between two vectors
|
||||||
|
"""
|
||||||
|
|
||||||
return np.sum(np.abs(x - y))
|
return np.sum(np.abs(x - y))
|
||||||
|
|
||||||
|
|
||||||
def cosine(x: np.ndarray, y: np.ndarray):
|
|
||||||
return 1 - (x.dot(y) / (np.linalg.norm(x) * np.linalg.norm(y)))
|
|
||||||
|
|||||||
@@ -5,13 +5,40 @@ from .exceptions import *
|
|||||||
|
|
||||||
|
|
||||||
class KNNClassifier:
|
class KNNClassifier:
|
||||||
|
"""K-nearest neighbors classifier
|
||||||
|
|
||||||
|
Weighted kNN classifier based on random projection forest.
|
||||||
|
|
||||||
|
Supports different (including custom) smoothing kernels and distance metrics.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
features: number of features in each sample
|
||||||
|
forest: an instance of RPTForest
|
||||||
|
classes: an array of labels (integers from 0 to N) corresponding to loaded train points
|
||||||
|
classes_count: number of classes used
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, features, classes_count, trees_count, rpt_m):
|
def __init__(self, features, classes_count, trees_count, rpt_m):
|
||||||
|
"""Initializes new classifier
|
||||||
|
|
||||||
|
:param features: number of features in each sample
|
||||||
|
:param classes_count: number of classes used
|
||||||
|
:param trees_count: number of trees in the forest
|
||||||
|
:param rpt_m: maximum number of samples in one leaf of an RP tree
|
||||||
|
"""
|
||||||
|
|
||||||
self.features = features
|
self.features = features
|
||||||
self.forest = RPTForest(features, trees_count, rpt_m)
|
self.forest = RPTForest(features, trees_count, rpt_m)
|
||||||
self.classes = None
|
self.classes = None
|
||||||
self.classes_count = classes_count
|
self.classes_count = classes_count
|
||||||
|
|
||||||
def load(self, points, classes):
|
def load(self, points, classes):
|
||||||
|
"""Loads train data, builds a corresponding forest
|
||||||
|
|
||||||
|
:param points: np.ndarray of train samples
|
||||||
|
:param classes: an array of labels (integers from 0 to N) corresponding to loaded train points
|
||||||
|
"""
|
||||||
|
|
||||||
if not isinstance(points, np.ndarray):
|
if not isinstance(points, np.ndarray):
|
||||||
raise InvalidType("points should be represented as np.ndarray")
|
raise InvalidType("points should be represented as np.ndarray")
|
||||||
|
|
||||||
@@ -31,6 +58,15 @@ class KNNClassifier:
|
|||||||
self.forest.load(points)
|
self.forest.load(points)
|
||||||
|
|
||||||
def predict(self, point: np.ndarray, distance, kernel, h):
|
def predict(self, point: np.ndarray, distance, kernel, h):
|
||||||
|
"""Predict class of given sample
|
||||||
|
|
||||||
|
:param point: target point as np.ndarray
|
||||||
|
:param distance: distance metric function (from neighbours.distance)
|
||||||
|
:param kernel: smoothing kernel function (from neighbours.kernel)
|
||||||
|
:param h: bandwidth
|
||||||
|
:return: predicted class
|
||||||
|
"""
|
||||||
|
|
||||||
nearest_point_indexes = self.forest.get_neighbours(point)
|
nearest_point_indexes = self.forest.get_neighbours(point)
|
||||||
|
|
||||||
votes = np.zeros(self.classes_count)
|
votes = np.zeros(self.classes_count)
|
||||||
|
|||||||
Reference in New Issue
Block a user