Source code for traveltimes_prediction.models.algorithms.knn_wrapper

from sklearn.neighbors import KNeighborsClassifier
import numpy as np

from ..base_model import BaseModel


[docs]class KNNWrapper(BaseModel, KNeighborsClassifier): name = 'KNN' def __init__(self, **kwargs): super().__init__(**kwargs) @staticmethod
[docs] def load(model): inst = KNNWrapper() inst._fit_method = model['_fit_method'] inst._y = np.array(model['_y']) inst._fit_X = np.array(model['_fit_X']) return inst
[docs] def dump(self): d = dict() d['model'] = dict() d['model']['_fit_method'] = 'kd_tree' d['model']['_fit_X'] = self._fit_X.tolist() d['model']['_y'] = self._y.tolist() d['model_type'] = self.name return d