Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pylint filtsmooth [WIP?] #169

Merged
merged 17 commits into from
Aug 26, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/probnum/filtsmooth/bayesfiltsmooth.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def predict(self, start, stop, randvar, **kwargs):
)
raise NotImplementedError(errormsg)

def update(self, start, stop, randvar, data, **kwargs):
def update(self, time, randvar, data, **kwargs):
"""
Update step of the Bayesian filter.

Expand Down
22 changes: 22 additions & 0 deletions src/probnum/filtsmooth/gaussfiltsmooth/_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
"""
Utility functions for Gaussian filtering and smoothing.
"""

from probnum.filtsmooth.statespace import (
ContinuousModel,
DiscreteModel,
)


def is_cont_disc(dynamod, measmod):
"""Checks whether the state space model is continuous-discrete."""
dyna_is_cont = isinstance(dynamod, ContinuousModel)
meas_is_disc = isinstance(measmod, DiscreteModel)
return dyna_is_cont and meas_is_disc


def is_disc_disc(dynamod, measmod):
"""Checks whether the state space model is discrete-discrete."""
dyna_is_disc = isinstance(dynamod, DiscreteModel)
meas_is_disc = isinstance(measmod, DiscreteModel)
return dyna_is_disc and meas_is_disc
38 changes: 11 additions & 27 deletions src/probnum/filtsmooth/gaussfiltsmooth/extendedkalman.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,19 @@
"""
import numpy as np

from probnum.filtsmooth.gaussfiltsmooth.gaussfiltsmooth import GaussFiltSmooth
from probnum.filtsmooth.gaussfiltsmooth.gaussfiltsmooth import (
GaussFiltSmooth,
linear_discrete_update,
)
from probnum.prob import RandomVariable
from probnum.prob.distributions import Normal
from probnum.filtsmooth.statespace import (
ContinuousModel,
DiscreteModel,
LinearSDEModel,
DiscreteGaussianModel,
)

from probnum.filtsmooth.gaussfiltsmooth._utils import is_cont_disc, is_disc_disc


class ExtendedKalman(GaussFiltSmooth):
"""
Expand All @@ -23,9 +26,9 @@ class ExtendedKalman(GaussFiltSmooth):
def __new__(cls, dynamod, measmod, initrv, **kwargs):

if cls is ExtendedKalman:
if _cont_disc(dynamod, measmod):
if is_cont_disc(dynamod, measmod):
return _ContDiscExtendedKalman(dynamod, measmod, initrv, **kwargs)
if _disc_disc(dynamod, measmod):
if is_disc_disc(dynamod, measmod):
return _DiscDiscExtendedKalman(dynamod, measmod, initrv, **kwargs)
else:
errmsg = (
Expand All @@ -37,20 +40,6 @@ def __new__(cls, dynamod, measmod, initrv, **kwargs):
return super().__new__(cls)


def _cont_disc(dynamod, measmod):
"""Check whether the state space model is continuous-discrete."""
dyna_is_cont = issubclass(type(dynamod), ContinuousModel)
meas_is_disc = issubclass(type(measmod), DiscreteModel)
return dyna_is_cont and meas_is_disc


def _disc_disc(dynamod, measmod):
"""Check whether the state space model is discrete-discrete."""
dyna_is_disc = issubclass(type(dynamod), DiscreteModel)
meas_is_disc = issubclass(type(measmod), DiscreteModel)
return dyna_is_disc and meas_is_disc


class _ContDiscExtendedKalman(ExtendedKalman):
"""
Continuous-discrete extended Kalman filtering and smoothing.
Expand All @@ -59,12 +48,12 @@ class _ContDiscExtendedKalman(ExtendedKalman):
def __init__(self, dynamod, measmod, initrv, **kwargs):
if not issubclass(type(dynamod), LinearSDEModel):
raise ValueError(
"This implementation of ContDiscExtendedKalmanFilter "
"This implementation of ContDiscExtendedKalman "
"requires a linear dynamic model."
)
if not issubclass(type(measmod), DiscreteGaussianModel):
raise ValueError(
"ContDiscExtendedKalmanFilter requires a Gaussian measurement model."
"ContDiscExtendedKalman requires a Gaussian measurement model."
)
if "cke_nsteps" in kwargs.keys():
self.cke_nsteps = kwargs["cke_nsteps"]
Expand Down Expand Up @@ -121,9 +110,4 @@ def _discrete_extkalman_update(time, randvar, data, measmod, **kwargs):
jacob = measmod.jacobian(time, mpred, **kwargs)
meascov = measmod.diffusionmatrix(time, **kwargs)
meanest = measmod.dynamics(time, mpred, **kwargs)
covest = jacob @ cpred @ jacob.T + meascov
ccest = cpred @ jacob.T
mean = mpred + ccest @ np.linalg.solve(covest, data - meanest)
cov = cpred - ccest @ np.linalg.solve(covest.T, ccest.T)
updated_rv = RandomVariable(distribution=Normal(mean, cov))
return updated_rv, covest, ccest, meanest
return linear_discrete_update(meanest, cpred, data, meascov, jacob, mpred)
12 changes: 10 additions & 2 deletions src/probnum/filtsmooth/gaussfiltsmooth/gaussfiltsmooth.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,7 @@ def smooth_step(self, unsmoothed_rv, pred_rv, smoothed_rv, crosscov):
if np.isscalar(predmean) and np.isscalar(predcov):
predmean = predmean * np.ones(1)
predcov = predcov * np.eye(1)
res = currmean - predmean
newmean = initmean + crosscov @ np.linalg.solve(predcov, res)
newmean = initmean + crosscov @ np.linalg.solve(predcov, currmean - predmean)
firstsolve = crosscov @ np.linalg.solve(predcov, currcov - predcov)
secondsolve = crosscov @ np.linalg.solve(predcov, firstsolve.T)
newcov = initcov + secondsolve.T
Expand All @@ -177,3 +176,12 @@ def predict(self, start, stop, randvar, **kwargs):
@abstractmethod
def update(self, time, randvar, data, **kwargs):
raise NotImplementedError


