upd regressor

This commit is contained in:
hashlag
2024-02-04 17:51:50 +03:00
parent 3fd2d954f4
commit f9f724fa7d

View File

@@ -5,7 +5,26 @@ from .exceptions import *
class KNNRegressor: class KNNRegressor:
"""K-nearest neighbors regressor
Nadaraya-Watson kNN regressor based on random projection forest.
Supports different (including custom) smoothing kernels and distance metrics.
Attributes:
features: number of features in each sample
forest: an instance of RPTForest
targets: an array of target values corresponding to loaded train points
"""
def __init__(self, features, trees_count, rpt_m): def __init__(self, features, trees_count, rpt_m):
"""Initializes new regressor
:param features: number of features in each sample
:param trees_count: number of trees in the forest
:param rpt_m: maximum number of samples in one leaf of an RP tree
"""
self.features = features self.features = features
self.forest = RPTForest(features, trees_count, rpt_m) self.forest = RPTForest(features, trees_count, rpt_m)
self.targets = None self.targets = None
@@ -36,16 +55,25 @@ class KNNRegressor:
self.forest.load(points) self.forest.load(points)
def predict(self, point: np.ndarray, distance, kernel, h): def predict(self, point: np.ndarray, distance, kernel, h):
"""Predict target value for given point
:param point: target point as np.ndarray
:param distance: distance metric function (e.g., from neighbours.distance)
:param kernel: smoothing kernel function (e.g., from neighbours.kernel)
:param h: bandwidth
:return: predicted value or numpy.nan if unable to obtain a prediction
"""
nearest_point_indexes = self.forest.get_neighbours(point) nearest_point_indexes = self.forest.get_neighbours(point)
# Nadaraya-Watson estimator # Nadaraya-Watson estimator
numerator = float(0) numerator = float(0)
denominator = float(0.0000001) denominator = float(0)
for point_ix in nearest_point_indexes: for point_ix in nearest_point_indexes:
weight = kernel(distance(point, self.forest.get_point(point_ix)) / h) weight = kernel(distance(point, self.forest.get_point(point_ix)) / h)
numerator += weight * self.targets[point_ix] numerator += weight * self.targets[point_ix]
denominator += weight denominator += weight
return numerator / denominator return np.nan if denominator == 0 else numerator / denominator