diff --git a/src/probnum/filtsmooth/bayesfiltsmooth.py b/src/probnum/filtsmooth/bayesfiltsmooth.py index 61bb8f17f..81dbebc89 100644 --- a/src/probnum/filtsmooth/bayesfiltsmooth.py +++ b/src/probnum/filtsmooth/bayesfiltsmooth.py @@ -51,7 +51,7 @@ def predict(self, start, stop, randvar, **kwargs): ) raise NotImplementedError(errormsg) - def update(self, start, stop, randvar, data, **kwargs): + def update(self, time, randvar, data, **kwargs): """ Update step of the Bayesian filter. diff --git a/src/probnum/filtsmooth/gaussfiltsmooth/_utils.py b/src/probnum/filtsmooth/gaussfiltsmooth/_utils.py new file mode 100644 index 000000000..ec56a4130 --- /dev/null +++ b/src/probnum/filtsmooth/gaussfiltsmooth/_utils.py @@ -0,0 +1,22 @@ +""" +Utility functions for Gaussian filtering and smoothing. +""" + +from probnum.filtsmooth.statespace import ( + ContinuousModel, + DiscreteModel, +) + + +def is_cont_disc(dynamod, measmod): + """Checks whether the state space model is continuous-discrete.""" + dyna_is_cont = isinstance(dynamod, ContinuousModel) + meas_is_disc = isinstance(measmod, DiscreteModel) + return dyna_is_cont and meas_is_disc + + +def is_disc_disc(dynamod, measmod): + """Checks whether the state space model is discrete-discrete.""" + dyna_is_disc = isinstance(dynamod, DiscreteModel) + meas_is_disc = isinstance(measmod, DiscreteModel) + return dyna_is_disc and meas_is_disc diff --git a/src/probnum/filtsmooth/gaussfiltsmooth/extendedkalman.py b/src/probnum/filtsmooth/gaussfiltsmooth/extendedkalman.py index e552e8246..d8daae6ec 100644 --- a/src/probnum/filtsmooth/gaussfiltsmooth/extendedkalman.py +++ b/src/probnum/filtsmooth/gaussfiltsmooth/extendedkalman.py @@ -4,16 +4,19 @@ """ import numpy as np -from probnum.filtsmooth.gaussfiltsmooth.gaussfiltsmooth import GaussFiltSmooth +from probnum.filtsmooth.gaussfiltsmooth.gaussfiltsmooth import ( + GaussFiltSmooth, + linear_discrete_update, +) from probnum.prob import RandomVariable from probnum.prob.distributions import Normal from probnum.filtsmooth.statespace import ( - ContinuousModel, - DiscreteModel, LinearSDEModel, DiscreteGaussianModel, ) +from probnum.filtsmooth.gaussfiltsmooth._utils import is_cont_disc, is_disc_disc + class ExtendedKalman(GaussFiltSmooth): """ @@ -23,9 +26,9 @@ class ExtendedKalman(GaussFiltSmooth): def __new__(cls, dynamod, measmod, initrv, **kwargs): if cls is ExtendedKalman: - if _cont_disc(dynamod, measmod): + if is_cont_disc(dynamod, measmod): return _ContDiscExtendedKalman(dynamod, measmod, initrv, **kwargs) - if _disc_disc(dynamod, measmod): + if is_disc_disc(dynamod, measmod): return _DiscDiscExtendedKalman(dynamod, measmod, initrv, **kwargs) else: errmsg = ( @@ -37,20 +40,6 @@ def __new__(cls, dynamod, measmod, initrv, **kwargs): return super().__new__(cls) -def _cont_disc(dynamod, measmod): - """Check whether the state space model is continuous-discrete.""" - dyna_is_cont = issubclass(type(dynamod), ContinuousModel) - meas_is_disc = issubclass(type(measmod), DiscreteModel) - return dyna_is_cont and meas_is_disc - - -def _disc_disc(dynamod, measmod): - """Check whether the state space model is discrete-discrete.""" - dyna_is_disc = issubclass(type(dynamod), DiscreteModel) - meas_is_disc = issubclass(type(measmod), DiscreteModel) - return dyna_is_disc and meas_is_disc - - class _ContDiscExtendedKalman(ExtendedKalman): """ Continuous-discrete extended Kalman filtering and smoothing. @@ -59,12 +48,12 @@ class _ContDiscExtendedKalman(ExtendedKalman): def __init__(self, dynamod, measmod, initrv, **kwargs): if not issubclass(type(dynamod), LinearSDEModel): raise ValueError( - "This implementation of ContDiscExtendedKalmanFilter " + "This implementation of ContDiscExtendedKalman " "requires a linear dynamic model." ) if not issubclass(type(measmod), DiscreteGaussianModel): raise ValueError( - "ContDiscExtendedKalmanFilter requires a Gaussian measurement model." + "ContDiscExtendedKalman requires a Gaussian measurement model." ) if "cke_nsteps" in kwargs.keys(): self.cke_nsteps = kwargs["cke_nsteps"] @@ -121,9 +110,4 @@ def _discrete_extkalman_update(time, randvar, data, measmod, **kwargs): jacob = measmod.jacobian(time, mpred, **kwargs) meascov = measmod.diffusionmatrix(time, **kwargs) meanest = measmod.dynamics(time, mpred, **kwargs) - covest = jacob @ cpred @ jacob.T + meascov - ccest = cpred @ jacob.T - mean = mpred + ccest @ np.linalg.solve(covest, data - meanest) - cov = cpred - ccest @ np.linalg.solve(covest.T, ccest.T) - updated_rv = RandomVariable(distribution=Normal(mean, cov)) - return updated_rv, covest, ccest, meanest + return linear_discrete_update(meanest, cpred, data, meascov, jacob, mpred) diff --git a/src/probnum/filtsmooth/gaussfiltsmooth/gaussfiltsmooth.py b/src/probnum/filtsmooth/gaussfiltsmooth/gaussfiltsmooth.py index beeaea3e1..0579a948d 100644 --- a/src/probnum/filtsmooth/gaussfiltsmooth/gaussfiltsmooth.py +++ b/src/probnum/filtsmooth/gaussfiltsmooth/gaussfiltsmooth.py @@ -163,8 +163,7 @@ def smooth_step(self, unsmoothed_rv, pred_rv, smoothed_rv, crosscov): if np.isscalar(predmean) and np.isscalar(predcov): predmean = predmean * np.ones(1) predcov = predcov * np.eye(1) - res = currmean - predmean - newmean = initmean + crosscov @ np.linalg.solve(predcov, res) + newmean = initmean + crosscov @ np.linalg.solve(predcov, currmean - predmean) firstsolve = crosscov @ np.linalg.solve(predcov, currcov - predcov) secondsolve = crosscov @ np.linalg.solve(predcov, firstsolve.T) newcov = initcov + secondsolve.T @@ -177,3 +176,12 @@ def predict(self, start, stop, randvar, **kwargs): @abstractmethod def update(self, time, randvar, data, **kwargs): raise NotImplementedError + + +def linear_discrete_update(meanest, cpred, data, meascov, measmat, mpred): + """Kalman update, potentially after linearization.""" + covest = measmat @ cpred @ measmat.T + meascov + ccest = cpred @ measmat.T + mean = mpred + ccest @ np.linalg.solve(covest, data - meanest) + cov = cpred - ccest @ np.linalg.solve(covest.T, ccest.T) + return RandomVariable(distribution=Normal(mean, cov)), covest, ccest, meanest diff --git a/src/probnum/filtsmooth/gaussfiltsmooth/kalman.py b/src/probnum/filtsmooth/gaussfiltsmooth/kalman.py index fc7eade19..465b88dd4 100644 --- a/src/probnum/filtsmooth/gaussfiltsmooth/kalman.py +++ b/src/probnum/filtsmooth/gaussfiltsmooth/kalman.py @@ -4,15 +4,18 @@ """ import numpy as np -from probnum.filtsmooth.gaussfiltsmooth.gaussfiltsmooth import GaussFiltSmooth +from probnum.filtsmooth.gaussfiltsmooth.gaussfiltsmooth import ( + GaussFiltSmooth, + linear_discrete_update, +) from probnum.prob import RandomVariable, Normal from probnum.filtsmooth.statespace import ( - ContinuousModel, - DiscreteModel, LinearSDEModel, DiscreteGaussianLinearModel, ) +from probnum.filtsmooth.gaussfiltsmooth._utils import is_cont_disc, is_disc_disc + class Kalman(GaussFiltSmooth): """ @@ -29,9 +32,9 @@ def __new__(cls, dynamod, measmod, initrv, **kwargs): discrete-discrete Kalman object is created. """ if cls is Kalman: - if _cont_disc(dynamod, measmod): + if is_cont_disc(dynamod, measmod): return _ContDiscKalman(dynamod, measmod, initrv, **kwargs) - if _disc_disc(dynamod, measmod): + if is_disc_disc(dynamod, measmod): return _DiscDiscKalman(dynamod, measmod, initrv) else: errmsg = ( @@ -43,20 +46,6 @@ def __new__(cls, dynamod, measmod, initrv, **kwargs): return super().__new__(cls) -def _cont_disc(dynamod, measmod): - """Checks whether the state space model is continuous-discrete.""" - dyna_is_cont = issubclass(type(dynamod), ContinuousModel) - meas_is_disc = issubclass(type(measmod), DiscreteModel) - return dyna_is_cont and meas_is_disc - - -def _disc_disc(dynamod, measmod): - """Checks whether the state space model is discrete-discrete.""" - dyna_is_disc = issubclass(type(dynamod), DiscreteModel) - meas_is_disc = issubclass(type(measmod), DiscreteModel) - return dyna_is_disc and meas_is_disc - - class _ContDiscKalman(Kalman): """ Provides predict() and update() methods for Kalman filtering and @@ -69,11 +58,11 @@ def __init__(self, dynamod, measmod, initrv, **kwargs): """ if not issubclass(type(dynamod), LinearSDEModel): raise ValueError( - "ContinuosDiscreteKalman requires " "a linear dynamic model." + "ContinuousDiscreteKalman requires a linear dynamic model." ) if not issubclass(type(measmod), DiscreteGaussianLinearModel): raise ValueError( - "DiscreteDiscreteKalman requires " "a linear measurement model." + "ContinuousDiscreteKalman requires a linear measurement model." ) if "cke_nsteps" in kwargs.keys(): self.cke_nsteps = kwargs["cke_nsteps"] @@ -137,8 +126,4 @@ def _discrete_kalman_update(time, randvar, data, measurementmodel, **kwargs): measmat = measurementmodel.dynamicsmatrix(time, **kwargs) meascov = measurementmodel.diffusionmatrix(time, **kwargs) meanest = measmat @ mpred - covest = measmat @ cpred @ measmat.T + meascov - ccest = cpred @ measmat.T - mean = mpred + ccest @ np.linalg.solve(covest, data - meanest) - cov = cpred - ccest @ np.linalg.solve(covest.T, ccest.T) - return (RandomVariable(distribution=Normal(mean, cov)), covest, ccest, meanest) + return linear_discrete_update(meanest, cpred, data, meascov, measmat, mpred) diff --git a/src/probnum/filtsmooth/gaussfiltsmooth/unscentedkalman.py b/src/probnum/filtsmooth/gaussfiltsmooth/unscentedkalman.py index 13574f003..c7364d8bc 100644 --- a/src/probnum/filtsmooth/gaussfiltsmooth/unscentedkalman.py +++ b/src/probnum/filtsmooth/gaussfiltsmooth/unscentedkalman.py @@ -7,7 +7,10 @@ import numpy as np -from probnum.filtsmooth.gaussfiltsmooth.gaussfiltsmooth import GaussFiltSmooth +from probnum.filtsmooth.gaussfiltsmooth.gaussfiltsmooth import ( + GaussFiltSmooth, + linear_discrete_update, +) from probnum.prob import RandomVariable, Normal from probnum.filtsmooth.gaussfiltsmooth.unscentedtransform import UnscentedTransform from probnum.filtsmooth.statespace import ( @@ -158,11 +161,7 @@ def _update_discrete_linear(time, randvar, data, measmod, **kwargs): measmat = measmod.dynamicsmatrix(time, **kwargs) meascov = measmod.diffusionmatrix(time, **kwargs) meanest = measmat @ mpred - covest = measmat @ cpred @ measmat.T + meascov - ccest = cpred @ measmat.T - mean = mpred + ccest @ np.linalg.solve(covest, data - meanest) - cov = cpred - ccest @ np.linalg.solve(covest.T, ccest.T) - return RandomVariable(distribution=Normal(mean, cov)), covest, ccest, meanest + return linear_discrete_update(meanest, cpred, data, meascov, measmat, mpred) def _update_discrete_nonlinear(time, randvar, data, measmod, ut, **kwargs): diff --git a/src/probnum/filtsmooth/statespace/__init__.py b/src/probnum/filtsmooth/statespace/__init__.py index 38a653c0a..883ee3df0 100644 --- a/src/probnum/filtsmooth/statespace/__init__.py +++ b/src/probnum/filtsmooth/statespace/__init__.py @@ -1,16 +1,3 @@ from .continuous import * from .discrete import * from .statespace import * - -# Public classes and functions. Order is reflected in documentation. -__all__ = [ - "ContinuousModel", - "LinearSDEModel", - "LTISDEModel", - "DiscreteModel", - "DiscreteGaussianModel", - "DiscreteGaussianLinearModel", - "DiscreteGaussianLTIModel", - "generate_cd", - "generate_dd", -] diff --git a/tox.ini b/tox.ini index 8a1c294d0..f47275c62 100644 --- a/tox.ini +++ b/tox.ini @@ -57,7 +57,7 @@ ignore_errors = true commands = # pylint src pylint src/probnum/diffeq --disable="protected-access,abstract-class-instantiated,too-many-locals,too-few-public-methods,too-many-arguments,unused-argument,missing-module-docstring,missing-function-docstring" - pylint src/probnum/filtsmooth --disable="duplicate-code,protected-access,no-self-use,too-many-locals,arguments-differ,too-many-arguments,unused-argument,missing-module-docstring,missing-function-docstring" + pylint src/probnum/filtsmooth --disable="duplicate-code,no-self-use,too-many-locals,too-many-arguments,unused-argument,missing-module-docstring,missing-function-docstring" pylint src/probnum/linalg --disable="attribute-defined-outside-init,too-many-statements,too-many-instance-attributes,too-complex,protected-access,too-many-lines,no-self-use,too-many-locals,redefined-builtin,arguments-differ,abstract-method,too-many-arguments,too-many-branches,duplicate-code,unused-argument,fixme,missing-module-docstring" pylint src/probnum/prob --disable="too-many-instance-attributes,broad-except,arguments-differ,abstract-method,too-many-arguments,protected-access,duplicate-code,unused-argument,fixme,missing-module-docstring,missing-function-docstring,raise-missing-from" pylint src/probnum/quad --disable="attribute-defined-outside-init,too-few-public-methods,redefined-builtin,arguments-differ,unused-argument,missing-module-docstring"