add some docstrings

This commit is contained in:
hashlag
2024-01-31 18:54:40 +03:00
parent 00dcff8623
commit 2e23aec1b4
2 changed files with 50 additions and 4 deletions

View File

@@ -2,12 +2,22 @@ import numpy as np
def euclidean(x: np.ndarray, y: np.ndarray):
"""Calculate the Euclidean distance between two vectors
:param x: vector as np.ndarray
:param y: vector as np.ndarray
:return: Euclidean distance between two vectors
"""
return np.linalg.norm(x - y)
def manhattan(x: np.ndarray, y: np.ndarray):
"""Calculate the Manhattan distance between two vectors
:param x: vector as np.ndarray
:param y: vector as np.ndarray
:return: Manhattan distance between two vectors
"""
return np.sum(np.abs(x - y))
def cosine(x: np.ndarray, y: np.ndarray):
return 1 - (x.dot(y) / (np.linalg.norm(x) * np.linalg.norm(y)))

View File

@@ -5,13 +5,40 @@ from .exceptions import *
class KNNClassifier:
"""K-nearest neighbors classifier
Weighted kNN classifier based on random projection forest.
Supports different (including custom) smoothing kernels and distance metrics.
Attributes:
features: number of features in each sample
forest: an instance of RPTForest
classes: an array of labels (integers from 0 to N) corresponding to loaded train points
classes_count: number of classes used
"""
def __init__(self, features, classes_count, trees_count, rpt_m):
"""Initializes new classifier
:param features: number of features in each sample
:param classes_count: number of classes used
:param trees_count: number of trees in the forest
:param rpt_m: maximum number of samples in one leaf of an RP tree
"""
self.features = features
self.forest = RPTForest(features, trees_count, rpt_m)
self.classes = None
self.classes_count = classes_count
def load(self, points, classes):
"""Loads train data, builds a corresponding forest
:param points: np.ndarray of train samples
:param classes: an array of labels (integers from 0 to N) corresponding to loaded train points
"""
if not isinstance(points, np.ndarray):
raise InvalidType("points should be represented as np.ndarray")
@@ -31,6 +58,15 @@ class KNNClassifier:
self.forest.load(points)
def predict(self, point: np.ndarray, distance, kernel, h):
"""Predict class of given sample
:param point: target point as np.ndarray
:param distance: distance metric function (from neighbours.distance)
:param kernel: smoothing kernel function (from neighbours.kernel)
:param h: bandwidth
:return: predicted class
"""
nearest_point_indexes = self.forest.get_neighbours(point)
votes = np.zeros(self.classes_count)