Skip to content

Commit

Permalink
Merge pull request #169 from pnkraemer/pylint_filtsmooth
Browse files Browse the repository at this point in the history
Pylint filtsmooth [WIP?]
  • Loading branch information
pnkraemer authored Aug 26, 2020
2 parents 4527cec + 59ff59c commit 146037c
Show file tree
Hide file tree
Showing 8 changed files with 61 additions and 76 deletions.
2 changes: 1 addition & 1 deletion src/probnum/filtsmooth/bayesfiltsmooth.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def predict(self, start, stop, randvar, **kwargs):
)
raise NotImplementedError(errormsg)

def update(self, start, stop, randvar, data, **kwargs):
def update(self, time, randvar, data, **kwargs):
"""
Update step of the Bayesian filter.
Expand Down
22 changes: 22 additions & 0 deletions src/probnum/filtsmooth/gaussfiltsmooth/_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
"""
Utility functions for Gaussian filtering and smoothing.
"""

from probnum.filtsmooth.statespace import (
ContinuousModel,
DiscreteModel,
)


def is_cont_disc(dynamod, measmod):
"""Checks whether the state space model is continuous-discrete."""
dyna_is_cont = isinstance(dynamod, ContinuousModel)
meas_is_disc = isinstance(measmod, DiscreteModel)
return dyna_is_cont and meas_is_disc


def is_disc_disc(dynamod, measmod):
"""Checks whether the state space model is discrete-discrete."""
dyna_is_disc = isinstance(dynamod, DiscreteModel)
meas_is_disc = isinstance(measmod, DiscreteModel)
return dyna_is_disc and meas_is_disc
38 changes: 11 additions & 27 deletions src/probnum/filtsmooth/gaussfiltsmooth/extendedkalman.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,19 @@
"""
import numpy as np

from probnum.filtsmooth.gaussfiltsmooth.gaussfiltsmooth import GaussFiltSmooth
from probnum.filtsmooth.gaussfiltsmooth.gaussfiltsmooth import (
GaussFiltSmooth,
linear_discrete_update,
)
from probnum.prob import RandomVariable
from probnum.prob.distributions import Normal
from probnum.filtsmooth.statespace import (
ContinuousModel,
DiscreteModel,
LinearSDEModel,
DiscreteGaussianModel,
)

from probnum.filtsmooth.gaussfiltsmooth._utils import is_cont_disc, is_disc_disc


class ExtendedKalman(GaussFiltSmooth):
"""
Expand All @@ -23,9 +26,9 @@ class ExtendedKalman(GaussFiltSmooth):
def __new__(cls, dynamod, measmod, initrv, **kwargs):

if cls is ExtendedKalman:
if _cont_disc(dynamod, measmod):
if is_cont_disc(dynamod, measmod):
return _ContDiscExtendedKalman(dynamod, measmod, initrv, **kwargs)
if _disc_disc(dynamod, measmod):
if is_disc_disc(dynamod, measmod):
return _DiscDiscExtendedKalman(dynamod, measmod, initrv, **kwargs)
else:
errmsg = (
Expand All @@ -37,20 +40,6 @@ def __new__(cls, dynamod, measmod, initrv, **kwargs):
return super().__new__(cls)


def _cont_disc(dynamod, measmod):
"""Check whether the state space model is continuous-discrete."""
dyna_is_cont = issubclass(type(dynamod), ContinuousModel)
meas_is_disc = issubclass(type(measmod), DiscreteModel)
return dyna_is_cont and meas_is_disc


def _disc_disc(dynamod, measmod):
"""Check whether the state space model is discrete-discrete."""
dyna_is_disc = issubclass(type(dynamod), DiscreteModel)
meas_is_disc = issubclass(type(measmod), DiscreteModel)
return dyna_is_disc and meas_is_disc


class _ContDiscExtendedKalman(ExtendedKalman):
"""
Continuous-discrete extended Kalman filtering and smoothing.
Expand All @@ -59,12 +48,12 @@ class _ContDiscExtendedKalman(ExtendedKalman):
def __init__(self, dynamod, measmod, initrv, **kwargs):
if not issubclass(type(dynamod), LinearSDEModel):
raise ValueError(
"This implementation of ContDiscExtendedKalmanFilter "
"This implementation of ContDiscExtendedKalman "
"requires a linear dynamic model."
)
if not issubclass(type(measmod), DiscreteGaussianModel):
raise ValueError(
"ContDiscExtendedKalmanFilter requires a Gaussian measurement model."
"ContDiscExtendedKalman requires a Gaussian measurement model."
)
if "cke_nsteps" in kwargs.keys():
self.cke_nsteps = kwargs["cke_nsteps"]
Expand Down Expand Up @@ -121,9 +110,4 @@ def _discrete_extkalman_update(time, randvar, data, measmod, **kwargs):
jacob = measmod.jacobian(time, mpred, **kwargs)
meascov = measmod.diffusionmatrix(time, **kwargs)
meanest = measmod.dynamics(time, mpred, **kwargs)
covest = jacob @ cpred @ jacob.T + meascov
ccest = cpred @ jacob.T
mean = mpred + ccest @ np.linalg.solve(covest, data - meanest)
cov = cpred - ccest @ np.linalg.solve(covest.T, ccest.T)
updated_rv = RandomVariable(distribution=Normal(mean, cov))
return updated_rv, covest, ccest, meanest
return linear_discrete_update(meanest, cpred, data, meascov, jacob, mpred)
12 changes: 10 additions & 2 deletions src/probnum/filtsmooth/gaussfiltsmooth/gaussfiltsmooth.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,7 @@ def smooth_step(self, unsmoothed_rv, pred_rv, smoothed_rv, crosscov):
if np.isscalar(predmean) and np.isscalar(predcov):
predmean = predmean * np.ones(1)
predcov = predcov * np.eye(1)
res = currmean - predmean
newmean = initmean + crosscov @ np.linalg.solve(predcov, res)
newmean = initmean + crosscov @ np.linalg.solve(predcov, currmean - predmean)
firstsolve = crosscov @ np.linalg.solve(predcov, currcov - predcov)
secondsolve = crosscov @ np.linalg.solve(predcov, firstsolve.T)
newcov = initcov + secondsolve.T
Expand All @@ -177,3 +176,12 @@ def predict(self, start, stop, randvar, **kwargs):
@abstractmethod
def update(self, time, randvar, data, **kwargs):
raise NotImplementedError


