Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pylint filtsmooth [WIP?] #169

Merged
merged 17 commits into from
Aug 26, 2020
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/probnum/filtsmooth/bayesfiltsmooth.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def predict(self, start, stop, randvar, **kwargs):
)
raise NotImplementedError(errormsg)

def update(self, start, stop, randvar, data, **kwargs):
def update(self, time, randvar, data, **kwargs):
"""
Update step of the Bayesian filter.

Expand Down
24 changes: 24 additions & 0 deletions src/probnum/filtsmooth/gaussfiltsmooth/_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
"""
Utility functions for Gaussian filtering and smoothing.
"""

from probnum.filtsmooth.statespace import (
ContinuousModel,
DiscreteModel,
LinearSDEModel,
DiscreteGaussianLinearModel,
)


def is_cont_disc(dynamod, measmod):
"""Checks whether the state space model is continuous-discrete."""
dyna_is_cont = issubclass(type(dynamod), ContinuousModel)
nathanaelbosch marked this conversation as resolved.
Show resolved Hide resolved
meas_is_disc = issubclass(type(measmod), DiscreteModel)
return dyna_is_cont and meas_is_disc


def is_disc_disc(dynamod, measmod):
"""Checks whether the state space model is discrete-discrete."""
dyna_is_disc = issubclass(type(dynamod), DiscreteModel)
meas_is_disc = issubclass(type(measmod), DiscreteModel)
return dyna_is_disc and meas_is_disc
36 changes: 11 additions & 25 deletions src/probnum/filtsmooth/gaussfiltsmooth/extendedkalman.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@
"""
import numpy as np

from probnum.filtsmooth.gaussfiltsmooth.gaussfiltsmooth import GaussFiltSmooth
from probnum.filtsmooth.gaussfiltsmooth.gaussfiltsmooth import (
GaussFiltSmooth,
linear_discrete_update,
)
from probnum.prob import RandomVariable
from probnum.prob.distributions import Normal
from probnum.filtsmooth.statespace import (
Expand All @@ -14,6 +17,8 @@
DiscreteGaussianModel,
)

from probnum.filtsmooth.gaussfiltsmooth._utils import is_cont_disc, is_disc_disc


