Skip to content

Commit

Permalink
Merge pull request #183 from pnkraemer/refactor_odesolver
Browse files Browse the repository at this point in the history
Refactor odesolver
  • Loading branch information
pnkraemer authored Aug 28, 2020
2 parents 0830891 + 2c3a5a7 commit 19c93a1
Show file tree
Hide file tree
Showing 5 changed files with 201 additions and 79 deletions.
112 changes: 45 additions & 67 deletions src/probnum/diffeq/odefiltsmooth/ivpfiltsmooth.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import warnings
import numpy as np

from probnum.prob import RandomVariable
Expand All @@ -19,7 +18,7 @@ class GaussianIVPFilter(odesolver.ODESolver):
further considerations to, e.g., BVPs.
"""

def __init__(self, ivp, gaussfilt):
def __init__(self, ivp, gaussfilt, with_smoothing):
"""
gaussfilt : gaussianfilter.GaussianFilter object,
e.g. the return value of ivp_to_ukf(), ivp_to_ekf1().
Expand All @@ -34,61 +33,56 @@ def __init__(self, ivp, gaussfilt):
"""
if not issubclass(type(gaussfilt.dynamicmodel), ODEPrior):
raise ValueError("Please initialise a Gaussian filter with an ODEPrior")
self.ivp = ivp
self.gfilt = gaussfilt
self.sigma_squared_global = 0.0
self.sigma_squared_current = 0.0
self.with_smoothing = with_smoothing
super().__init__(ivp)

def initialise(self):
return self.ivp.t0, self.gfilt.initialrandomvariable

def step(self, t, t_new, current_rv, **kwargs):
"""Gaussian IVP filter step as nonlinear Kalman filtering with zero data."""
pred_rv, _ = self.gfilt.predict(t, t_new, current_rv, **kwargs)
zero_data = 0.0
filt_rv, meas_cov, crosscov, meas_mean = self.gfilt.update(
t_new, pred_rv, zero_data, **kwargs
)
errorest, self.sigma_squared_current = self._estimate_error(
filt_rv.mean(), crosscov, meas_cov, meas_mean
)
return filt_rv, errorest

def solve(self, firststep, steprule, **kwargs):
"""
Solve IVP and calibrates uncertainty according
to Proposition 4 in Tronarp et al.
def method_callback(self, time, current_guess, current_error):
"""Update the sigma-squared (ssq) estimate."""
self.sigma_squared_global = (
self.sigma_squared_global
+ (self.sigma_squared_current - self.sigma_squared_global) / self.num_steps
)

Parameters
----------
firststep : float
First step for adaptive step size rule.
steprule : StepRule
Step-size selection rule, e.g. constant steps or adaptive steps.
def postprocess(self, times, rvs):
"""
current_rv = self.gfilt.initialrandomvariable
t = self.ivp.t0
times = [t]
rvs = [current_rv]
step = firststep
ssqest, num_steps = 0.0, 0

while t < self.ivp.tmax:

t_new = t + step
pred_rv, _ = self.gfilt.predict(t, t_new, current_rv, **kwargs)

zero_data = 0.0
filt_rv, meas_cov, crosscov, meas_mean = self.gfilt.update(
t_new, pred_rv, zero_data, **kwargs
)

errorest, ssq = self._estimate_error(
filt_rv.mean(), crosscov, meas_cov, meas_mean
)

if steprule.is_accepted(step, errorest):
times.append(t_new)
rvs.append(filt_rv)
num_steps += 1
ssqest = ssqest + (ssq - ssqest) / num_steps
current_rv = filt_rv
t = t_new

step = self._suggest_step(step, errorest, steprule)
step = min(step, self.ivp.tmax - t)

Rescale covariances with sigma square estimate,
(if specified) smooth the estimate, return ODESolution.
"""
rvs = self._rescale(rvs)
odesol = super().postprocess(times, rvs)
if self.with_smoothing is True:
odesol = self._odesmooth(ode_solution=odesol)
return odesol

def _rescale(self, rvs):
"""Rescales covariances according to estimate sigma squared value."""
rvs = [
RandomVariable(distribution=Normal(rv.mean(), ssqest * rv.cov()))
RandomVariable(
distribution=Normal(rv.mean(), self.sigma_squared_global * rv.cov())
)
for rv in rvs
]
return rvs