def linear_discrete_update(meanest, cpred, data, meascov, measmat, mpred):
"""Kalman update, potentially after linearization."""
covest = measmat @ cpred @ measmat.T + meascov
ccest = cpred @ measmat.T
mean = mpred + ccest @ np.linalg.solve(covest, data - meanest)
cov = cpred - ccest @ np.linalg.solve(covest.T, ccest.T)
return RandomVariable(distribution=Normal(mean, cov)), covest, ccest, meanest
37 changes: 11 additions & 26 deletions src/probnum/filtsmooth/gaussfiltsmooth/kalman.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,18 @@
"""

import numpy as np
from probnum.filtsmooth.gaussfiltsmooth.gaussfiltsmooth import GaussFiltSmooth
from probnum.filtsmooth.gaussfiltsmooth.gaussfiltsmooth import (
GaussFiltSmooth,
linear_discrete_update,
)
from probnum.prob import RandomVariable, Normal
from probnum.filtsmooth.statespace import (
ContinuousModel,
DiscreteModel,
LinearSDEModel,
DiscreteGaussianLinearModel,
)

from probnum.filtsmooth.gaussfiltsmooth._utils import is_cont_disc, is_disc_disc


class Kalman(GaussFiltSmooth):
"""
Expand All @@ -29,9 +32,9 @@ def __new__(cls, dynamod, measmod, initrv, **kwargs):
discrete-discrete Kalman object is created.
"""
if cls is Kalman:
if _cont_disc(dynamod, measmod):
if is_cont_disc(dynamod, measmod):
return _ContDiscKalman(dynamod, measmod, initrv, **kwargs)
if _disc_disc(dynamod, measmod):
if is_disc_disc(dynamod, measmod):
return _DiscDiscKalman(dynamod, measmod, initrv)
else:
errmsg = (
Expand All @@ -43,20 +46,6 @@ def __new__(cls, dynamod, measmod, initrv, **kwargs):
return super().__new__(cls)


def _cont_disc(dynamod, measmod):
"""Checks whether the state space model is continuous-discrete."""
dyna_is_cont = issubclass(type(dynamod), ContinuousModel)
meas_is_disc = issubclass(type(measmod), DiscreteModel)
return dyna_is_cont and meas_is_disc


def _disc_disc(dynamod, measmod):
"""Checks whether the state space model is discrete-discrete."""
dyna_is_disc = issubclass(type(dynamod), DiscreteModel)
meas_is_disc = issubclass(type(measmod), DiscreteModel)
return dyna_is_disc and meas_is_disc


class _ContDiscKalman(Kalman):
"""
Provides predict() and update() methods for Kalman filtering and
Expand All @@ -69,11 +58,11 @@ def __init__(self, dynamod, measmod, initrv, **kwargs):
"""
if not issubclass(type(dynamod), LinearSDEModel):
raise ValueError(
"ContinuosDiscreteKalman requires " "a linear dynamic model."
"ContinuousDiscreteKalman requires a linear dynamic model."
)
if not issubclass(type(measmod), DiscreteGaussianLinearModel):
raise ValueError(
"DiscreteDiscreteKalman requires " "a linear measurement model."
"ContinuousDiscreteKalman requires a linear measurement model."
)
if "cke_nsteps" in kwargs.keys():
self.cke_nsteps = kwargs["cke_nsteps"]
Expand Down Expand Up @@ -137,8 +126,4 @@ def _discrete_kalman_update(time, randvar, data, measurementmodel, **kwargs):
measmat = measurementmodel.dynamicsmatrix(time, **kwargs)
meascov = measurementmodel.diffusionmatrix(time, **kwargs)
meanest = measmat @ mpred
covest = measmat @ cpred @ measmat.T + meascov
ccest = cpred @ measmat.T
mean = mpred + ccest @ np.linalg.solve(covest, data - meanest)
cov = cpred - ccest @ np.linalg.solve(covest.T, ccest.T)
return (RandomVariable(distribution=Normal(mean, cov)), covest, ccest, meanest)
return linear_discrete_update(meanest, cpred, data, meascov, measmat, mpred)
11 changes: 5 additions & 6 deletions src/probnum/filtsmooth/gaussfiltsmooth/unscentedkalman.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@

import numpy as np

from probnum.filtsmooth.gaussfiltsmooth.gaussfiltsmooth import GaussFiltSmooth
from probnum.filtsmooth.gaussfiltsmooth.gaussfiltsmooth import (
GaussFiltSmooth,
linear_discrete_update,
)
from probnum.prob import RandomVariable, Normal
from probnum.filtsmooth.gaussfiltsmooth.unscentedtransform import UnscentedTransform
from probnum.filtsmooth.statespace import (
Expand Down Expand Up @@ -158,11 +161,7 @@ def _update_discrete_linear(time, randvar, data, measmod, **kwargs):
measmat = measmod.dynamicsmatrix(time, **kwargs)
meascov = measmod.diffusionmatrix(time, **kwargs)
meanest = measmat @ mpred
covest = measmat @ cpred @ measmat.T + meascov
ccest = cpred @ measmat.T
mean = mpred + ccest @ np.linalg.solve(covest, data - meanest)
cov = cpred - ccest @ np.linalg.solve(covest.T, ccest.T)
return RandomVariable(distribution=Normal(mean, cov)), covest, ccest, meanest
return linear_discrete_update(meanest, cpred, data, meascov, measmat, mpred)


def _update_discrete_nonlinear(time, randvar, data, measmod, ut, **kwargs):
Expand Down
13 changes: 0 additions & 13 deletions src/probnum/filtsmooth/statespace/__init__.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,3 @@
from .continuous import *
from .discrete import *
from .statespace import *

# Public classes and functions. Order is reflected in documentation.
__all__ = [
"ContinuousModel",
"LinearSDEModel",
"LTISDEModel",
"DiscreteModel",
"DiscreteGaussianModel",
"DiscreteGaussianLinearModel",
"DiscreteGaussianLTIModel",
"generate_cd",
"generate_dd",
]
2 changes: 1 addition & 1 deletion tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ ignore_errors = true
commands =
# pylint src
pylint src/probnum/diffeq --disable="protected-access,abstract-class-instantiated,too-many-locals,too-few-public-methods,too-many-arguments,unused-argument,missing-module-docstring,missing-function-docstring"
pylint src/probnum/filtsmooth --disable="duplicate-code,protected-access,no-self-use,too-many-locals,arguments-differ,too-many-arguments,unused-argument,missing-module-docstring,missing-function-docstring"
pylint src/probnum/filtsmooth --disable="duplicate-code,no-self-use,too-many-locals,too-many-arguments,unused-argument,missing-module-docstring,missing-function-docstring"
pylint src/probnum/linalg --disable="attribute-defined-outside-init,too-many-statements,too-many-instance-attributes,too-complex,protected-access,too-many-lines,no-self-use,too-many-locals,redefined-builtin,arguments-differ,abstract-method,too-many-arguments,too-many-branches,duplicate-code,unused-argument,fixme,missing-module-docstring"
pylint src/probnum/prob --disable="too-many-instance-attributes,broad-except,arguments-differ,abstract-method,too-many-arguments,protected-access,duplicate-code,unused-argument,fixme,missing-module-docstring,missing-function-docstring,raise-missing-from"
pylint src/probnum/quad --disable="attribute-defined-outside-init,too-few-public-methods,redefined-builtin,arguments-differ,unused-argument,missing-module-docstring"
Expand Down