You've already forked neighbours
Add KNNClassifier, exceptions
This commit is contained in:
@@ -1 +1 @@
|
||||
from .rp_neighbours import *
|
||||
from .knn_classifier import *
|
||||
|
||||
6
neighbours/exceptions.py
Normal file
6
neighbours/exceptions.py
Normal file
@@ -0,0 +1,6 @@
|
||||
class InvalidDimensionError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class InvalidType(Exception):
|
||||
pass
|
||||
33
neighbours/knn_classifier.py
Normal file
33
neighbours/knn_classifier.py
Normal file
@@ -0,0 +1,33 @@
|
||||
import numpy as np
|
||||
|
||||
from .rp_neighbours import *
|
||||
from .exceptions import *
|
||||
|
||||
|
||||
class KNNClassifier:
|
||||
def __init__(self, features, trees_count, rpt_m):
|
||||
self.features = features
|
||||
self.forest = RPTForest(features, trees_count, rpt_m)
|
||||
self.classes = None
|
||||
|
||||
def load(self, points, classes):
|
||||
if not isinstance(points, np.ndarray):
|
||||
raise InvalidType("points should be represented as np.ndarray")
|
||||
|
||||
if not isinstance(classes, np.ndarray) and not isinstance(classes, list):
|
||||
raise InvalidType("classes should be represented as np.ndarray or list")
|
||||
|
||||
self.classes = classes
|
||||
|
||||
if points.ndim != 2:
|
||||
raise InvalidDimensionError("points array should be two-dimensional")
|
||||
|
||||
if points.shape[1] != self.features:
|
||||
raise InvalidDimensionError(
|
||||
"invalid number of features in sample (expected {}, got {})".format(self.features, points.shape[1])
|
||||
)
|
||||
|
||||
self.forest.load(points)
|
||||
|
||||
def predict(self, point: np.ndarray):
|
||||
pass
|
||||
@@ -85,20 +85,20 @@ class RPTForest:
|
||||
def get_point(self, ix):
|
||||
"""Returns stored point by index
|
||||
|
||||
:param ix: point index
|
||||
:param ix: target point index
|
||||
:return: np.ndarray representing the point
|
||||
"""
|
||||
|
||||
return self.points[ix]
|
||||
|
||||
def load(self, points: list) -> None:
|
||||
def load(self, points: np.ndarray) -> None:
|
||||
"""Loads a list of points and builds the corresponding forest
|
||||
|
||||
:param points: list of points
|
||||
:param points: numpy.ndarray of points
|
||||
"""
|
||||
|
||||
self.trees.clear()
|
||||
self.points = np.array(points)
|
||||
self.points = points
|
||||
|
||||
ixs = list(range(len(self.points)))
|
||||
|
||||
@@ -110,7 +110,7 @@ class RPTForest:
|
||||
|
||||
:param root: root of a random projection tree to search
|
||||
:param point: target point
|
||||
:return: set of points located in the same region of space
|
||||
:return: set of indexes of points located in the same region of space
|
||||
"""
|
||||
|
||||
while isinstance(root, SplittingNode):
|
||||
|
||||
Reference in New Issue
Block a user