class ExtendedKalman(GaussFiltSmooth):
"""
Expand All @@ -23,9 +28,9 @@ class ExtendedKalman(GaussFiltSmooth):
def __new__(cls, dynamod, measmod, initrv, **kwargs):

if cls is ExtendedKalman:
if _cont_disc(dynamod, measmod):
if is_cont_disc(dynamod, measmod):
return _ContDiscExtendedKalman(dynamod, measmod, initrv, **kwargs)
if _disc_disc(dynamod, measmod):
if is_disc_disc(dynamod, measmod):
return _DiscDiscExtendedKalman(dynamod, measmod, initrv, **kwargs)
else:
errmsg = (
Expand All @@ -37,20 +42,6 @@ def __new__(cls, dynamod, measmod, initrv, **kwargs):
return super().__new__(cls)


def _cont_disc(dynamod, measmod):
"""Check whether the state space model is continuous-discrete."""
dyna_is_cont = issubclass(type(dynamod), ContinuousModel)
meas_is_disc = issubclass(type(measmod), DiscreteModel)
return dyna_is_cont and meas_is_disc


def _disc_disc(dynamod, measmod):
"""Check whether the state space model is discrete-discrete."""
dyna_is_disc = issubclass(type(dynamod), DiscreteModel)
meas_is_disc = issubclass(type(measmod), DiscreteModel)
return dyna_is_disc and meas_is_disc


class _ContDiscExtendedKalman(ExtendedKalman):
"""
Continuous-discrete extended Kalman filtering and smoothing.
Expand All @@ -59,12 +50,12 @@ class _ContDiscExtendedKalman(ExtendedKalman):
def __init__(self, dynamod, measmod, initrv, **kwargs):
if not issubclass(type(dynamod), LinearSDEModel):
raise ValueError(
"This implementation of ContDiscExtendedKalmanFilter "
"This implementation of ContDiscExtendedKalman "
"requires a linear dynamic model."
)
if not issubclass(type(measmod), DiscreteGaussianModel):
raise ValueError(
"ContDiscExtendedKalmanFilter requires a Gaussian measurement model."
"ContDiscExtendedKalman requires a Gaussian measurement model."
)
if "cke_nsteps" in kwargs.keys():
self.cke_nsteps = kwargs["cke_nsteps"]
Expand Down Expand Up @@ -121,9 +112,4 @@ def _discrete_extkalman_update(time, randvar, data, measmod, **kwargs):
jacob = measmod.jacobian(time, mpred, **kwargs)
meascov = measmod.diffusionmatrix(time, **kwargs)
meanest = measmod.dynamics(time, mpred, **kwargs)
covest = jacob @ cpred @ jacob.T + meascov
ccest = cpred @ jacob.T
mean = mpred + ccest @ np.linalg.solve(covest, data - meanest)
cov = cpred - ccest @ np.linalg.solve(covest.T, ccest.T)
updated_rv = RandomVariable(distribution=Normal(mean, cov))
return updated_rv, covest, ccest, meanest
return linear_discrete_update(meanest, cpred, data, meascov, jacob, mpred)
12 changes: 10 additions & 2 deletions src/probnum/filtsmooth/gaussfiltsmooth/gaussfiltsmooth.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,7 @@ def smooth_step(self, unsmoothed_rv, pred_rv, smoothed_rv, crosscov):
if np.isscalar(predmean) and np.isscalar(predcov):
predmean = predmean * np.ones(1)
predcov = predcov * np.eye(1)
res = currmean - predmean
newmean = initmean + crosscov @ np.linalg.solve(predcov, res)
newmean = initmean + crosscov @ np.linalg.solve(predcov, currmean - predmean)
firstsolve = crosscov @ np.linalg.solve(predcov, currcov - predcov)
secondsolve = crosscov @ np.linalg.solve(predcov, firstsolve.T)
newcov = initcov + secondsolve.T
Expand All @@ -177,3 +176,12 @@ def predict(self, start, stop, randvar, **kwargs):
@abstractmethod
def update(self, time, randvar, data, **kwargs):
raise NotImplementedError


