From b6960faccab6c77c38488524fed342e8eb9ca568 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicholas=20Kr=C3=A4mer?= Date: Wed, 26 Aug 2020 12:08:11 +0200 Subject: [PATCH 1/4] break tests first --- src/probnum/diffeq/odefiltsmooth/ivpfiltsmooth.py | 13 ++++++------- src/probnum/diffeq/odesolver.py | 7 ------- 2 files changed, 6 insertions(+), 14 deletions(-) diff --git a/src/probnum/diffeq/odefiltsmooth/ivpfiltsmooth.py b/src/probnum/diffeq/odefiltsmooth/ivpfiltsmooth.py index acde3c754..0f9ae7f9f 100644 --- a/src/probnum/diffeq/odefiltsmooth/ivpfiltsmooth.py +++ b/src/probnum/diffeq/odefiltsmooth/ivpfiltsmooth.py @@ -19,7 +19,7 @@ class GaussianIVPFilter(odesolver.ODESolver): further considerations to, e.g., BVPs. """ - def __init__(self, ivp, gaussfilt, steprl): + def __init__(self, ivp, gaussfilt): """ steprule : stepsize rule gaussfilt : gaussianfilter.GaussianFilter object, @@ -37,9 +37,8 @@ def __init__(self, ivp, gaussfilt, steprl): raise ValueError("Please initialise a Gaussian filter with an ODEPrior") self.ivp = ivp self.gfilt = gaussfilt - odesolver.ODESolver.__init__(self, steprl) - def solve(self, firststep, **kwargs): + def solve(self, firststep, steprule, **kwargs): """ Solve IVP and calibrates uncertainty according to Proposition 4 in Tronarp et al. @@ -70,7 +69,7 @@ def solve(self, firststep, **kwargs): filt_rv.mean(), crosscov, meas_cov, meas_mean ) - if self.steprule.is_accepted(step, errorest): + if steprule.is_accepted(step, errorest): times.append(t_new) rvs.append(filt_rv) num_steps += 1 @@ -78,7 +77,7 @@ def solve(self, firststep, **kwargs): current_rv = filt_rv t = t_new - step = self._suggest_step(step, errorest) + step = self._suggest_step(step, errorest, steprule) step = min(step, self.ivp.tmax - t) rvs = [ @@ -153,7 +152,7 @@ def _rel_and_abs_error(self, abserrors, currmn): abs_error = abserrors @ weights / np.linalg.norm(weights) return np.maximum(rel_error, abs_error) - def _suggest_step(self, step, errorest): + def _suggest_step(self, step, errorest, steprule): """ Suggests step according to steprule and warns if step is extremely small. @@ -163,7 +162,7 @@ def _suggest_step(self, step, errorest): RuntimeWarning If suggested step is smaller than :math:`10^{-15}`. """ - step = self.steprule.suggest(step, errorest) + step = steprule.suggest(step, errorest) if step < 1e-15: warnmsg = "Stepsize is num. zero (%.1e)" % step warnings.warn(message=warnmsg, category=RuntimeWarning) diff --git a/src/probnum/diffeq/odesolver.py b/src/probnum/diffeq/odesolver.py index cd1255d3d..63bc9029a 100644 --- a/src/probnum/diffeq/odesolver.py +++ b/src/probnum/diffeq/odesolver.py @@ -10,13 +10,6 @@ class ODESolver(ABC): Interface for ODESolver. """ - def __init__(self, steprule): - """ - An ODESolver is an object governed by a stepsize rule. - That is: constant steps or adaptive steps. - """ - self.steprule = steprule - @abstractmethod def solve(self, ivp, minstep, maxstep, **kwargs): """ From 14042231077f88ea6c926e55b4b21251e01fe1aa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicholas=20Kr=C3=A4mer?= Date: Wed, 26 Aug 2020 12:09:27 +0200 Subject: [PATCH 2/4] repair things second --- src/probnum/diffeq/odefiltsmooth/odefiltsmooth.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/probnum/diffeq/odefiltsmooth/odefiltsmooth.py b/src/probnum/diffeq/odefiltsmooth/odefiltsmooth.py index 3969ee54e..b5cafcfdc 100644 --- a/src/probnum/diffeq/odefiltsmooth/odefiltsmooth.py +++ b/src/probnum/diffeq/odefiltsmooth/odefiltsmooth.py @@ -225,10 +225,10 @@ def probsolve_ivp( [0.97947631] [0.98614541]] """ - solver, firststep = _create_solver_object( + solver, firststep, stprl = _create_solver_object( ivp, method, which_prior, tol, step, firststep, precond_step, **kwargs ) - solution = solver.solve(firststep=firststep, **kwargs) + solution = solver.solve(firststep=firststep, steprule=stprl, **kwargs) if method in ["eks0", "eks1", "uks"]: solution = solver.odesmooth(solution, **kwargs) return solution @@ -251,7 +251,7 @@ def _create_solver_object( stprl = steprule.ConstantSteps(step) firststep = step gfilt = _string2filter(ivp, _prior, method, **kwargs) - return GaussianIVPFilter(ivp, gfilt, stprl), firststep + return GaussianIVPFilter(ivp, gfilt), firststep, stprl def _check_step_tol(step, tol): From aa7deb3ac07226006bb3bdfb280a29b39636bb7f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicholas=20Kr=C3=A4mer?= Date: Wed, 26 Aug 2020 12:14:00 +0200 Subject: [PATCH 3/4] refactored create_solver_function slightly --- src/probnum/diffeq/odefiltsmooth/odefiltsmooth.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/probnum/diffeq/odefiltsmooth/odefiltsmooth.py b/src/probnum/diffeq/odefiltsmooth/odefiltsmooth.py index b5cafcfdc..929b16c32 100644 --- a/src/probnum/diffeq/odefiltsmooth/odefiltsmooth.py +++ b/src/probnum/diffeq/odefiltsmooth/odefiltsmooth.py @@ -225,16 +225,17 @@ def probsolve_ivp( [0.97947631] [0.98614541]] """ - solver, firststep, stprl = _create_solver_object( + gfilt, firststep, stprl = _create_solver_inputs( ivp, method, which_prior, tol, step, firststep, precond_step, **kwargs ) + solver = GaussianIVPFilter(ivp, gfilt) solution = solver.solve(firststep=firststep, steprule=stprl, **kwargs) if method in ["eks0", "eks1", "uks"]: solution = solver.odesmooth(solution, **kwargs) return solution -def _create_solver_object( +def _create_solver_inputs( ivp, method, which_prior, tol, step, firststep, precond_step, **kwargs ): """Create the solver object that is used.""" @@ -251,7 +252,7 @@ def _create_solver_object( stprl = steprule.ConstantSteps(step) firststep = step gfilt = _string2filter(ivp, _prior, method, **kwargs) - return GaussianIVPFilter(ivp, gfilt), firststep, stprl + return gfilt, firststep, stprl def _check_step_tol(step, tol): From 11606dc254023d8e63762b1095e4d22da9a6fd65 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicholas=20Kr=C3=A4mer?= Date: Wed, 26 Aug 2020 14:02:01 +0200 Subject: [PATCH 4/4] adapted docstrings --- src/probnum/diffeq/odefiltsmooth/ivpfiltsmooth.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/probnum/diffeq/odefiltsmooth/ivpfiltsmooth.py b/src/probnum/diffeq/odefiltsmooth/ivpfiltsmooth.py index 0f9ae7f9f..2ad69699b 100644 --- a/src/probnum/diffeq/odefiltsmooth/ivpfiltsmooth.py +++ b/src/probnum/diffeq/odefiltsmooth/ivpfiltsmooth.py @@ -21,7 +21,6 @@ class GaussianIVPFilter(odesolver.ODESolver): def __init__(self, ivp, gaussfilt): """ - steprule : stepsize rule gaussfilt : gaussianfilter.GaussianFilter object, e.g. the return value of ivp_to_ukf(), ivp_to_ekf1(). @@ -47,6 +46,8 @@ def solve(self, firststep, steprule, **kwargs): ---------- firststep : float First step for adaptive step size rule. + steprule : StepRule + Step-size selection rule, e.g. constant steps or adaptive steps. """ current_rv = self.gfilt.initialrandomvariable t = self.ivp.t0