Skip to content

Commit

Permalink
Fix pylint errors for _odesolution.py and _odesolver.py (#609)
Browse files Browse the repository at this point in the history
* renamed callback module import

* made time_stopper a function

* disable no self use for postprocess8)

* long lines with break now in odesolution

* removed superfluous tests

* deleted redundant test suite

* changed time_stopper arguments

* another attempt at the time stopper

* deleted dead code
  • Loading branch information
pnkraemer authored Jan 19, 2022
1 parent c5db35a commit 5023c41
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 118 deletions.
15 changes: 10 additions & 5 deletions src/probnum/diffeq/_odesolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
This object is returned by ODESolver.solve().
Provides dense output (by being callable), is sliceable, and collects the time-grid as well as the discrete-time solution.
Provides dense output (by being callable), is sliceable,
and collects the time-grid as well as the discrete-time solution.
"""

from typing import Optional
Expand Down Expand Up @@ -70,10 +71,14 @@ def sample(
Random number generator.
t
Location / time at which to sample.
If nothing is specified, samples at the ODE-solver grid points are computed.
If it is a float, a sample of the ODE-solution at this time point is computed.
Similarly, if it is a list of floats (or an array), samples at the specified grid-points are returned.
This is not the same as computing i.i.d samples at the respective locations.
If nothing is specified, samples at the ODE-solver
grid points are computed.
If it is a float, a sample of the ODE-solution
at this time point is computed.
Similarly, if it is a list of floats (or an array),
samples at the specified grid-points are returned.
This is not the same as computing i.i.d samples at the respective
locations.
size
Number of samples.
"""
Expand Down
63 changes: 37 additions & 26 deletions src/probnum/diffeq/_odesolver.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,23 @@
"""ODE solver interface."""

import dataclasses
from abc import ABC, abstractmethod
from collections import abc
from typing import Iterable, Optional, Union
from typing import Iterable, Iterator, Optional, Union

import numpy as np

from probnum import problems
from probnum.diffeq import callbacks
from probnum.diffeq import callbacks as callback_module # see below
from probnum.typing import FloatLike

CallbackType = Union[callbacks.ODESolverCallback, Iterable[callbacks.ODESolverCallback]]
# From above:
# One of the argument to solve() is called 'callback',
# and we do not want to redefine variable namespaces.

CallbackType = Union[
callback_module.ODESolverCallback, Iterable[callback_module.ODESolverCallback]
]
"""Callback interface type."""


Expand Down Expand Up @@ -59,7 +66,7 @@ def solution_generator(
):
"""Generate ODE solver steps."""

callbacks, time_stopper = self._process_event_inputs(callbacks, stop_at)
callbacks, stopper_state = self._process_event_inputs(callbacks, stop_at)

state = self.initialize(ivp)
yield state
Expand All @@ -68,8 +75,8 @@ def solution_generator(

# Use state.ivp in case a callback modifies the IVP
while state.t < state.ivp.tmax:
if time_stopper is not None:
dt = time_stopper.adjust_dt_to_time_stops(state.t, dt)
if stopper_state is not None:
dt, stopper_state = _adjust_time_step(stopper_state, t=state.t, dt=dt)

state, dt = self.perform_full_step(state, dt)

Expand All @@ -90,7 +97,10 @@ def promote_callback_type(cbs):
if callbacks is not None:
callbacks = promote_callback_type(callbacks)
if stop_at_locations is not None:
time_stopper = _TimeStopper(stop_at_locations)
loc_iter = iter(stop_at_locations)
time_stopper = _TimeStopperState(
locations=loc_iter, next_location=next(loc_iter)
)
else:
time_stopper = None
return callbacks, time_stopper
Expand Down Expand Up @@ -123,7 +133,6 @@ def perform_full_step(self, state, initial_dt):
else:
dt = min(suggested_dt, state.ivp.tmax - state.t)

# This line of code is unnecessary?!
self.method_callback(state)
return proposed_state, dt

Expand All @@ -147,7 +156,9 @@ def rvlist_to_odesol(self, times, rvs):
"""Create an ODESolution object."""
raise NotImplementedError

def postprocess(self, odesol):
# We disable the pylint warning, because subclasses _do_ use 'self'
# but pylint does not seem to realize this.
def postprocess(self, odesol): # pylint: disable="no-self-use"
"""Process the ODESolution object before returning."""
return odesol

Expand All @@ -158,25 +169,25 @@ def method_callback(self, state):
current guess is accepted, but before storing it. No return. For
example: tune hyperparameters (sigma).
"""
pass


class _TimeStopper:
"""Make the ODE solver stop at specified time-points."""

def __init__(self, locations: Iterable):
self._locations = iter(locations)
self._next_location = next(self._locations)
@dataclasses.dataclass
class _TimeStopperState:
locations: Iterator
next_location: FloatLike

def adjust_dt_to_time_stops(self, t, dt):
"""Check whether the next time-point is supposed to be stopped at."""

if t >= self._next_location:
try:
self._next_location = next(self._locations)
except StopIteration:
self._next_location = np.inf
def _adjust_time_step(stopper_state, t, dt):
if t >= stopper_state.next_location:
try:
next_location = next(stopper_state.locations)
except StopIteration:
next_location = np.inf
else:
next_location = stopper_state.next_location

if t + dt > self._next_location:
dt = self._next_location - t
return dt
if t + dt > next_location:
dt = next_location - t
return dt, _TimeStopperState(
locations=stopper_state.locations, next_location=next_location
)
87 changes: 0 additions & 87 deletions tests/test_diffeq/test_odesolver.py

This file was deleted.

0 comments on commit 5023c41

Please sign in to comment.