return ODESolution(times, rvs, self)

def odesmooth(self, filter_solution, **kwargs):
def _odesmooth(self, ode_solution, **kwargs):
"""
Smooth out the ODE-Filter output.
Expand All @@ -103,13 +97,13 @@ def odesmooth(self, filter_solution, **kwargs):
-------
smoothed_solution: ODESolution
"""
ivp_filter_posterior = filter_solution._kalman_posterior
ivp_filter_posterior = ode_solution._kalman_posterior
ivp_smoother_posterior = self.gfilt.smooth(ivp_filter_posterior, **kwargs)

smoothed_solution = ODESolution(
times=ivp_smoother_posterior.locations,
rvs=ivp_smoother_posterior.state_rvs,
solver=filter_solution._solver,
solver=ode_solution._solver,
)

return smoothed_solution
Expand Down Expand Up @@ -153,22 +147,6 @@ def _rel_and_abs_error(self, abserrors, currmn):
abs_error = abserrors @ weights / np.linalg.norm(weights)
return np.maximum(rel_error, abs_error)

def _suggest_step(self, step, errorest, steprule):
"""
Suggests step according to steprule and warns if
step is extremely small.
Raises
------
RuntimeWarning
If suggested step is smaller than :math:`10^{-15}`.
"""
step = steprule.suggest(step, errorest)
if step < 1e-15:
warnmsg = "Stepsize is num. zero (%.1e)" % step
warnings.warn(message=warnmsg, category=RuntimeWarning)
return step

