You've already forked neighbours
upd regressor
This commit is contained in:
@@ -5,7 +5,26 @@ from .exceptions import *
|
|||||||
|
|
||||||
|
|
||||||
class KNNRegressor:
|
class KNNRegressor:
|
||||||
|
"""K-nearest neighbors regressor
|
||||||
|
|
||||||
|
Nadaraya-Watson kNN regressor based on random projection forest.
|
||||||
|
|
||||||
|
Supports different (including custom) smoothing kernels and distance metrics.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
features: number of features in each sample
|
||||||
|
forest: an instance of RPTForest
|
||||||
|
targets: an array of target values corresponding to loaded train points
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, features, trees_count, rpt_m):
|
def __init__(self, features, trees_count, rpt_m):
|
||||||
|
"""Initializes new regressor
|
||||||
|
|
||||||
|
:param features: number of features in each sample
|
||||||
|
:param trees_count: number of trees in the forest
|
||||||
|
:param rpt_m: maximum number of samples in one leaf of an RP tree
|
||||||
|
"""
|
||||||
|
|
||||||
self.features = features
|
self.features = features
|
||||||
self.forest = RPTForest(features, trees_count, rpt_m)
|
self.forest = RPTForest(features, trees_count, rpt_m)
|
||||||
self.targets = None
|
self.targets = None
|
||||||
@@ -36,16 +55,25 @@ class KNNRegressor:
|
|||||||
self.forest.load(points)
|
self.forest.load(points)
|
||||||
|
|
||||||
def predict(self, point: np.ndarray, distance, kernel, h):
|
def predict(self, point: np.ndarray, distance, kernel, h):
|
||||||
|
"""Predict target value for given point
|
||||||
|
|
||||||
|
:param point: target point as np.ndarray
|
||||||
|
:param distance: distance metric function (e.g., from neighbours.distance)
|
||||||
|
:param kernel: smoothing kernel function (e.g., from neighbours.kernel)
|
||||||
|
:param h: bandwidth
|
||||||
|
:return: predicted value or numpy.nan if unable to obtain a prediction
|
||||||
|
"""
|
||||||
|
|
||||||
nearest_point_indexes = self.forest.get_neighbours(point)
|
nearest_point_indexes = self.forest.get_neighbours(point)
|
||||||
|
|
||||||
# Nadaraya-Watson estimator
|
# Nadaraya-Watson estimator
|
||||||
|
|
||||||
numerator = float(0)
|
numerator = float(0)
|
||||||
denominator = float(0.0000001)
|
denominator = float(0)
|
||||||
|
|
||||||
for point_ix in nearest_point_indexes:
|
for point_ix in nearest_point_indexes:
|
||||||
weight = kernel(distance(point, self.forest.get_point(point_ix)) / h)
|
weight = kernel(distance(point, self.forest.get_point(point_ix)) / h)
|
||||||
numerator += weight * self.targets[point_ix]
|
numerator += weight * self.targets[point_ix]
|
||||||
denominator += weight
|
denominator += weight
|
||||||
|
|
||||||
return numerator / denominator
|
return np.nan if denominator == 0 else numerator / denominator
|
||||||
|
|||||||
Reference in New Issue
Block a user