Source code for traveltimes_prediction.models.algorithms.kalman_wrapper
import copy
import numpy as np
import pandas as pd
from filterpy.common import Q_discrete_white_noise
from filterpy.kalman import KalmanFilter
from ...support_files.helpers import tt_column_search
[docs]class KalmanWrapper:
name='Kalman'
def __init__(self, train_data):
self.train_data = copy.deepcopy(train_data)
self.f = KalmanFilter(dim_x=2, dim_z=1)
self.set_filter_coefficients(x=np.array([[100.], [.0]]),
F=np.array([[1, 1],[.0, 1.]]),
H=np.array([[1.0, 0]]),
P=np.array([[1000, 0],[0, 1000]]),
R=5,
Q=Q_discrete_white_noise(dim=2, dt=0.2, var=15))
[docs] def set_filter_coefficients(self, x, F, H, P, R, Q):
self.f.x = x
self.f.F = F
self.f.H = H
self.f.P = P
self.f.R = R
self.f.Q = Q
[docs] def make_predictions(self):
is_df = False
if isinstance(self.train_data, pd.DataFrame):
is_df = True
self.train_data.sort_values(by=['calculation_time_local'], ascending=[True], inplace=True)
timeline = self.train_data.calculation_time_local.values
tt_column = tt_column_search(list(self.train_data.columns))[0]
meas = (x for x in self.train_data[tt_column].values)
else:
meas = (x for x in self.train_data)
kalman = []
while True:
try:
z = next(meas) # 2
self.f.predict() # 1
kalman.append(self.f.x[0][0])
self.f.update(z) # 3
except StopIteration:
break
arr = np.array(kalman)
if is_df:
arr = np.vstack((arr, timeline))
return pd.DataFrame(arr.T, columns=['Kalman tt prediction', 'calculation_time_local'])
return arr