Source code for traveltimes_prediction.models.algorithms.knn_wrapper
from sklearn.neighbors import KNeighborsClassifier
import numpy as np
from ..base_model import BaseModel
[docs]class KNNWrapper(BaseModel, KNeighborsClassifier):
name = 'KNN'
def __init__(self, **kwargs):
super().__init__(**kwargs)
@staticmethod
[docs] def load(model):
inst = KNNWrapper()
inst._fit_method = model['_fit_method']
inst._y = np.array(model['_y'])
inst._fit_X = np.array(model['_fit_X'])
return inst
[docs] def dump(self):
d = dict()
d['model'] = dict()
d['model']['_fit_method'] = 'kd_tree'
d['model']['_fit_X'] = self._fit_X.tolist()
d['model']['_y'] = self._y.tolist()
d['model_type'] = self.name
return d