@property
def prior(self):
return self.gfilt.dynamicmodel
5 changes: 2 additions & 3 deletions src/probnum/diffeq/odefiltsmooth/odefiltsmooth.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,10 +228,9 @@ def probsolve_ivp(
gfilt, firststep, stprl = _create_solver_inputs(
ivp, method, which_prior, tol, step, firststep, precond_step, **kwargs
)
solver = GaussianIVPFilter(ivp, gfilt)
with_smoothing = method[-2] == "s"
solver = GaussianIVPFilter(ivp, gfilt, with_smoothing=with_smoothing)
solution = solver.solve(firststep=firststep, steprule=stprl, **kwargs)
if method in ["eks0", "eks1", "uks"]:
solution = solver.odesmooth(solution, **kwargs)
return solution


Expand Down
24 changes: 20 additions & 4 deletions src/probnum/diffeq/odesolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,17 @@ class ODESolution(FiltSmoothPosterior):
"""

def __init__(self, times, rvs, solver):
self._kalman_posterior = KalmanPosterior(times, rvs, solver.gfilt)

# try-except is a hotfix for now:
# future PR is to move KalmanPosterior-info out of here, into GaussianIVPFilter
try:
self._kalman_posterior = KalmanPosterior(times, rvs, solver.gfilt)
self._t = None
self._y = None
except AttributeError:
self._kalman_posterior = None
self._t = times
self._y = _RandomVariableList(rvs)
self._solver = solver

def _proj_normal_rv(self, rv, coord):
Expand All @@ -84,7 +94,10 @@ def _proj_normal_rv(self, rv, coord):
@property
def t(self):
""":obj:`np.ndarray`: Times of the discrete-time solution"""
return self._kalman_posterior.locations
if self._t: # hotfix
return self._t
else:
return self._kalman_posterior.locations

@property
def y(self):
Expand All @@ -95,8 +108,11 @@ def y(self):
as a list of random variables.
To return means and covariances use ``y.mean()`` and ``y.cov()``.
"""
function_rvs = [self._proj_normal_rv(rv, 0) for rv in self._state_rvs]
return _RandomVariableList(function_rvs)
if self._y: # hotfix
return self._y
else:
function_rvs = [self._proj_normal_rv(rv, 0) for rv in self._state_rvs]
return _RandomVariableList(function_rvs)

@property
def dy(self):
Expand Down
90 changes: 85 additions & 5 deletions src/probnum/diffeq/odesolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,98 @@
"""

from abc import ABC, abstractmethod
import warnings

from probnum.diffeq.odesolution import ODESolution


class ODESolver(ABC):
"""
Interface for ODESolver.
"""

@abstractmethod
def solve(self, ivp, minstep, maxstep, **kwargs):
def __init__(self, ivp):
self.ivp = ivp
self.num_steps = 0

def solve(self, firststep, steprule, **kwargs):
"""
Every ODE solver has a solve() method.
Optional: callback function. Allows e.g. printing variables
at runtime.
Solve an IVP.
Parameters
----------
firststep : float
First step for adaptive step-size rule.
steprule : :class:`StepRule`
Step-size selection rule, e.g. constant steps or adaptive steps.
"""
t, current_rv = self.initialise()
times, rvs = [t], [current_rv]
stepsize = firststep

while t < self.ivp.tmax:

t_new = t + stepsize
proposed_rv, errorest = self.step(t, t_new, current_rv, **kwargs)

if steprule.is_accepted(stepsize, errorest):
self.num_steps += 1
t = t_new
current_rv = proposed_rv
times.append(t)
rvs.append(current_rv)
self.method_callback(
time=t_new, current_guess=proposed_rv, current_error=errorest
)

suggested_stepsize = self._suggest_step(stepsize, errorest, steprule)
stepsize = min(suggested_stepsize, self.ivp.tmax - t)

odesol = self.postprocess(times=times, rvs=rvs)
return odesol

def _suggest_step(self, step, errorest, steprule):
"""
Suggests step according to steprule and warns if step is extremely small.
Raises
------
RuntimeWarning
If suggested step is smaller than :math:`10^{-15}`.
"""
step = steprule.suggest(step, errorest)
if step < 1e-15:
warnmsg = "Stepsize is num. zero (%.1e)" % step
warnings.warn(message=warnmsg, category=RuntimeWarning)
return step

@abstractmethod
def initialise(self):
"""Returns t0 and y0 (for the solver, which might be different to ivp.y0)"""
raise NotImplementedError

@abstractmethod
def step(self, start, stop, current, **kwargs):
"""Every ODE solver needs a step() method that returns a new random variable and an error estimate"""
raise NotImplementedError

def postprocess(self, times, rvs):
"""
Turn list of random variables into an ODE solution object and potentially do more.
Overwrite for instance via
>>> def postprocess(self, times, rvs):
>>> # do something with times and rvs
>>> odesol = super().postprocess(times, rvs)
>>> # do something with odesol
>>> return odesol
"""
return ODESolution(times, rvs, self)

def method_callback(self, time, current_guess, current_error):
"""
Optional callback. Can be overwritten.
Do this as soon as it is clear that the current guess is accepted, but before storing it.
No return. For example: tune hyperparameters (sigma).
"""
pass
49 changes: 49 additions & 0 deletions tests/test_diffeq/test_odesolver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import unittest

import numpy as np

from probnum.diffeq import logistic, ODESolver, ConstantSteps
from probnum.prob import RandomVariable, Dirac


class MockODESolver(ODESolver):
"""Euler method as an ODE solver"""

def initialise(self):
return self.ivp.t0, self.ivp.initrv

def step(self, start, stop, current):
h = stop - start
x = current.mean()
xnew = x + h * self.ivp(start, x)
return (
RandomVariable(Dirac(xnew)),
np.nan,
) # return nan as error estimate to ensure that it is not used


class ODESolverTestCase(unittest.TestCase):
"""
An ODE Solver has to work with just step() and initialise() provided.
We implement Euler in MockODESolver to assure this.
"""

def setUp(self):
y0 = RandomVariable(distribution=Dirac(0.3))
ivp = logistic([0, 4], initrv=y0)
self.solver = MockODESolver(ivp)
self.step = 0.2

def test_solve(self):
steprule = ConstantSteps(self.step)
odesol = self.solver.solve(
firststep=self.step, steprule=steprule
) # this is the actual part of the test

# quick check that the result is sensible
self.assertAlmostEqual(odesol.t[-1], self.solver.ivp.tmax)
self.assertAlmostEqual(odesol.y[-1].mean(), 1.0, places=2)


if __name__ == "__main__":
unittest.main()

0 comments on commit 19c93a1

Please sign in to comment.