You've already forked neighbours
Add KNNClassifier, exceptions
This commit is contained in:
@@ -1 +1 @@
|
|||||||
from .rp_neighbours import *
|
from .knn_classifier import *
|
||||||
|
|||||||
6
neighbours/exceptions.py
Normal file
6
neighbours/exceptions.py
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
class InvalidDimensionError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidType(Exception):
|
||||||
|
pass
|
||||||
33
neighbours/knn_classifier.py
Normal file
33
neighbours/knn_classifier.py
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from .rp_neighbours import *
|
||||||
|
from .exceptions import *
|
||||||
|
|
||||||
|
|
||||||
|
class KNNClassifier:
|
||||||
|
def __init__(self, features, trees_count, rpt_m):
|
||||||
|
self.features = features
|
||||||
|
self.forest = RPTForest(features, trees_count, rpt_m)
|
||||||
|
self.classes = None
|
||||||
|
|
||||||
|
def load(self, points, classes):
|
||||||
|
if not isinstance(points, np.ndarray):
|
||||||
|
raise InvalidType("points should be represented as np.ndarray")
|
||||||
|
|
||||||
|
if not isinstance(classes, np.ndarray) and not isinstance(classes, list):
|
||||||
|
raise InvalidType("classes should be represented as np.ndarray or list")
|
||||||
|
|
||||||
|
self.classes = classes
|
||||||
|
|
||||||
|
if points.ndim != 2:
|
||||||
|
raise InvalidDimensionError("points array should be two-dimensional")
|
||||||
|
|
||||||
|
if points.shape[1] != self.features:
|
||||||
|
raise InvalidDimensionError(
|
||||||
|
"invalid number of features in sample (expected {}, got {})".format(self.features, points.shape[1])
|
||||||
|
)
|
||||||
|
|
||||||
|
self.forest.load(points)
|
||||||
|
|
||||||
|
def predict(self, point: np.ndarray):
|
||||||
|
pass
|
||||||
@@ -85,20 +85,20 @@ class RPTForest:
|
|||||||
def get_point(self, ix):
|
def get_point(self, ix):
|
||||||
"""Returns stored point by index
|
"""Returns stored point by index
|
||||||
|
|
||||||
:param ix: point index
|
:param ix: target point index
|
||||||
:return: np.ndarray representing the point
|
:return: np.ndarray representing the point
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return self.points[ix]
|
return self.points[ix]
|
||||||
|
|
||||||
def load(self, points: list) -> None:
|
def load(self, points: np.ndarray) -> None:
|
||||||
"""Loads a list of points and builds the corresponding forest
|
"""Loads a list of points and builds the corresponding forest
|
||||||
|
|
||||||
:param points: list of points
|
:param points: numpy.ndarray of points
|
||||||
"""
|
"""
|
||||||
|
|
||||||
self.trees.clear()
|
self.trees.clear()
|
||||||
self.points = np.array(points)
|
self.points = points
|
||||||
|
|
||||||
ixs = list(range(len(self.points)))
|
ixs = list(range(len(self.points)))
|
||||||
|
|
||||||
@@ -110,7 +110,7 @@ class RPTForest:
|
|||||||
|
|
||||||
:param root: root of a random projection tree to search
|
:param root: root of a random projection tree to search
|
||||||
:param point: target point
|
:param point: target point
|
||||||
:return: set of points located in the same region of space
|
:return: set of indexes of points located in the same region of space
|
||||||
"""
|
"""
|
||||||
|
|
||||||
while isinstance(root, SplittingNode):
|
while isinstance(root, SplittingNode):
|
||||||
|
|||||||
Reference in New Issue
Block a user