From 19d96c9176775dae0b23bb3b861f0d20b3a01543 Mon Sep 17 00:00:00 2001 From: hashlag <90853356+hashlag@users.noreply.github.com> Date: Mon, 29 Jan 2024 00:12:26 +0300 Subject: [PATCH] Add KNNClassifier, exceptions --- neighbours/__init__.py | 2 +- neighbours/exceptions.py | 6 ++++++ neighbours/knn_classifier.py | 33 +++++++++++++++++++++++++++++++++ neighbours/rp_neighbours.py | 10 +++++----- 4 files changed, 45 insertions(+), 6 deletions(-) create mode 100644 neighbours/exceptions.py create mode 100644 neighbours/knn_classifier.py diff --git a/neighbours/__init__.py b/neighbours/__init__.py index d8c2701..4a2c27f 100644 --- a/neighbours/__init__.py +++ b/neighbours/__init__.py @@ -1 +1 @@ -from .rp_neighbours import * +from .knn_classifier import * diff --git a/neighbours/exceptions.py b/neighbours/exceptions.py new file mode 100644 index 0000000..e1e1d73 --- /dev/null +++ b/neighbours/exceptions.py @@ -0,0 +1,6 @@ +class InvalidDimensionError(Exception): + pass + + +class InvalidType(Exception): + pass diff --git a/neighbours/knn_classifier.py b/neighbours/knn_classifier.py new file mode 100644 index 0000000..79eec24 --- /dev/null +++ b/neighbours/knn_classifier.py @@ -0,0 +1,33 @@ +import numpy as np + +from .rp_neighbours import * +from .exceptions import * + + +class KNNClassifier: + def __init__(self, features, trees_count, rpt_m): + self.features = features + self.forest = RPTForest(features, trees_count, rpt_m) + self.classes = None + + def load(self, points, classes): + if not isinstance(points, np.ndarray): + raise InvalidType("points should be represented as np.ndarray") + + if not isinstance(classes, np.ndarray) and not isinstance(classes, list): + raise InvalidType("classes should be represented as np.ndarray or list") + + self.classes = classes + + if points.ndim != 2: + raise InvalidDimensionError("points array should be two-dimensional") + + if points.shape[1] != self.features: + raise InvalidDimensionError( + "invalid number of features in sample (expected {}, got {})".format(self.features, points.shape[1]) + ) + + self.forest.load(points) + + def predict(self, point: np.ndarray): + pass diff --git a/neighbours/rp_neighbours.py b/neighbours/rp_neighbours.py index 75717dc..e92a945 100644 --- a/neighbours/rp_neighbours.py +++ b/neighbours/rp_neighbours.py @@ -85,20 +85,20 @@ class RPTForest: def get_point(self, ix): """Returns stored point by index - :param ix: point index + :param ix: target point index :return: np.ndarray representing the point """ return self.points[ix] - def load(self, points: list) -> None: + def load(self, points: np.ndarray) -> None: """Loads a list of points and builds the corresponding forest - :param points: list of points + :param points: numpy.ndarray of points """ self.trees.clear() - self.points = np.array(points) + self.points = points ixs = list(range(len(self.points))) @@ -110,7 +110,7 @@ class RPTForest: :param root: root of a random projection tree to search :param point: target point - :return: set of points located in the same region of space + :return: set of indexes of points located in the same region of space """ while isinstance(root, SplittingNode):