diff --git a/src/probnum/diffeq/odefiltsmooth/ivpfiltsmooth.py b/src/probnum/diffeq/odefiltsmooth/ivpfiltsmooth.py index acde3c754..2ad69699b 100644 --- a/src/probnum/diffeq/odefiltsmooth/ivpfiltsmooth.py +++ b/src/probnum/diffeq/odefiltsmooth/ivpfiltsmooth.py @@ -19,9 +19,8 @@ class GaussianIVPFilter(odesolver.ODESolver): further considerations to, e.g., BVPs. """ - def __init__(self, ivp, gaussfilt, steprl): + def __init__(self, ivp, gaussfilt): """ - steprule : stepsize rule gaussfilt : gaussianfilter.GaussianFilter object, e.g. the return value of ivp_to_ukf(), ivp_to_ekf1(). @@ -37,9 +36,8 @@ def __init__(self, ivp, gaussfilt, steprl): raise ValueError("Please initialise a Gaussian filter with an ODEPrior") self.ivp = ivp self.gfilt = gaussfilt - odesolver.ODESolver.__init__(self, steprl) - def solve(self, firststep, **kwargs): + def solve(self, firststep, steprule, **kwargs): """ Solve IVP and calibrates uncertainty according to Proposition 4 in Tronarp et al. @@ -48,6 +46,8 @@ def solve(self, firststep, **kwargs): ---------- firststep : float First step for adaptive step size rule. + steprule : StepRule + Step-size selection rule, e.g. constant steps or adaptive steps. """ current_rv = self.gfilt.initialrandomvariable t = self.ivp.t0 @@ -70,7 +70,7 @@ def solve(self, firststep, **kwargs): filt_rv.mean(), crosscov, meas_cov, meas_mean ) - if self.steprule.is_accepted(step, errorest): + if steprule.is_accepted(step, errorest): times.append(t_new) rvs.append(filt_rv) num_steps += 1 @@ -78,7 +78,7 @@ def solve(self, firststep, **kwargs): current_rv = filt_rv t = t_new - step = self._suggest_step(step, errorest) + step = self._suggest_step(step, errorest, steprule) step = min(step, self.ivp.tmax - t) rvs = [ @@ -153,7 +153,7 @@ def _rel_and_abs_error(self, abserrors, currmn): abs_error = abserrors @ weights / np.linalg.norm(weights) return np.maximum(rel_error, abs_error) - def _suggest_step(self, step, errorest): + def _suggest_step(self, step, errorest, steprule): """ Suggests step according to steprule and warns if step is extremely small. @@ -163,7 +163,7 @@ def _suggest_step(self, step, errorest): RuntimeWarning If suggested step is smaller than :math:`10^{-15}`. """ - step = self.steprule.suggest(step, errorest) + step = steprule.suggest(step, errorest) if step < 1e-15: warnmsg = "Stepsize is num. zero (%.1e)" % step warnings.warn(message=warnmsg, category=RuntimeWarning) diff --git a/src/probnum/diffeq/odefiltsmooth/odefiltsmooth.py b/src/probnum/diffeq/odefiltsmooth/odefiltsmooth.py index 3969ee54e..929b16c32 100644 --- a/src/probnum/diffeq/odefiltsmooth/odefiltsmooth.py +++ b/src/probnum/diffeq/odefiltsmooth/odefiltsmooth.py @@ -225,16 +225,17 @@ def probsolve_ivp( [0.97947631] [0.98614541]] """ - solver, firststep = _create_solver_object( + gfilt, firststep, stprl = _create_solver_inputs( ivp, method, which_prior, tol, step, firststep, precond_step, **kwargs ) - solution = solver.solve(firststep=firststep, **kwargs) + solver = GaussianIVPFilter(ivp, gfilt) + solution = solver.solve(firststep=firststep, steprule=stprl, **kwargs) if method in ["eks0", "eks1", "uks"]: solution = solver.odesmooth(solution, **kwargs) return solution -def _create_solver_object( +def _create_solver_inputs( ivp, method, which_prior, tol, step, firststep, precond_step, **kwargs ): """Create the solver object that is used.""" @@ -251,7 +252,7 @@ def _create_solver_object( stprl = steprule.ConstantSteps(step) firststep = step gfilt = _string2filter(ivp, _prior, method, **kwargs) - return GaussianIVPFilter(ivp, gfilt, stprl), firststep + return gfilt, firststep, stprl def _check_step_tol(step, tol): diff --git a/src/probnum/diffeq/odesolver.py b/src/probnum/diffeq/odesolver.py index cd1255d3d..63bc9029a 100644 --- a/src/probnum/diffeq/odesolver.py +++ b/src/probnum/diffeq/odesolver.py @@ -10,13 +10,6 @@ class ODESolver(ABC): Interface for ODESolver. """ - def __init__(self, steprule): - """ - An ODESolver is an object governed by a stepsize rule. - That is: constant steps or adaptive steps. - """ - self.steprule = steprule - @abstractmethod def solve(self, ivp, minstep, maxstep, **kwargs): """