diff --git a/src/probnum/diffeq/_odesolution.py b/src/probnum/diffeq/_odesolution.py index 8e4db2e0c..38f8a69e8 100644 --- a/src/probnum/diffeq/_odesolution.py +++ b/src/probnum/diffeq/_odesolution.py @@ -2,7 +2,8 @@ This object is returned by ODESolver.solve(). -Provides dense output (by being callable), is sliceable, and collects the time-grid as well as the discrete-time solution. +Provides dense output (by being callable), is sliceable, +and collects the time-grid as well as the discrete-time solution. """ from typing import Optional @@ -70,10 +71,14 @@ def sample( Random number generator. t Location / time at which to sample. - If nothing is specified, samples at the ODE-solver grid points are computed. - If it is a float, a sample of the ODE-solution at this time point is computed. - Similarly, if it is a list of floats (or an array), samples at the specified grid-points are returned. - This is not the same as computing i.i.d samples at the respective locations. + If nothing is specified, samples at the ODE-solver + grid points are computed. + If it is a float, a sample of the ODE-solution + at this time point is computed. + Similarly, if it is a list of floats (or an array), + samples at the specified grid-points are returned. + This is not the same as computing i.i.d samples at the respective + locations. size Number of samples. """ diff --git a/src/probnum/diffeq/_odesolver.py b/src/probnum/diffeq/_odesolver.py index 777430dbd..3eb1e9e9b 100644 --- a/src/probnum/diffeq/_odesolver.py +++ b/src/probnum/diffeq/_odesolver.py @@ -1,16 +1,23 @@ """ODE solver interface.""" +import dataclasses from abc import ABC, abstractmethod from collections import abc -from typing import Iterable, Optional, Union +from typing import Iterable, Iterator, Optional, Union import numpy as np from probnum import problems -from probnum.diffeq import callbacks +from probnum.diffeq import callbacks as callback_module # see below from probnum.typing import FloatLike -CallbackType = Union[callbacks.ODESolverCallback, Iterable[callbacks.ODESolverCallback]] +# From above: +# One of the argument to solve() is called 'callback', +# and we do not want to redefine variable namespaces. + +CallbackType = Union[ + callback_module.ODESolverCallback, Iterable[callback_module.ODESolverCallback] +] """Callback interface type.""" @@ -59,7 +66,7 @@ def solution_generator( ): """Generate ODE solver steps.""" - callbacks, time_stopper = self._process_event_inputs(callbacks, stop_at) + callbacks, stopper_state = self._process_event_inputs(callbacks, stop_at) state = self.initialize(ivp) yield state @@ -68,8 +75,8 @@ def solution_generator( # Use state.ivp in case a callback modifies the IVP while state.t < state.ivp.tmax: - if time_stopper is not None: - dt = time_stopper.adjust_dt_to_time_stops(state.t, dt) + if stopper_state is not None: + dt, stopper_state = _adjust_time_step(stopper_state, t=state.t, dt=dt) state, dt = self.perform_full_step(state, dt) @@ -90,7 +97,10 @@ def promote_callback_type(cbs): if callbacks is not None: callbacks = promote_callback_type(callbacks) if stop_at_locations is not None: - time_stopper = _TimeStopper(stop_at_locations) + loc_iter = iter(stop_at_locations) + time_stopper = _TimeStopperState( + locations=loc_iter, next_location=next(loc_iter) + ) else: time_stopper = None return callbacks, time_stopper @@ -123,7 +133,6 @@ def perform_full_step(self, state, initial_dt): else: dt = min(suggested_dt, state.ivp.tmax - state.t) - # This line of code is unnecessary?! self.method_callback(state) return proposed_state, dt @@ -147,7 +156,9 @@ def rvlist_to_odesol(self, times, rvs): """Create an ODESolution object.""" raise NotImplementedError - def postprocess(self, odesol): + # We disable the pylint warning, because subclasses _do_ use 'self' + # but pylint does not seem to realize this. + def postprocess(self, odesol): # pylint: disable="no-self-use" """Process the ODESolution object before returning.""" return odesol @@ -158,25 +169,25 @@ def method_callback(self, state): current guess is accepted, but before storing it. No return. For example: tune hyperparameters (sigma). """ - pass - -class _TimeStopper: - """Make the ODE solver stop at specified time-points.""" - def __init__(self, locations: Iterable): - self._locations = iter(locations) - self._next_location = next(self._locations) +@dataclasses.dataclass +class _TimeStopperState: + locations: Iterator + next_location: FloatLike - def adjust_dt_to_time_stops(self, t, dt): - """Check whether the next time-point is supposed to be stopped at.""" - if t >= self._next_location: - try: - self._next_location = next(self._locations) - except StopIteration: - self._next_location = np.inf +def _adjust_time_step(stopper_state, t, dt): + if t >= stopper_state.next_location: + try: + next_location = next(stopper_state.locations) + except StopIteration: + next_location = np.inf + else: + next_location = stopper_state.next_location - if t + dt > self._next_location: - dt = self._next_location - t - return dt + if t + dt > next_location: + dt = next_location - t + return dt, _TimeStopperState( + locations=stopper_state.locations, next_location=next_location + ) diff --git a/tests/test_diffeq/test_odesolver.py b/tests/test_diffeq/test_odesolver.py deleted file mode 100644 index d1640b77e..000000000 --- a/tests/test_diffeq/test_odesolver.py +++ /dev/null @@ -1,87 +0,0 @@ -import unittest - -import numpy as np -import pytest - -from probnum import diffeq -from probnum.problems.zoo.diffeq import logistic -from probnum.randvars import Constant - - -class MockODESolver(diffeq.ODESolver): - """Euler method as an ODE solver.""" - - def initialize(self, ivp): - return diffeq.ODESolverState( - ivp=ivp, - rv=Constant(ivp.y0), - t=ivp.t0, - error_estimate=np.nan, - reference_state=None, - ) - - def attempt_step(self, state, dt): - t, x = state.t, state.rv.mean - xnew = x + dt * state.ivp.f(t, x) - - # return nan as error estimate to ensure that it is not used - new_state = diffeq.ODESolverState( - ivp=state.ivp, - rv=Constant(xnew), - t=t + dt, - error_estimate=np.nan, - reference_state=xnew, - ) - return new_state - - def rvlist_to_odesol(self, times, rvs): - return diffeq.ODESolution(locations=times, states=rvs) - - -class ODESolverTestCase(unittest.TestCase): - """An ODE Solver has to work with just step() and initialise() provided. - - We implement Euler in MockODESolver to assure this. - """ - - def setUp(self): - step = 0.2 - steprule = diffeq.stepsize.ConstantSteps(step) - euler_order = 1 - self.solver = MockODESolver(steprule=steprule, order=euler_order) - - def test_solve(self): - y0 = np.array([0.3]) - ivp = logistic(t0=0, tmax=4, y0=y0) - odesol = self.solver.solve( - ivp=ivp, - ) # this is the actual part of the test - - # quick check that the result is sensible - self.assertAlmostEqual(odesol.locations[-1], ivp.tmax) - self.assertAlmostEqual(odesol.states[-1].mean[0], 1.0, places=2) - - -class TestTimeStopper: - @pytest.fixture(autouse=True) - def _setup(self): - self.time_stops = [2.0, 3.0, 4.0, 5.0] - self.discrete_events = diffeq._odesolver._TimeStopper(locations=self.time_stops) - - def dummy_perform_step(state, dt, steprule): - return state, dt - - self.dummy_perform_step = dummy_perform_step - - def test_adjust_dt_to_time_stops(self): - # Should interfere dt to 0.1 instead of 5.0, because 2 is in self.time_stops - dt = self.discrete_events.adjust_dt_to_time_stops(t=1.9, dt=5.0) - assert dt == pytest.approx(0.1) - - # Should not interfere dt if there is no proximity to an event - dt = self.discrete_events.adjust_dt_to_time_stops(t=1.0, dt=0.00001) - assert dt == pytest.approx(0.00001) - - -if __name__ == "__main__": - unittest.main()