def linear_discrete_update(meanest, cpred, data, meascov, measmat, mpred):
"""Kalman update, potentially after linearization."""
covest = measmat @ cpred @ measmat.T + meascov
ccest = cpred @ measmat.T
mean = mpred + ccest @ np.linalg.solve(covest, data - meanest)
cov = cpred - ccest @ np.linalg.solve(covest.T, ccest.T)
return RandomVariable(distribution=Normal(mean, cov)), covest, ccest, meanest
37 changes: 11 additions & 26 deletions src/probnum/filtsmooth/gaussfiltsmooth/kalman.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,18 @@
"""

import numpy as np
from probnum.filtsmooth.gaussfiltsmooth.gaussfiltsmooth import GaussFiltSmooth
from probnum.filtsmooth.gaussfiltsmooth.gaussfiltsmooth import (
GaussFiltSmooth,
linear_discrete_update,
)
from probnum.prob import RandomVariable, Normal
from probnum.filtsmooth.statespace import (
ContinuousModel,
DiscreteModel,
LinearSDEModel,
DiscreteGaussianLinearModel,
)

from probnum.filtsmooth.gaussfiltsmooth._utils import is_cont_disc, is_disc_disc


class Kalman(GaussFiltSmooth):
"""
Expand All @@ -29,9 +32,9 @@ def __new__(cls, dynamod, measmod, initrv, **kwargs):
discrete-discrete Kalman object is created.
"""
if cls is Kalman:
if _cont_disc(dynamod, measmod):
if is_cont_disc(dynamod, measmod):
return _ContDiscKalman(dynamod, measmod, initrv, **kwargs)
if _disc_disc(dynamod, measmod):
if is_disc_disc(dynamod, measmod):
return _DiscDiscKalman(dynamod, measmod, initrv)
else:
errmsg = (
Expand All @@ -43,20 +46,6 @@ def __new__(cls, dynamod, measmod, initrv, **kwargs):
return super().__new__(cls)


def _cont_disc(dynamod, measmod):
"""Checks whether the state space model is continuous-discrete."""
dyna_is_cont = issubclass(type(dynamod), ContinuousModel)
meas_is_disc = issubclass(type(measmod), DiscreteModel)
return dyna_is_cont and meas_is_disc


def _disc_disc(dynamod, measmod):
"""Checks whether the state space model is discrete-discrete."""
dyna_is_disc = issubclass(type(dynamod), DiscreteModel)
meas_is_disc = issubclass(type(measmod), DiscreteModel)
return dyna_is_disc and meas_is_disc


class _ContDiscKalman(Kalman):
"""
Provides predict() and update() methods for Kalman filtering and
Expand All @@ -69,11 +58,11 @@ def __init__(self, dynamod, measmod, initrv, **kwargs):
"""
if not issubclass(type(dynamod), LinearSDEModel):
raise ValueError(
"ContinuosDiscreteKalman requires " "a linear dynamic model."
"ContinuousDiscreteKalman requires a linear dynamic model."
)
if not issubclass(type(measmod), DiscreteGaussianLinearModel):
raise ValueError(
"DiscreteDiscreteKalman requires " "a linear measurement model."
"ContinuousDiscreteKalman requires a linear measurement model."
)
if "cke_nsteps" in kwargs.keys():
self.cke_nsteps = kwargs["cke_nsteps"]
Expand Down Expand Up @@ -137,8 +126,4 @@ def _discrete_kalman_update(time, randvar, data, measurementmodel, **kwargs):
measmat = measurementmodel.dynamicsmatrix(time, **kwargs)
meascov = measurementmodel.diffusionmatrix(time, **kwargs)
meanest = measmat @ mpred
covest = measmat @ cpred @ measmat.T + meascov
ccest = cpred @ measmat.T
mean = mpred + ccest @ np.linalg.solve(covest, data - meanest)
cov = cpred - ccest @ np.linalg.solve(covest.T, ccest.T)
return (RandomVariable(distribution=Normal(mean, cov)), covest, ccest, meanest)
return linear_discrete_update(meanest, cpred, data, meascov, measmat, mpred)
11 changes: 5 additions & 6 deletions src/probnum/filtsmooth/gaussfiltsmooth/unscentedkalman.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@

import numpy as np

from probnum.filtsmooth.gaussfiltsmooth.gaussfiltsmooth import GaussFiltSmooth
from probnum.filtsmooth.gaussfiltsmooth.gaussfiltsmooth import (
GaussFiltSmooth,
linear_discrete_update,
)
from probnum.prob import RandomVariable, Normal
from probnum.filtsmooth.gaussfiltsmooth.unscentedtransform import UnscentedTransform
from probnum.filtsmooth.statespace import (
Expand Down Expand Up @@ -158,11 +161,7 @@ def _update_discrete_linear(time, randvar, data, measmod, **kwargs):
measmat = measmod.dynamicsmatrix(time, **kwargs)
meascov = measmod.diffusionmatrix(time, **kwargs)
meanest = measmat @ mpred
covest = measmat @ cpred @ measmat.T + meascov
ccest = cpred @ measmat.T
mean = mpred + ccest @ np.linalg.solve(covest, data - meanest)
cov = cpred - ccest @ np.linalg.solve(covest.T, ccest.T)
return RandomVariable(distribution=Normal(mean, cov)), covest, ccest, meanest
return linear_discrete_update(meanest, cpred, data, meascov, measmat, mpred)


def _update_discrete_nonlinear(time, randvar, data, measmod, ut, **kwargs):
Expand Down
13 changes: 0 additions & 13 deletions src/probnum/filtsmooth/statespace/__init__.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,3 @@
from .continuous import *
from .discrete import *
from .statespace import *

# Public classes and functions. Order is reflected in documentation.
__all__ = [
"ContinuousModel",
"LinearSDEModel",
"LTISDEModel",
"DiscreteModel",
"DiscreteGaussianModel",
"DiscreteGaussianLinearModel",
"DiscreteGaussianLTIModel",
"generate_cd",
"generate_dd",
]
2 changes: 1 addition & 1 deletion tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ ignore_errors = true
commands =
# pylint src
pylint src/probnum/diffeq --disable="protected-access,abstract-class-instantiated,too-many-locals,too-few-public-methods,too-many-arguments,unused-argument,missing-module-docstring,missing-function-docstring"
pylint src/probnum/filtsmooth --disable="duplicate-code,protected-access,no-self-use,too-many-locals,arguments-differ,too-many-arguments,unused-argument,missing-module-docstring,missing-function-docstring"
pylint src/probnum/filtsmooth --disable="duplicate-code,no-self-use,too-many-locals,too-many-arguments,unused-argument,missing-module-docstring,missing-function-docstring"
pylint src/probnum/linalg --disable="attribute-defined-outside-init,too-many-statements,too-many-instance-attributes,too-complex,protected-access,too-many-lines,no-self-use,too-many-locals,redefined-builtin,arguments-differ,abstract-method,too-many-arguments,too-many-branches,duplicate-code,unused-argument,fixme,missing-module-docstring"
pylint src/probnum/prob --disable="too-many-instance-attributes,broad-except,arguments-differ,abstract-method,too-many-arguments,protected-access,duplicate-code,unused-argument,fixme,missing-module-docstring,missing-function-docstring,raise-missing-from"
pylint src/probnum/quad --disable="attribute-defined-outside-init,too-few-public-methods,redefined-builtin,arguments-differ,unused-argument,missing-module-docstring"
Expand Down

0 comments on commit 146037c

Please sign in to comment.