Skip to content

Commit

Permalink
Merge pull request #179 from pnkraemer/refactor_diffeq
Browse files Browse the repository at this point in the history
Refactor diffeq
  • Loading branch information
pnkraemer authored Aug 26, 2020
2 parents 146037c + fb526a9 commit 0b85017
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 19 deletions.
16 changes: 8 additions & 8 deletions src/probnum/diffeq/odefiltsmooth/ivpfiltsmooth.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,8 @@ class GaussianIVPFilter(odesolver.ODESolver):
further considerations to, e.g., BVPs.
"""

def __init__(self, ivp, gaussfilt, steprl):
def __init__(self, ivp, gaussfilt):
"""
steprule : stepsize rule
gaussfilt : gaussianfilter.GaussianFilter object,
e.g. the return value of ivp_to_ukf(), ivp_to_ekf1().
Expand All @@ -37,9 +36,8 @@ def __init__(self, ivp, gaussfilt, steprl):
raise ValueError("Please initialise a Gaussian filter with an ODEPrior")
self.ivp = ivp
self.gfilt = gaussfilt
odesolver.ODESolver.__init__(self, steprl)

def solve(self, firststep, **kwargs):
def solve(self, firststep, steprule, **kwargs):
"""
Solve IVP and calibrates uncertainty according
to Proposition 4 in Tronarp et al.
Expand All @@ -48,6 +46,8 @@ def solve(self, firststep, **kwargs):
----------
firststep : float
First step for adaptive step size rule.
steprule : StepRule
Step-size selection rule, e.g. constant steps or adaptive steps.
"""
current_rv = self.gfilt.initialrandomvariable
t = self.ivp.t0
Expand All @@ -70,15 +70,15 @@ def solve(self, firststep, **kwargs):
filt_rv.mean(), crosscov, meas_cov, meas_mean
)

if self.steprule.is_accepted(step, errorest):
if steprule.is_accepted(step, errorest):
times.append(t_new)
rvs.append(filt_rv)
num_steps += 1
ssqest = ssqest + (ssq - ssqest) / num_steps
current_rv = filt_rv
t = t_new

step = self._suggest_step(step, errorest)
step = self._suggest_step(step, errorest, steprule)
step = min(step, self.ivp.tmax - t)

rvs = [
Expand Down Expand Up @@ -153,7 +153,7 @@ def _rel_and_abs_error(self, abserrors, currmn):
abs_error = abserrors @ weights / np.linalg.norm(weights)
return np.maximum(rel_error, abs_error)

def _suggest_step(self, step, errorest):
def _suggest_step(self, step, errorest, steprule):
"""
Suggests step according to steprule and warns if
step is extremely small.
Expand All @@ -163,7 +163,7 @@ def _suggest_step(self, step, errorest):
RuntimeWarning
If suggested step is smaller than :math:`10^{-15}`.
"""
step = self.steprule.suggest(step, errorest)
step = steprule.suggest(step, errorest)
if step < 1e-15:
warnmsg = "Stepsize is num. zero (%.1e)" % step
warnings.warn(message=warnmsg, category=RuntimeWarning)
Expand Down
9 changes: 5 additions & 4 deletions src/probnum/diffeq/odefiltsmooth/odefiltsmooth.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,16 +225,17 @@ def probsolve_ivp(
[0.97947631]
[0.98614541]]
"""
solver, firststep = _create_solver_object(
gfilt, firststep, stprl = _create_solver_inputs(
ivp, method, which_prior, tol, step, firststep, precond_step, **kwargs
)
solution = solver.solve(firststep=firststep, **kwargs)
solver = GaussianIVPFilter(ivp, gfilt)
solution = solver.solve(firststep=firststep, steprule=stprl, **kwargs)
if method in ["eks0", "eks1", "uks"]:
solution = solver.odesmooth(solution, **kwargs)
return solution


def _create_solver_object(
def _create_solver_inputs(
ivp, method, which_prior, tol, step, firststep, precond_step, **kwargs
):
"""Create the solver object that is used."""
Expand All @@ -251,7 +252,7 @@ def _create_solver_object(
stprl = steprule.ConstantSteps(step)
firststep = step
gfilt = _string2filter(ivp, _prior, method, **kwargs)
return GaussianIVPFilter(ivp, gfilt, stprl), firststep
return gfilt, firststep, stprl


def _check_step_tol(step, tol):
Expand Down
7 changes: 0 additions & 7 deletions src/probnum/diffeq/odesolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,6 @@ class ODESolver(ABC):
Interface for ODESolver.
"""

def __init__(self, steprule):
"""
An ODESolver is an object governed by a stepsize rule.
That is: constant steps or adaptive steps.
"""
self.steprule = steprule

@abstractmethod
def solve(self, ivp, minstep, maxstep, **kwargs):
"""
Expand Down

0 comments on commit 0b85017

Please sign in to comment.