Add KNNClassifier, exceptions

This commit is contained in:
hashlag
2024-01-29 00:12:26 +03:00
parent 4bd1db9191
commit 19d96c9176
4 changed files with 45 additions and 6 deletions

View File

@@ -1 +1 @@
from .rp_neighbours import *
from .knn_classifier import *

6
neighbours/exceptions.py Normal file
View File

@@ -0,0 +1,6 @@
class InvalidDimensionError(Exception):
pass
class InvalidType(Exception):
pass

View File

@@ -0,0 +1,33 @@
import numpy as np
from .rp_neighbours import *
from .exceptions import *
class KNNClassifier:
def __init__(self, features, trees_count, rpt_m):
self.features = features
self.forest = RPTForest(features, trees_count, rpt_m)
self.classes = None
def load(self, points, classes):
if not isinstance(points, np.ndarray):
raise InvalidType("points should be represented as np.ndarray")
if not isinstance(classes, np.ndarray) and not isinstance(classes, list):
raise InvalidType("classes should be represented as np.ndarray or list")
self.classes = classes
if points.ndim != 2:
raise InvalidDimensionError("points array should be two-dimensional")
if points.shape[1] != self.features:
raise InvalidDimensionError(
"invalid number of features in sample (expected {}, got {})".format(self.features, points.shape[1])
)
self.forest.load(points)
def predict(self, point: np.ndarray):
pass

View File

@@ -85,20 +85,20 @@ class RPTForest:
def get_point(self, ix):
"""Returns stored point by index
:param ix: point index
:param ix: target point index
:return: np.ndarray representing the point
"""
return self.points[ix]
def load(self, points: list) -> None:
def load(self, points: np.ndarray) -> None:
"""Loads a list of points and builds the corresponding forest
:param points: list of points
:param points: numpy.ndarray of points
"""
self.trees.clear()
self.points = np.array(points)
self.points = points
ixs = list(range(len(self.points)))
@@ -110,7 +110,7 @@ class RPTForest:
:param root: root of a random projection tree to search
:param point: target point
:return: set of points located in the same region of space
:return: set of indexes of points located in the same region of space
"""
while isinstance(root, SplittingNode):