def linear_discrete_update(meanest, cpred, data, meascov, measmat, mpred):
"""Kalman update, potentially after linearization."""
covest = measmat @ cpred @ measmat.T + meascov
ccest = cpred @ measmat.T
mean = mpred + ccest @ np.linalg.solve(covest, data - meanest)
cov = cpred - ccest @ np.linalg.solve(covest.T, ccest.T)
return RandomVariable(distribution=Normal(mean, cov)), covest, ccest, meanest
35 changes: 11 additions & 24 deletions src/probnum/filtsmooth/gaussfiltsmooth/kalman.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@
"""

import numpy as np
from probnum.filtsmooth.gaussfiltsmooth.gaussfiltsmooth import GaussFiltSmooth
from probnum.filtsmooth.gaussfiltsmooth.gaussfiltsmooth import (
GaussFiltSmooth,
linear_discrete_update,
)
from probnum.prob import RandomVariable, Normal
from probnum.filtsmooth.statespace import (
ContinuousModel,
Expand All @@ -13,6 +16,8 @@
DiscreteGaussianLinearModel,
)

from probnum.filtsmooth.gaussfiltsmooth._utils import is_cont_disc, is_disc_disc


class Kalman(GaussFiltSmooth):
"""
Expand All @@ -29,9 +34,9 @@ def __new__(cls, dynamod, measmod, initrv, **kwargs):
discrete-discrete Kalman object is created.
"""
if cls is Kalman:
if _cont_disc(dynamod, measmod):
if is_cont_disc(dynamod, measmod):
return _ContDiscKalman(dynamod, measmod, initrv, **kwargs)
if _disc_disc(dynamod, measmod):
if is_disc_disc(dynamod, measmod):
return _DiscDiscKalman(dynamod, measmod, initrv)
else:
errmsg = (
Expand All @@ -43,20 +48,6 @@ def __new__(cls, dynamod, measmod, initrv, **kwargs):
return super().__new__(cls)


def _cont_disc(dynamod, measmod):
"""Checks whether the state space model is continuous-discrete."""
dyna_is_cont = issubclass(type(dynamod), ContinuousModel)
meas_is_disc = issubclass(type(measmod), DiscreteModel)
return dyna_is_cont and meas_is_disc


def _disc_disc(dynamod, measmod):
"""Checks whether the state space model is discrete-discrete."""
dyna_is_disc = issubclass(type(dynamod), DiscreteModel)
meas_is_disc = issubclass(type(measmod), DiscreteModel)
return dyna_is_disc and meas_is_disc


class _ContDiscKalman(Kalman):
"""
Provides predict() and update() methods for Kalman filtering and
Expand All @@ -69,11 +60,11 @@ def __init__(self, dynamod, measmod, initrv, **kwargs):
"""
if not issubclass(type(dynamod), LinearSDEModel):
raise ValueError(
"ContinuosDiscreteKalman requires " "a linear dynamic model."
"ContinuousDiscreteKalman requires a linear dynamic model."
)
if not issubclass(type(measmod), DiscreteGaussianLinearModel):
raise ValueError(
"DiscreteDiscreteKalman requires " "a linear measurement model."
"ContinuousDiscreteKalman requires a linear measurement model."
)
if "cke_nsteps" in kwargs.keys():
self.cke_nsteps = kwargs["cke_nsteps"]
Expand Down Expand Up @@ -137,8 +128,4 @@ def _discrete_kalman_update(time, randvar, data, measurementmodel, **kwargs):
measmat = measurementmodel.dynamicsmatrix(time, **kwargs)
meascov = measurementmodel.diffusionmatrix(time, **kwargs)
meanest = measmat @ mpred
covest = measmat @ cpred @ measmat.T + meascov
ccest = cpred @ measmat.T
mean = mpred + ccest @ np.linalg.solve(covest, data - meanest)
cov = cpred - ccest @ np.linalg.solve(covest.T, ccest.T)
return (RandomVariable(distribution=Normal(mean, cov)), covest, ccest, meanest)
return linear_discrete_update(meanest, cpred, data, meascov, measmat, mpred)
11 changes: 5 additions & 6 deletions src/probnum/filtsmooth/gaussfiltsmooth/unscentedkalman.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@

import numpy as np

from probnum.filtsmooth.gaussfiltsmooth.gaussfiltsmooth import GaussFiltSmooth
from probnum.filtsmooth.gaussfiltsmooth.gaussfiltsmooth import (
GaussFiltSmooth,
linear_discrete_update,
)
from probnum.prob import RandomVariable, Normal
from probnum.filtsmooth.gaussfiltsmooth.unscentedtransform import UnscentedTransform
from probnum.filtsmooth.statespace import (
Expand Down Expand Up @@ -158,11 +161,7 @@ def _update_discrete_linear(time, randvar, data, measmod, **kwargs):
measmat = measmod.dynamicsmatrix(time, **kwargs)
meascov = measmod.diffusionmatrix(time, **kwargs)
meanest = measmat @ mpred
covest = measmat @ cpred @ measmat.T + meascov
ccest = cpred @ measmat.T
mean = mpred + ccest @ np.linalg.solve(covest, data - meanest)
cov = cpred - ccest @ np.linalg.solve(covest.T, ccest.T)
return RandomVariable(distribution=Normal(mean, cov)), covest, ccest, meanest
return linear_discrete_update(meanest, cpred, data, meascov, measmat, mpred)


def _update_discrete_nonlinear(time, randvar, data, measmod, ut, **kwargs):
Expand Down
13 changes: 0 additions & 13 deletions src/probnum/filtsmooth/statespace/__init__.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,3 @@
from .continuous import *
from .discrete import *
from .statespace import *

# Public classes and functions. Order is reflected in documentation.
__all__ = [
"ContinuousModel",
"LinearSDEModel",
"LTISDEModel",
"DiscreteModel",
"DiscreteGaussianModel",
"DiscreteGaussianLinearModel",
"DiscreteGaussianLTIModel",
"generate_cd",
"generate_dd",
]