diff --git a/docs/benchmarks/hires/run_hires.py b/docs/benchmarks/hires/run_hires.py index 32b53631..bef46b78 100644 --- a/docs/benchmarks/hires/run_hires.py +++ b/docs/benchmarks/hires/run_hires.py @@ -88,8 +88,8 @@ def param_to_solution(tol): tcoeffs = taylor.odejet_padded_scan(vf_auto, (u0,), num=num_derivatives) ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, ssm_fact="dense") ts1 = ivpsolvers.correction_ts1(ssm=ssm) - strategy = ivpsolvers.strategy_filter(ibm, ts1, ssm=ssm) - solver = ivpsolvers.solver_dynamic(strategy, ssm=ssm) + strategy = ivpsolvers.strategy_filter(ssm=ssm) + solver = ivpsolvers.solver_dynamic(strategy, prior=ibm, correction=ts1, ssm=ssm) control = ivpsolve.control_proportional_integral(clip=True) adaptive_solver = ivpsolve.adaptive( solver, atol=1e-2 * tol, rtol=tol, control=control, ssm=ssm diff --git a/docs/benchmarks/lotkavolterra/run_lotkavolterra.py b/docs/benchmarks/lotkavolterra/run_lotkavolterra.py index b381ba4b..029d94e1 100644 --- a/docs/benchmarks/lotkavolterra/run_lotkavolterra.py +++ b/docs/benchmarks/lotkavolterra/run_lotkavolterra.py @@ -81,8 +81,9 @@ def param_to_solution(tol): tcoeffs = taylor.odejet_padded_scan(vf_auto, (u0,), num=num_derivatives) ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, ssm_fact=implementation) - strategy = ivpsolvers.strategy_filter(ibm, correction(ssm=ssm), ssm=ssm) - solver = ivpsolvers.solver_mle(strategy, ssm=ssm) + strategy = ivpsolvers.strategy_filter(ssm=ssm) + corr = correction(ssm=ssm) + solver = ivpsolvers.solver_mle(strategy, prior=ibm, correction=corr, ssm=ssm) control = ivpsolve.control_proportional_integral() adaptive_solver = ivpsolve.adaptive( solver, atol=1e-2 * tol, rtol=tol, control=control, ssm=ssm diff --git a/docs/benchmarks/pleiades/run_pleiades.py b/docs/benchmarks/pleiades/run_pleiades.py index 8d61fd8b..b1b49e25 100644 --- a/docs/benchmarks/pleiades/run_pleiades.py +++ b/docs/benchmarks/pleiades/run_pleiades.py @@ -101,8 +101,10 @@ def param_to_solution(tol): ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, ssm_fact="isotropic") ts0_or_ts1 = correction_fun(ssm=ssm, ode_order=2) - strategy = ivpsolvers.strategy_filter(ibm, ts0_or_ts1, ssm=ssm) - solver = ivpsolvers.solver_dynamic(strategy, ssm=ssm) + strategy = ivpsolvers.strategy_filter(ssm=ssm) + solver = ivpsolvers.solver_dynamic( + strategy, prior=ibm, correction=ts0_or_ts1, ssm=ssm + ) control = ivpsolve.control_proportional_integral() adaptive_solver = ivpsolve.adaptive( solver, atol=1e-3 * tol, rtol=tol, control=control, ssm=ssm diff --git a/docs/benchmarks/vanderpol/run_vanderpol.py b/docs/benchmarks/vanderpol/run_vanderpol.py index fa5010b2..5390cae0 100644 --- a/docs/benchmarks/vanderpol/run_vanderpol.py +++ b/docs/benchmarks/vanderpol/run_vanderpol.py @@ -80,9 +80,11 @@ def param_to_solution(tol): ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, ssm_fact="dense") ts0_or_ts1 = ivpsolvers.correction_ts1(ode_order=2, ssm=ssm) - strategy = ivpsolvers.strategy_filter(ibm, ts0_or_ts1, ssm=ssm) + strategy = ivpsolvers.strategy_filter(ssm=ssm) - solver = ivpsolvers.solver_dynamic(strategy, ssm=ssm) + solver = ivpsolvers.solver_dynamic( + strategy, prior=ibm, correction=ts0_or_ts1, ssm=ssm + ) control = ivpsolve.control_proportional_integral(clip=True) adaptive_solver = ivpsolve.adaptive( solver, atol=1e-3 * tol, rtol=tol, control=control, ssm=ssm diff --git a/docs/examples_misc/use_equinox_bounded_while_loop.py b/docs/examples_misc/use_equinox_bounded_while_loop.py index 2beedf9c..c9f543b8 100644 --- a/docs/examples_misc/use_equinox_bounded_while_loop.py +++ b/docs/examples_misc/use_equinox_bounded_while_loop.py @@ -64,8 +64,8 @@ def vf(y, *, t): # noqa: ARG001 ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, ssm_fact="isotropic") ts0 = ivpsolvers.correction_ts0(ode_order=1, ssm=ssm) - strategy = ivpsolvers.strategy_fixedpoint(ibm, ts0, ssm=ssm) - solver = ivpsolvers.solver(strategy) + strategy = ivpsolvers.strategy_fixedpoint(ssm=ssm) + solver = ivpsolvers.solver(strategy, prior=ibm, correction=ts0, ssm=ssm) adaptive_solver = ivpsolve.adaptive(solver, ssm=ssm) init = solver.initial_condition() diff --git a/docs/examples_parameter_estimation/neural_ode.py b/docs/examples_parameter_estimation/neural_ode.py index fa7d425c..9de1cc24 100644 --- a/docs/examples_parameter_estimation/neural_ode.py +++ b/docs/examples_parameter_estimation/neural_ode.py @@ -70,8 +70,8 @@ def loss_fn(parameters): tcoeffs = (*initial_values, vf(*initial_values, t=t0, p=parameters)) ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, ssm_fact="isotropic") ts0 = ivpsolvers.correction_ts0(ssm=ssm) - strategy = ivpsolvers.strategy_smoother(ibm, ts0, ssm=ssm) - solver_ts0 = ivpsolvers.solver(strategy) + strategy = ivpsolvers.strategy_smoother(ssm=ssm) + solver_ts0 = ivpsolvers.solver(strategy, prior=ibm, correction=ts0, ssm=ssm) init = solver_ts0.initial_condition() sol = ivpsolve.solve_fixed_grid( @@ -128,8 +128,8 @@ def vf(y, *, t, p): tcoeffs = (u0, vf(u0, t=t0, p=f_args)) ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, output_scale=1.0, ssm_fact="isotropic") ts0 = ivpsolvers.correction_ts0(ssm=ssm) -strategy = ivpsolvers.strategy_smoother(ibm, ts0, ssm=ssm) -solver_ts0 = ivpsolvers.solver(strategy) +strategy = ivpsolvers.strategy_smoother(ssm=ssm) +solver_ts0 = ivpsolvers.solver(strategy, prior=ibm, correction=ts0, ssm=ssm) init = solver_ts0.initial_condition() # + @@ -168,8 +168,8 @@ def vf(y, *, t, p): tcoeffs = (u0, vf(u0, t=t0, p=f_args)) ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, output_scale=1.0, ssm_fact="isotropic") ts0 = ivpsolvers.correction_ts0(ssm=ssm) -strategy = ivpsolvers.strategy_smoother(ibm, ts0, ssm=ssm) -solver_ts0 = ivpsolvers.solver(strategy) +strategy = ivpsolvers.strategy_smoother(ssm=ssm) +solver_ts0 = ivpsolvers.solver(strategy, prior=ibm, correction=ts0, ssm=ssm) init = solver_ts0.initial_condition() sol = ivpsolve.solve_fixed_grid( @@ -182,8 +182,8 @@ def vf(y, *, t, p): tcoeffs = (u0, vf(u0, t=t0, p=f_args)) ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, output_scale=1.0, ssm_fact="isotropic") ts0 = ivpsolvers.correction_ts0(ssm=ssm) -strategy = ivpsolvers.strategy_smoother(ibm, ts0, ssm=ssm) -solver_ts0 = ivpsolvers.solver(strategy) +strategy = ivpsolvers.strategy_smoother(ssm=ssm) +solver_ts0 = ivpsolvers.solver(strategy, prior=ibm, correction=ts0, ssm=ssm) init = solver_ts0.initial_condition() sol = ivpsolve.solve_fixed_grid( diff --git a/docs/examples_parameter_estimation/physics_enhanced_regression_1.py b/docs/examples_parameter_estimation/physics_enhanced_regression_1.py index f835ce46..90668dfc 100644 --- a/docs/examples_parameter_estimation/physics_enhanced_regression_1.py +++ b/docs/examples_parameter_estimation/physics_enhanced_regression_1.py @@ -72,8 +72,8 @@ def solve(p): tcoeffs, output_scale=output_scale, ssm_fact="isotropic" ) ts0 = ivpsolvers.correction_ts0(ssm=ssm) - strategy = ivpsolvers.strategy_smoother(ibm, ts0, ssm=ssm) - solver = ivpsolvers.solver(strategy) + strategy = ivpsolvers.strategy_smoother(ssm=ssm) + solver = ivpsolvers.solver(strategy, prior=ibm, correction=ts0, ssm=ssm) init = solver.initial_condition() return ivpsolve.solve_fixed_grid( diff --git a/docs/examples_parameter_estimation/physics_enhanced_regression_2.py b/docs/examples_parameter_estimation/physics_enhanced_regression_2.py index e585d249..f78541ae 100644 --- a/docs/examples_parameter_estimation/physics_enhanced_regression_2.py +++ b/docs/examples_parameter_estimation/physics_enhanced_regression_2.py @@ -191,8 +191,8 @@ def solve_fixed(theta, *, ts): tcoeffs, output_scale=output_scale, ssm_fact="isotropic" ) ts0 = ivpsolvers.correction_ts0(ssm=ssm) - strategy = ivpsolvers.strategy_filter(ibm, ts0, ssm=ssm) - solver = ivpsolvers.solver(strategy) + strategy = ivpsolvers.strategy_filter(ssm=ssm) + solver = ivpsolvers.solver(strategy, prior=ibm, correction=ts0, ssm=ssm) init = solver.initial_condition() sol = ivpsolve.solve_fixed_grid(vf, init, grid=ts, solver=solver, ssm=ssm) return sol[-1] @@ -208,8 +208,8 @@ def solve_adaptive(theta, *, save_at): tcoeffs, output_scale=output_scale, ssm_fact="isotropic" ) ts0 = ivpsolvers.correction_ts0(ssm=ssm) - strategy = ivpsolvers.strategy_filter(ibm, ts0, ssm=ssm) - solver = ivpsolvers.solver(strategy) + strategy = ivpsolvers.strategy_filter(ssm=ssm) + solver = ivpsolvers.solver(strategy, prior=ibm, correction=ts0, ssm=ssm) adaptive_solver = ivpsolve.adaptive(solver, ssm=ssm) init = solver.initial_condition() diff --git a/docs/examples_quickstart/easy_example.py b/docs/examples_quickstart/easy_example.py index 658fb797..65dfeb0a 100644 --- a/docs/examples_quickstart/easy_example.py +++ b/docs/examples_quickstart/easy_example.py @@ -55,11 +55,11 @@ def vf(y, *, t): # noqa: ARG001 # Set up a state-space model tcoeffs = taylor.odejet_padded_scan(lambda y: vf(y, t=t0), (u0,), num=4) ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, ssm_fact="dense") +ts0 = ivpsolvers.correction_ts1(ode_order=1, ssm=ssm) +strategy = ivpsolvers.strategy_smoother(ssm=ssm) # Build a solver -ts0 = ivpsolvers.correction_ts1(ode_order=1, ssm=ssm) -strategy = ivpsolvers.strategy_smoother(ibm, ts0, ssm=ssm) -solver = ivpsolvers.solver_mle(strategy, ssm=ssm) +solver = ivpsolvers.solver_mle(strategy, prior=ibm, correction=ts0, ssm=ssm) adaptive_solver = ivpsolve.adaptive(solver, ssm=ssm) # - diff --git a/docs/examples_solver_config/conditioning-on-zero-residual.py b/docs/examples_solver_config/conditioning-on-zero-residual.py index ecc9b67d..86299d29 100644 --- a/docs/examples_solver_config/conditioning-on-zero-residual.py +++ b/docs/examples_solver_config/conditioning-on-zero-residual.py @@ -78,8 +78,9 @@ def vector_field(y, t): # noqa: ARG001 # Compute the posterior ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, output_scale=1.0, ssm_fact="dense") -slr1 = ivpsolvers.correction_ts1(ssm=ssm) -solver = ivpsolvers.solver(ivpsolvers.strategy_fixedpoint(ibm, slr1, ssm=ssm)) +ts1 = ivpsolvers.correction_ts1(ssm=ssm) +strategy = ivpsolvers.strategy_fixedpoint(ssm=ssm) +solver = ivpsolvers.solver(strategy, prior=ibm, correction=ts1, ssm=ssm) adaptive_solver = ivpsolve.adaptive(solver, atol=1e-1, rtol=1e-2, ssm=ssm) dt0 = ivpsolve.dt0(lambda y: vector_field(y, t=t0), (u0,)) diff --git a/docs/examples_solver_config/dynamic_output_scales.py b/docs/examples_solver_config/dynamic_output_scales.py index 2cce9d52..a709bab6 100644 --- a/docs/examples_solver_config/dynamic_output_scales.py +++ b/docs/examples_solver_config/dynamic_output_scales.py @@ -72,9 +72,9 @@ def vf(*ys, t): # noqa: ARG001 tcoeffs = (u0, vf(u0, t=t0)) ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, output_scale=1.0, ssm_fact="dense") ts1 = ivpsolvers.correction_ts1(ssm=ssm) -strategy = ivpsolvers.strategy_filter(ibm, ts1, ssm=ssm) -dynamic = ivpsolvers.solver_dynamic(strategy, ssm=ssm) -mle = ivpsolvers.solver_mle(strategy, ssm=ssm) +strategy = ivpsolvers.strategy_filter(ssm=ssm) +dynamic = ivpsolvers.solver_dynamic(strategy, prior=ibm, correction=ts1, ssm=ssm) +mle = ivpsolvers.solver_mle(strategy, prior=ibm, correction=ts1, ssm=ssm) # + t0, t1 = 0.0, 3.0 diff --git a/docs/examples_solver_config/posterior_uncertainties.py b/docs/examples_solver_config/posterior_uncertainties.py index 18706fce..5b02742e 100644 --- a/docs/examples_solver_config/posterior_uncertainties.py +++ b/docs/examples_solver_config/posterior_uncertainties.py @@ -61,7 +61,8 @@ def vf(*ys, t): # noqa: ARG001 tcoeffs = taylor.odejet_padded_scan(lambda y: vf(y, t=t0), (u0,), num=4) ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, output_scale=1.0, ssm_fact="isotropic") ts0 = ivpsolvers.correction_ts0(ssm=ssm) -solver = ivpsolvers.solver_mle(ivpsolvers.strategy_filter(ibm, ts0, ssm=ssm), ssm=ssm) +strategy = ivpsolvers.strategy_filter(ssm=ssm) +solver = ivpsolvers.solver_mle(strategy, prior=ibm, correction=ts0, ssm=ssm) adaptive_solver = ivpsolve.adaptive(solver, atol=1e-2, rtol=1e-2, ssm=ssm) ts = jnp.linspace(t0, t0 + 2.0, endpoint=True, num=500) @@ -115,9 +116,8 @@ def vf(*ys, t): # noqa: ARG001 # + ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, output_scale=1.0, ssm_fact="isotropic") ts0 = ivpsolvers.correction_ts0(ssm=ssm) -solver = ivpsolvers.solver_mle( - ivpsolvers.strategy_fixedpoint(ibm, ts0, ssm=ssm), ssm=ssm -) +strategy = ivpsolvers.strategy_fixedpoint(ssm=ssm) +solver = ivpsolvers.solver_mle(strategy, prior=ibm, correction=ts0, ssm=ssm) adaptive_solver = ivpsolve.adaptive(solver, atol=1e-2, rtol=1e-2, ssm=ssm) ts = jnp.linspace(t0, t0 + 2.0, endpoint=True, num=500) diff --git a/docs/examples_solver_config/second_order_problems.py b/docs/examples_solver_config/second_order_problems.py index 447d76c8..c67c8e89 100644 --- a/docs/examples_solver_config/second_order_problems.py +++ b/docs/examples_solver_config/second_order_problems.py @@ -52,8 +52,8 @@ def vf_1(y, t): # noqa: ARG001 tcoeffs = taylor.odejet_padded_scan(lambda y: vf_1(y, t=t0), (u0,), num=4) ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, output_scale=1.0, ssm_fact="isotropic") ts0 = ivpsolvers.correction_ts0(ssm=ssm) -strategy = ivpsolvers.strategy_filter(ibm, ts0, ssm=ssm) -solver_1st = ivpsolvers.solver_mle(strategy, ssm=ssm) +strategy = ivpsolvers.strategy_filter(ssm=ssm) +solver_1st = ivpsolvers.solver_mle(strategy, prior=ibm, correction=ts0, ssm=ssm) adaptive_solver_1st = ivpsolve.adaptive(solver_1st, atol=1e-5, rtol=1e-5, ssm=ssm) @@ -86,8 +86,8 @@ def vf_2(y, dy, t): # noqa: ARG001 tcoeffs = taylor.odejet_padded_scan(lambda *ys: vf_2(*ys, t=t0), (u0, du0), num=3) ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, output_scale=1.0, ssm_fact="isotropic") ts0 = ivpsolvers.correction_ts0(ode_order=2, ssm=ssm) -strategy = ivpsolvers.strategy_filter(ibm, ts0, ssm=ssm) -solver_2nd = ivpsolvers.solver_mle(strategy, ssm=ssm) +strategy = ivpsolvers.strategy_filter(ssm=ssm) +solver_2nd = ivpsolvers.solver_mle(strategy, prior=ibm, correction=ts0, ssm=ssm) adaptive_solver_2nd = ivpsolve.adaptive(solver_2nd, atol=1e-5, rtol=1e-5, ssm=ssm) diff --git a/docs/examples_solver_config/taylor_coefficients.py b/docs/examples_solver_config/taylor_coefficients.py index d38befcd..1a323be7 100644 --- a/docs/examples_solver_config/taylor_coefficients.py +++ b/docs/examples_solver_config/taylor_coefficients.py @@ -64,8 +64,8 @@ def solve(tc): """Solve the ODE.""" prior, ssm = ivpsolvers.prior_ibm(tc, ssm_fact="dense") ts0 = ivpsolvers.correction_ts0(ssm=ssm) - strategy = ivpsolvers.strategy_fixedpoint(prior, ts0, ssm=ssm) - solver = ivpsolvers.solver_mle(strategy, ssm=ssm) + strategy = ivpsolvers.strategy_fixedpoint(ssm=ssm) + solver = ivpsolvers.solver_mle(strategy, prior=prior, correction=ts0, ssm=ssm) init = solver.initial_condition() ts = jnp.linspace(t0, t1, endpoint=True, num=10) diff --git a/probdiffeq/ivpsolve.py b/probdiffeq/ivpsolve.py index 727d46c8..a401cdca 100644 --- a/probdiffeq/ivpsolve.py +++ b/probdiffeq/ivpsolve.py @@ -205,8 +205,8 @@ def body_fn(state: _RejectionState) -> _RejectionState: dt=self.control.extract(state_control), ) # Normalise the error - u_proposed = self.ssm.stats.qoi(state_proposed.strategy.hidden)[0] - u_step_from = self.ssm.stats.qoi(state_proposed.strategy.hidden)[0] + u_proposed = self.ssm.stats.qoi(state_proposed.hidden)[0] + u_step_from = self.ssm.stats.qoi(state_proposed.hidden)[0] u = np.maximum(np.abs(u_proposed), np.abs(u_step_from)) error_power = _error_scale_and_normalize(error_estimate, u=u) diff --git a/probdiffeq/ivpsolvers.py b/probdiffeq/ivpsolvers.py index 064d3884..a198d9ea 100644 --- a/probdiffeq/ivpsolvers.py +++ b/probdiffeq/ivpsolvers.py @@ -173,10 +173,14 @@ def _tensor_points(x, /, *, d): class _ExtraImpl: """Extrapolation model interface.""" - prior: _MarkovProcess + name: str ssm: Any - def initial_condition(self): + is_suitable_for_save_at: int + is_suitable_for_save_every_step: int + is_suitable_for_offgrid_marginals: int + + def initial_condition(self, *, prior): """Compute an initial condition from a set of Taylor coefficients.""" raise NotImplementedError @@ -184,7 +188,7 @@ def init(self, sol: stats.MarkovSeq, /): """Initialise a state from a solution.""" raise NotImplementedError - def begin(self, rv, _extra, /, dt): + def begin(self, rv, _extra, /, *, prior_discretized): """Begin the extrapolation.""" raise NotImplementedError @@ -196,27 +200,27 @@ def extract(self, hidden_state, extra, /): """Extract a solution from a state.""" raise NotImplementedError - def interpolate(self, state_t0, marginal_t1, *, dt0, dt1, output_scale): + def interpolate(self, state_t0, marginal_t1, *, dt0, dt1, output_scale, prior): """Interpolate.""" raise NotImplementedError - def interpolate_at_t1(self, rv, extra, /): + def interpolate_at_t1(self, rv, extra, /, *, prior): """Process the state at checkpoint t=t_n.""" raise NotImplementedError @containers.dataclass class _ExtraImplSmoother(_ExtraImpl): - def initial_condition(self): - rv = self.ssm.normal.from_tcoeffs(self.prior.tcoeffs) - cond = self.ssm.conditional.identity(len(self.prior.tcoeffs)) + def initial_condition(self, *, prior): + rv = self.ssm.normal.from_tcoeffs(prior.tcoeffs) + cond = self.ssm.conditional.identity(len(prior.tcoeffs)) return stats.MarkovSeq(init=rv, conditional=cond) def init(self, sol: stats.MarkovSeq, /): return sol.init, sol.conditional - def begin(self, rv, _extra, /, dt): - cond, (p, p_inv) = self.prior.discretize(dt) + def begin(self, rv, _extra, /, *, prior_discretized): + cond, (p, p_inv) = prior_discretized rv_p = self.ssm.normal.preconditioner_apply(rv, p_inv) @@ -243,7 +247,7 @@ def complete(self, _ssv, extra, /, output_scale): def extract(self, hidden_state, extra, /): return stats.MarkovSeq(init=hidden_state, conditional=extra) - def interpolate(self, state_t0, marginal_t1, *, dt0, dt1, output_scale): + def interpolate(self, state_t0, marginal_t1, *, dt0, dt1, output_scale, prior): """Interpolate. A smoother interpolates by_ @@ -258,9 +262,14 @@ def interpolate(self, state_t0, marginal_t1, *, dt0, dt1, output_scale): Subsequent IVP solver steps continue from the value at 't1'. """ # Extrapolate from t0 to t, and from t to t1. This yields all building blocks. - extrapolated_t = self._extrapolate(*state_t0, dt=dt0, output_scale=output_scale) + prior0 = prior.discretize(dt0) + extrapolated_t = self._extrapolate( + *state_t0, output_scale=output_scale, prior_discretized=prior0 + ) + + prior1 = prior.discretize(dt1) extrapolated_t1 = self._extrapolate( - *extrapolated_t, dt=dt1, output_scale=output_scale + *extrapolated_t, output_scale=output_scale, prior_discretized=prior1 ) # Marginalise from t1 to t to obtain the interpolated solution. @@ -278,11 +287,12 @@ def interpolate(self, state_t0, marginal_t1, *, dt0, dt1, output_scale): interp_from=solution_at_t, ) - def _extrapolate(self, state, extra, /, *, dt, output_scale): - state, cache = self.begin(state, extra, dt=dt) + def _extrapolate(self, state, extra, /, *, output_scale, prior_discretized): + state, cache = self.begin(state, extra, prior_discretized=prior_discretized) return self.complete(state, cache, output_scale=output_scale) - def interpolate_at_t1(self, rv, extra, /): + def interpolate_at_t1(self, rv, extra, /, *, prior): + del prior return _InterpRes((rv, extra), (rv, extra), (rv, extra)) @@ -291,11 +301,11 @@ class _ExtraImplFilter(_ExtraImpl): def init(self, sol, /): return sol, None - def initial_condition(self): - return self.ssm.normal.from_tcoeffs(self.prior.tcoeffs) + def initial_condition(self, *, prior): + return self.ssm.normal.from_tcoeffs(prior.tcoeffs) - def begin(self, rv, _extra, /, dt): - cond, (p, p_inv) = self.prior.discretize(dt) + def begin(self, rv, _extra, /, prior_discretized): + cond, (p, p_inv) = prior_discretized rv_p = self.ssm.normal.preconditioner_apply(rv, p_inv) @@ -321,13 +331,14 @@ def complete(self, _ssv, extra, /, output_scale): # Gather and return return extrapolated, None - def interpolate(self, state_t0, marginal_t1, dt0, dt1, output_scale): + def interpolate(self, state_t0, marginal_t1, dt0, dt1, output_scale, *, prior): # todo: by ditching marginal_t1 and dt1, this function _extrapolates # (no *inter*polation happening) del dt1 hidden, extra = state_t0 - hidden, extra = self.begin(hidden, extra, dt=dt0) + prior0 = prior.discretize(dt0) + hidden, extra = self.begin(hidden, extra, prior_discretized=prior0) hidden, extra = self.complete(hidden, extra, output_scale=output_scale) # Consistent state-types in interpolation result. @@ -335,22 +346,23 @@ def interpolate(self, state_t0, marginal_t1, dt0, dt1, output_scale): step_from = (marginal_t1, None) return _InterpRes(step_from=step_from, interpolated=interp, interp_from=interp) - def interpolate_at_t1(self, rv, extra, /): + def interpolate_at_t1(self, rv, extra, /, *, prior): + del prior return _InterpRes((rv, extra), (rv, extra), (rv, extra)) @containers.dataclass class _ExtraImplFixedPoint(_ExtraImpl): - def initial_condition(self): - rv = self.ssm.normal.from_tcoeffs(self.prior.tcoeffs) - cond = self.ssm.conditional.identity(len(self.prior.tcoeffs)) + def initial_condition(self, prior): + rv = self.ssm.normal.from_tcoeffs(prior.tcoeffs) + cond = self.ssm.conditional.identity(len(prior.tcoeffs)) return stats.MarkovSeq(init=rv, conditional=cond) def init(self, sol: stats.MarkovSeq, /): return sol.init, sol.conditional - def begin(self, rv, extra, /, dt): - cond, (p, p_inv) = self.prior.discretize(dt) + def begin(self, rv, extra, /, prior_discretized): + cond, (p, p_inv) = prior_discretized rv_p = self.ssm.normal.preconditioner_apply(rv, p_inv) @@ -380,7 +392,7 @@ def complete(self, _rv, extra, /, output_scale): # Gather and return return extrapolated, cond - def interpolate(self, state_t0, marginal_t1, *, dt0, dt1, output_scale): + def interpolate(self, state_t0, marginal_t1, *, dt0, dt1, output_scale, prior): """Interpolate. A fixed-point smoother interpolates by @@ -419,11 +431,16 @@ def interpolate(self, state_t0, marginal_t1, *, dt0, dt1, output_scale): then don't understand why tests fail.) """ # Extrapolate from t0 to t, and from t to t1. This yields all building blocks. - extrapolated_t = self._extrapolate(*state_t0, dt=dt0, output_scale=output_scale) - conditional_id = self.ssm.conditional.identity(self.prior.num_derivatives + 1) + prior0 = prior.discretize(dt0) + extrapolated_t = self._extrapolate( + *state_t0, output_scale=output_scale, prior_discretized=prior0 + ) + conditional_id = self.ssm.conditional.identity(prior.num_derivatives + 1) previous_new = (extrapolated_t[0], conditional_id) + + prior1 = prior.discretize(dt1) extrapolated_t1 = self._extrapolate( - *previous_new, dt=dt1, output_scale=output_scale + *previous_new, output_scale=output_scale, prior_discretized=prior1 ) # Marginalise from t1 to t to obtain the interpolated solution. @@ -437,12 +454,12 @@ def interpolate(self, state_t0, marginal_t1, *, dt0, dt1, output_scale): interp_from=previous_new, ) - def _extrapolate(self, state, extra, /, *, dt, output_scale): - x, cache = self.begin(state, extra, dt=dt) + def _extrapolate(self, state, extra, /, *, output_scale, prior_discretized): + x, cache = self.begin(state, extra, prior_discretized=prior_discretized) return self.complete(x, cache, output_scale=output_scale) - def interpolate_at_t1(self, rv, extra, /): - cond_identity = self.ssm.conditional.identity(self.prior.num_derivatives + 1) + def interpolate_at_t1(self, rv, extra, /, *, prior): + cond_identity = self.ssm.conditional.identity(prior.num_derivatives + 1) return _InterpRes((rv, cond_identity), (rv, extra), (rv, cond_identity)) @@ -558,174 +575,36 @@ def _estimate_error(observed, /, *, ssm): return output_scale * error_estimate_unscaled -class _StrategyState(containers.NamedTuple): - t: Any - hidden: Any - aux_extra: Any - aux_corr: Any - - -@containers.dataclass -class _Strategy: - """Estimation strategy.""" - - name: str - extrapolation: _ExtraImpl - correction: _Correction - ssm: Any - - is_suitable_for_save_at: int - is_suitable_for_offgrid_marginals: int - is_suitable_for_save_every_step: int - prior: _MarkovProcess - - @property - def num_derivatives(self): - return self.prior.num_derivatives - - def init(self, t, posterior, /) -> _StrategyState: - """Initialise a state from a posterior.""" - rv, extra = self.extrapolation.init(posterior) - rv, corr = self.correction.init(rv) - return _StrategyState(t=t, hidden=rv, aux_extra=extra, aux_corr=corr) - - def initial_condition(self): - """Construct an initial condition from a set of Taylor coefficients.""" - return self.extrapolation.initial_condition() - - def begin(self, state: _StrategyState, /, *, dt, vector_field): - """Predict the error of an upcoming step.""" - hidden, extra = self.extrapolation.begin(state.hidden, state.aux_extra, dt=dt) - t = state.t + dt - error, observed, corr = self.correction.estimate_error( - hidden, vector_field=vector_field, t=t - ) - state = _StrategyState(t=t, hidden=hidden, aux_extra=extra, aux_corr=corr) - return error, observed, state - - def complete(self, state, /, *, output_scale): - """Complete the step after the error has been predicted.""" - hidden, extra = self.extrapolation.complete( - state.hidden, state.aux_extra, output_scale=output_scale - ) - hidden, corr = self.correction.complete(hidden, state.aux_corr) - return _StrategyState(t=state.t, hidden=hidden, aux_extra=extra, aux_corr=corr) - - def extract(self, state: _StrategyState, /): - """Extract the solution from a state.""" - hidden = self.correction.extract(state.hidden) - sol = self.extrapolation.extract(hidden, state.aux_extra) - return state.t, sol - - def case_interpolate_at_t1(self, state_t1: _StrategyState) -> _InterpRes: - """Process the solution in case t=t_n.""" - _tmp = self.extrapolation.interpolate_at_t1(state_t1.hidden, state_t1.aux_extra) - step_from, solution, interp_from = ( - _tmp.step_from, - _tmp.interpolated, - _tmp.interp_from, - ) - - def _state(x): - t = state_t1.t - corr_like = tree_util.tree_map(np.empty_like, state_t1.aux_corr) - return _StrategyState(t=t, hidden=x[0], aux_extra=x[1], aux_corr=corr_like) - - step_from = _state(step_from) - solution = _state(solution) - interp_from = _state(interp_from) - return _InterpRes(step_from, solution, interp_from) - - def case_interpolate( - self, t, *, s0: _StrategyState, s1: _StrategyState, output_scale - ) -> _InterpRes: - """Process the solution in case t>t_n.""" - # Interpolate - interp = self.extrapolation.interpolate( - state_t0=(s0.hidden, s0.aux_extra), - marginal_t1=s1.hidden, - dt0=t - s0.t, - dt1=s1.t - t, - output_scale=output_scale, - ) - - # Turn outputs into valid states - - def _state(t_, x): - corr_like = tree_util.tree_map(np.empty_like, s0.aux_corr) - return _StrategyState(t=t_, hidden=x[0], aux_extra=x[1], aux_corr=corr_like) - - step_from = _state(s1.t, interp.step_from) - interpolated = _state(t, interp.interpolated) - interp_from = _state(t, interp.interp_from) - return _InterpRes( - step_from=step_from, interpolated=interpolated, interp_from=interp_from - ) - - def offgrid_marginals(self, *, t, marginals_t1, posterior_t0, t0, t1, output_scale): - """Compute offgrid_marginals.""" - if not self.is_suitable_for_offgrid_marginals: - raise NotImplementedError - - dt0 = t - t0 - dt1 = t1 - t - state_t0 = self.init(t0, posterior_t0) - - interp = self.extrapolation.interpolate( - state_t0=(state_t0.hidden, state_t0.aux_extra), - marginal_t1=marginals_t1, - dt0=dt0, - dt1=dt1, - output_scale=output_scale, - ) - - (marginals, _aux) = interp.interpolated - u = self.ssm.stats.qoi(marginals) - return u, marginals - - -def strategy_smoother(prior, correction: _Correction, /, ssm) -> _Strategy: +def strategy_smoother(*, ssm): """Construct a smoother.""" - extrapolation = _ExtraImplSmoother(prior, ssm=ssm) - return _Strategy( - extrapolation=extrapolation, - correction=correction, - prior=prior, + return _ExtraImplSmoother( + name="Smoother", ssm=ssm, is_suitable_for_save_at=False, is_suitable_for_save_every_step=True, is_suitable_for_offgrid_marginals=True, - name="Smoother", ) -def strategy_fixedpoint(prior, correction: _Correction, /, ssm) -> _Strategy: +def strategy_fixedpoint(*, ssm): """Construct a fixedpoint-smoother.""" - extrapolation = _ExtraImplFixedPoint(prior, ssm=ssm) - return _Strategy( - extrapolation=extrapolation, - correction=correction, + return _ExtraImplFixedPoint( + name="Fixed-point smoother", ssm=ssm, - prior=prior, is_suitable_for_save_at=True, is_suitable_for_save_every_step=False, is_suitable_for_offgrid_marginals=False, - name="Fixed-point smoother", ) -def strategy_filter(prior, correction: _Correction, /, *, ssm) -> _Strategy: +def strategy_filter(*, ssm): """Construct a filter.""" - extrapolation = _ExtraImplFilter(prior, ssm=ssm) - return _Strategy( + return _ExtraImplFilter( name="Filter", - prior=prior, - extrapolation=extrapolation, - correction=correction, + ssm=ssm, is_suitable_for_save_at=True, - is_suitable_for_offgrid_marginals=True, is_suitable_for_save_every_step=True, - ssm=ssm, + is_suitable_for_offgrid_marginals=True, ) @@ -738,84 +617,148 @@ class _Calibration: extract: Callable -class _SolverState(containers.NamedTuple): +class _State(containers.NamedTuple): """Solver state.""" - strategy: Any + t: Any + hidden: Any + aux_extra: Any output_scale: Any - @property - def t(self): - return self.strategy.t - @containers.dataclass class _ProbabilisticSolver: name: str - calibration: _Calibration - step_implementation: Callable - strategy: _Strategy requires_rescaling: bool - @property - def offgrid_marginals(self): - return self.strategy.offgrid_marginals + step_implementation: Callable + + prior: _MarkovProcess + ssm: Any + extrapolation: _ExtraImpl + calibration: _Calibration + correction: _Correction + + def offgrid_marginals(self, *, t, marginals_t1, posterior_t0, t0, t1, output_scale): + """Compute offgrid_marginals.""" + if not self.is_suitable_for_offgrid_marginals: + raise NotImplementedError + + dt0 = t - t0 + dt1 = t1 - t + + rv, extra = self.extrapolation.init(posterior_t0) + rv, corr = self.correction.init(rv) + + # TODO: Replace dt0, dt1, and prior with prior_dt0, and prior_dt1 + interp = self.extrapolation.interpolate( + state_t0=(rv, extra), + marginal_t1=marginals_t1, + dt0=dt0, + dt1=dt1, + output_scale=output_scale, + prior=self.prior, + ) + + (marginals, _aux) = interp.interpolated + u = self.ssm.stats.qoi(marginals) + return u, marginals @property def error_contraction_rate(self): - return self.strategy.num_derivatives + 1 + return self.prior.num_derivatives + 1 + + @property + def is_suitable_for_offgrid_marginals(self): + return self.extrapolation.is_suitable_for_offgrid_marginals @property def is_suitable_for_save_at(self): - return self.strategy.is_suitable_for_save_at + return self.extrapolation.is_suitable_for_save_at @property def is_suitable_for_save_every_step(self): - return self.strategy.is_suitable_for_save_every_step + return self.extrapolation.is_suitable_for_save_every_step - def init(self, t, initial_condition) -> _SolverState: + def init(self, t, initial_condition) -> _State: posterior, output_scale = initial_condition - state_strategy = self.strategy.init(t, posterior) + + rv, extra = self.extrapolation.init(posterior) + rv, corr = self.correction.init(rv) + calib_state = self.calibration.init(output_scale) - return _SolverState(strategy=state_strategy, output_scale=calib_state) + return _State(t=t, hidden=rv, aux_extra=extra, output_scale=calib_state) - def step(self, state: _SolverState, *, vector_field, dt) -> _SolverState: + def step(self, state: _State, *, vector_field, dt) -> _State: return self.step_implementation( state, vector_field=vector_field, dt=dt, calibration=self.calibration ) - def extract(self, state: _SolverState, /): - t, posterior = self.strategy.extract(state.strategy) + def extract(self, state: _State, /): + hidden = self.correction.extract(state.hidden) + posterior = self.extrapolation.extract(hidden, state.aux_extra) + t = state.t + _output_scale_prior, output_scale = self.calibration.extract(state.output_scale) return t, (posterior, output_scale) - def interpolate( - self, t, *, interp_from: _SolverState, interp_to: _SolverState - ) -> _InterpRes: + def interpolate(self, t, *, interp_from: _State, interp_to: _State) -> _InterpRes: output_scale, _ = self.calibration.extract(interp_to.output_scale) - interp = self.strategy.case_interpolate( - t, s0=interp_from.strategy, s1=interp_to.strategy, output_scale=output_scale + return self._case_interpolate( + t, s0=interp_from, s1=interp_to, output_scale=output_scale + ) + + def _case_interpolate(self, t, *, s0, s1, output_scale) -> _InterpRes: + """Process the solution in case t>t_n.""" + # Interpolate + interp = self.extrapolation.interpolate( + state_t0=(s0.hidden, s0.aux_extra), + marginal_t1=s1.hidden, + dt0=t - s0.t, + dt1=s1.t - t, + output_scale=output_scale, + prior=self.prior, + ) + + # Turn outputs into valid states + + def _state(t_, x, scale): + return _State(t=t_, hidden=x[0], aux_extra=x[1], output_scale=scale) + + step_from = _state(s1.t, interp.step_from, s1.output_scale) + interpolated = _state(t, interp.interpolated, s1.output_scale) + interp_from = _state(t, interp.interp_from, s0.output_scale) + return _InterpRes( + step_from=step_from, interpolated=interpolated, interp_from=interp_from ) - prev = _SolverState(interp.interp_from, output_scale=interp_from.output_scale) - sol = _SolverState(interp.interpolated, output_scale=interp_to.output_scale) - acc = _SolverState(interp.step_from, output_scale=interp_to.output_scale) - return _InterpRes(step_from=acc, interpolated=sol, interp_from=prev) def interpolate_at_t1(self, *, interp_from, interp_to) -> _InterpRes: - x = self.strategy.case_interpolate_at_t1(interp_to.strategy) + """Process the solution in case t=t_n.""" + tmp = self.extrapolation.interpolate_at_t1( + interp_to.hidden, interp_to.aux_extra, prior=self.prior + ) + step_from_, solution_, interp_from_ = ( + tmp.step_from, + tmp.interpolated, + tmp.interp_from, + ) - prev = _SolverState(x.interp_from, output_scale=interp_from.output_scale) - sol = _SolverState(x.interpolated, output_scale=interp_to.output_scale) - acc = _SolverState(x.step_from, output_scale=interp_to.output_scale) + def _state(t_, s, scale): + return _State(t=t_, hidden=s[0], aux_extra=s[1], output_scale=scale) + + t = interp_to.t + prev = _state(t, interp_from_, interp_from.output_scale) + sol = _state(t, solution_, interp_to.output_scale) + acc = _state(t, step_from_, interp_to.output_scale) return _InterpRes(step_from=acc, interpolated=sol, interp_from=prev) def initial_condition(self): """Construct an initial condition.""" - posterior = self.strategy.initial_condition() - return posterior, self.strategy.prior.output_scale + posterior = self.extrapolation.initial_condition(prior=self.prior) + return posterior, self.prior.output_scale -def solver_mle(strategy, *, ssm): +def solver_mle(extrapolation, /, *, correction, prior, ssm): """Create a solver that calibrates the output scale via maximum-likelihood. Warning: needs to be combined with a call to stats.calibrate() @@ -824,27 +767,33 @@ def solver_mle(strategy, *, ssm): def step_mle(state, /, *, dt, vector_field, calibration): output_scale_prior, _calibrated = calibration.extract(state.output_scale) - error, _, state_strategy = strategy.begin( - state.strategy, dt=dt, vector_field=vector_field + + prior_discretized = prior.discretize(dt) + hidden, extra = extrapolation.begin( + state.hidden, state.aux_extra, prior_discretized=prior_discretized + ) + t = state.t + dt + error, _, corr = correction.estimate_error( + hidden, vector_field=vector_field, t=t ) - state_strategy = strategy.complete( - state_strategy, output_scale=output_scale_prior + hidden, extra = extrapolation.complete( + hidden, extra, output_scale=output_scale_prior ) - observed = state_strategy.aux_corr + hidden, observed = correction.complete(hidden, corr) - # Calibrate output_scale = calibration.update(state.output_scale, observed=observed) - - # Return - state = _SolverState(strategy=state_strategy, output_scale=output_scale) + state = _State(t=t, hidden=hidden, aux_extra=extra, output_scale=output_scale) return dt * error, state return _ProbabilisticSolver( + ssm=ssm, name="Probabilistic solver with MLE calibration", + prior=prior, calibration=_calibration_running_mean(ssm=ssm), step_implementation=step_mle, - strategy=strategy, + extrapolation=extrapolation, + correction=correction, requires_rescaling=True, ) @@ -872,25 +821,34 @@ def extract(state, /): return _Calibration(init=init, update=update, extract=extract) -def solver_dynamic(strategy, *, ssm): +def solver_dynamic(extrapolation, *, correction, prior, ssm): """Create a solver that calibrates the output scale dynamically.""" def step_dynamic(state, /, *, dt, vector_field, calibration): - error, observed, state_strategy = strategy.begin( - state.strategy, dt=dt, vector_field=vector_field + prior_discretized = prior.discretize(dt) + hidden, extra = extrapolation.begin( + state.hidden, state.aux_extra, prior_discretized=prior_discretized + ) + t = state.t + dt + error, observed, corr = correction.estimate_error( + hidden, vector_field=vector_field, t=t ) output_scale = calibration.update(state.output_scale, observed=observed) - prior, _calibrated = calibration.extract(output_scale) - state_strategy = strategy.complete(state_strategy, output_scale=prior) + prior_, _calibrated = calibration.extract(output_scale) + hidden, extra = extrapolation.complete(hidden, extra, output_scale=prior_) + hidden, corr = correction.complete(hidden, corr) # Return solution - state = _SolverState(strategy=state_strategy, output_scale=output_scale) + state = _State(t=t, hidden=hidden, aux_extra=extra, output_scale=output_scale) return dt * error, state return _ProbabilisticSolver( - strategy=strategy, + prior=prior, + ssm=ssm, + extrapolation=extrapolation, + correction=correction, calibration=_calibration_most_recent(ssm=ssm), name="Dynamic probabilistic solver", step_implementation=step_dynamic, @@ -911,24 +869,37 @@ def extract(state, /): return _Calibration(init=init, update=update, extract=extract) -def solver(strategy, /): +def solver(extrapolation, /, *, correction, prior, ssm): """Create a solver that does not calibrate the output scale automatically.""" - def step(state: _SolverState, *, vector_field, dt, calibration): + def step(state: _State, *, vector_field, dt, calibration): del calibration # unused - error, _observed, state_strategy = strategy.begin( - state.strategy, dt=dt, vector_field=vector_field + prior_discretized = prior.discretize(dt) + hidden, extra = extrapolation.begin( + state.hidden, state.aux_extra, prior_discretized=prior_discretized + ) + t = state.t + dt + error, _, corr = correction.estimate_error( + hidden, vector_field=vector_field, t=t ) - state_strategy = strategy.complete( - state_strategy, output_scale=state.output_scale + + hidden, extra = extrapolation.complete( + hidden, extra, output_scale=state.output_scale ) + hidden, corr = correction.complete(hidden, corr) + # Extract and return solution - state = _SolverState(strategy=state_strategy, output_scale=state.output_scale) + state = _State( + t=t, hidden=hidden, aux_extra=extra, output_scale=state.output_scale + ) return dt * error, state return _ProbabilisticSolver( - strategy=strategy, + ssm=ssm, + prior=prior, + extrapolation=extrapolation, + correction=correction, calibration=_calibration_none(), step_implementation=step, name="Probabilistic solver", diff --git a/tests/test_ivpsolve/test_fixed_grid_vs_save_every_step.py b/tests/test_ivpsolve/test_fixed_grid_vs_save_every_step.py index 49863791..cfb4f390 100644 --- a/tests/test_ivpsolve/test_fixed_grid_vs_save_every_step.py +++ b/tests/test_ivpsolve/test_fixed_grid_vs_save_every_step.py @@ -18,8 +18,8 @@ class Taylor(containers.NamedTuple): ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, ssm_fact=fact) ts0 = ivpsolvers.correction_ts0(ssm=ssm) - strategy = ivpsolvers.strategy_filter(ibm, ts0, ssm=ssm) - solver = ivpsolvers.solver_mle(strategy, ssm=ssm) + strategy = ivpsolvers.strategy_filter(ssm=ssm) + solver = ivpsolvers.solver_mle(strategy, prior=ibm, correction=ts0, ssm=ssm) control = ivpsolve.control_integral(clip=True) # Any clipped controller will do. asolver = ivpsolve.adaptive(solver, ssm=ssm, atol=1e-2, rtol=1e-2, control=control) diff --git a/tests/test_ivpsolve/test_save_at_vs_save_every_step.py b/tests/test_ivpsolve/test_save_at_vs_save_every_step.py index edb82a18..4bf35b35 100644 --- a/tests/test_ivpsolve/test_save_at_vs_save_every_step.py +++ b/tests/test_ivpsolve/test_save_at_vs_save_every_step.py @@ -13,10 +13,9 @@ def test_save_at_result_matches_interpolated_adaptive_result(fact): # Generate a solver tcoeffs = taylor.odejet_padded_scan(lambda y: vf(y, t=t0), u0, num=2) ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, ssm_fact=fact) - ts0 = ivpsolvers.correction_ts0(ssm=ssm) - strategy = ivpsolvers.strategy_filter(ibm, ts0, ssm=ssm) - solver = ivpsolvers.solver(strategy) + strategy = ivpsolvers.strategy_filter(ssm=ssm) + solver = ivpsolvers.solver(strategy, prior=ibm, correction=ts0, ssm=ssm) adaptive_solver = ivpsolve.adaptive(solver, atol=1e-2, rtol=1e-2, ssm=ssm) init = solver.initial_condition() diff --git a/tests/test_ivpsolve/test_save_every_step.py b/tests/test_ivpsolve/test_save_every_step.py index 664a2b95..24d4e7e8 100644 --- a/tests/test_ivpsolve/test_save_every_step.py +++ b/tests/test_ivpsolve/test_save_every_step.py @@ -24,8 +24,8 @@ def python_loop_solution(ivp, *, fact, strategy_fun): ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, ssm_fact=fact) ts0 = ivpsolvers.correction_ts0(ssm=ssm) - strategy = strategy_fun(ibm, ts0, ssm=ssm) - solver = ivpsolvers.solver_mle(strategy, ssm=ssm) + strategy = strategy_fun(ssm=ssm) + solver = ivpsolvers.solver_mle(strategy, prior=ibm, correction=ts0, ssm=ssm) # clip=False because we need to test adaptive-step-interpolation # for smoothers diff --git a/tests/test_ivpsolve/test_solution_object.py b/tests/test_ivpsolve/test_solution_object.py index a7831a5e..b9924349 100644 --- a/tests/test_ivpsolve/test_solution_object.py +++ b/tests/test_ivpsolve/test_solution_object.py @@ -24,8 +24,8 @@ def fixture_approximate_solution(fact): ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, ssm_fact=fact) ts0 = ivpsolvers.correction_ts0(ssm=ssm) - strategy = ivpsolvers.strategy_filter(ibm, ts0, ssm=ssm) - solver = ivpsolvers.solver_mle(strategy, ssm=ssm) + strategy = ivpsolvers.strategy_filter(ssm=ssm) + solver = ivpsolvers.solver_mle(strategy, prior=ibm, correction=ts0, ssm=ssm) asolver = ivpsolve.adaptive(solver, atol=1e-2, rtol=1e-2, ssm=ssm) init = solver.initial_condition() @@ -88,8 +88,8 @@ def solve(init): ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, ssm_fact=fact) ts0 = ivpsolvers.correction_ts0(ssm=ssm) - strategy = ivpsolvers.strategy_filter(ibm, ts0, ssm=ssm) - solver = ivpsolvers.solver_mle(strategy, ssm=ssm) + strategy = ivpsolvers.strategy_filter(ssm=ssm) + solver = ivpsolvers.solver_mle(strategy, prior=ibm, correction=ts0, ssm=ssm) adaptive_solver = ivpsolve.adaptive(solver, atol=1e-2, rtol=1e-2, ssm=ssm) initcond = solver.initial_condition() diff --git a/tests/test_ivpsolve/test_terminal_values_vs_save_every_step.py b/tests/test_ivpsolve/test_terminal_values_vs_save_every_step.py index f3e274b5..98e36064 100644 --- a/tests/test_ivpsolve/test_terminal_values_vs_save_every_step.py +++ b/tests/test_ivpsolve/test_terminal_values_vs_save_every_step.py @@ -14,8 +14,8 @@ def test_terminal_values_identical(fact): ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, ssm_fact=fact) ts0 = ivpsolvers.correction_ts0(ssm=ssm) - strategy = ivpsolvers.strategy_filter(ibm, ts0, ssm=ssm) - solver = ivpsolvers.solver_mle(strategy, ssm=ssm) + strategy = ivpsolvers.strategy_filter(ssm=ssm) + solver = ivpsolvers.solver_mle(strategy, prior=ibm, correction=ts0, ssm=ssm) asolver = ivpsolve.adaptive(solver, atol=1e-2, rtol=1e-2, ssm=ssm) init = solver.initial_condition() diff --git a/tests/test_ivpsolvers/test_calibration_dynamic_vs_mle.py b/tests/test_ivpsolvers/test_calibration_dynamic_vs_mle.py index 32167aa8..f03b727f 100644 --- a/tests/test_ivpsolvers/test_calibration_dynamic_vs_mle.py +++ b/tests/test_ivpsolvers/test_calibration_dynamic_vs_mle.py @@ -16,8 +16,8 @@ def test_exponential_approximated_well(fact): ibm, ssm = ivpsolvers.prior_ibm((*u0, vf(*u0, t=t0)), ssm_fact=fact) ts0 = ivpsolvers.correction_ts0(ssm=ssm) - strategy = ivpsolvers.strategy_filter(ibm, ts0, ssm=ssm) - solver = ivpsolvers.solver_dynamic(strategy, ssm=ssm) + strategy = ivpsolvers.strategy_filter(ssm=ssm) + solver = ivpsolvers.solver_dynamic(strategy, prior=ibm, correction=ts0, ssm=ssm) init = solver.initial_condition() diff --git a/tests/test_ivpsolvers/test_calibration_mle_vs_none.py b/tests/test_ivpsolvers/test_calibration_mle_vs_none.py index 78bbdfc6..36526039 100644 --- a/tests/test_ivpsolvers/test_calibration_mle_vs_none.py +++ b/tests/test_ivpsolvers/test_calibration_mle_vs_none.py @@ -6,8 +6,8 @@ """ from probdiffeq import ivpsolve, ivpsolvers, stats, taylor -from probdiffeq.backend import functools, ode, testing from probdiffeq.backend import numpy as np +from probdiffeq.backend import ode, testing @testing.case() @@ -21,8 +21,8 @@ def case_solve_fixed_grid(fact): kwargs = {"grid": np.linspace(t0, t1, endpoint=True, num=5), "ssm": ssm} def solver_to_solution(solver_fun, strategy_fun): - strategy = strategy_fun(ibm, ts0, ssm=ssm) - solver = solver_fun(strategy) + strategy = strategy_fun(ssm=ssm) + solver = solver_fun(strategy, prior=ibm, correction=ts0, ssm=ssm) init = solver.initial_condition() return ivpsolve.solve_fixed_grid(vf, init, solver=solver, **kwargs) @@ -44,8 +44,8 @@ def case_solve_adaptive_save_at(fact): kwargs = {"save_at": save_at, "dt0": dt0, "ssm": ssm} def solver_to_solution(solver_fun, strategy_fun): - strategy = strategy_fun(ibm, ts0, ssm=ssm) - solver = solver_fun(strategy) + strategy = strategy_fun(ssm=ssm) + solver = solver_fun(strategy, prior=ibm, correction=ts0, ssm=ssm) init = solver.initial_condition() adaptive_solver = ivpsolve.adaptive(solver, atol=1e-2, rtol=1e-2, ssm=ssm) @@ -69,8 +69,8 @@ def case_solve_adaptive_save_every_step(fact): kwargs = {"t0": t0, "t1": t1, "dt0": dt0, "ssm": ssm} def solver_to_solution(solver_fun, strategy_fun): - strategy = strategy_fun(ibm, ts0, ssm=ssm) - solver = solver_fun(strategy) + strategy = strategy_fun(ssm=ssm) + solver = solver_fun(strategy, prior=ibm, correction=ts0, ssm=ssm) init = solver.initial_condition() adaptive_solver = ivpsolve.adaptive(solver, atol=1e-2, rtol=1e-2, ssm=ssm) @@ -94,8 +94,8 @@ def case_simulate_terminal_values(fact): kwargs = {"t0": t0, "t1": t1, "dt0": dt0, "ssm": ssm} def solver_to_solution(solver_fun, strategy_fun): - strategy = strategy_fun(ibm, ts0, ssm=ssm) - solver = solver_fun(strategy) + strategy = strategy_fun(ssm=ssm) + solver = solver_fun(strategy, prior=ibm, correction=ts0, ssm=ssm) init = solver.initial_condition() adaptive_solver = ivpsolve.adaptive(solver, ssm=ssm, atol=1e-2, rtol=1e-2) @@ -114,7 +114,7 @@ def solver_to_solution(solver_fun, strategy_fun): def fixture_uncalibrated_and_mle_solution(solver_to_solution, strategy_fun): solve, ssm = solver_to_solution uncalib = solve(ivpsolvers.solver, strategy_fun) - mle = solve(functools.partial(ivpsolvers.solver_mle, ssm=ssm), strategy_fun) + mle = solve(ivpsolvers.solver_mle, strategy_fun) return (uncalib, mle), ssm diff --git a/tests/test_ivpsolvers/test_corrections.py b/tests/test_ivpsolvers/test_corrections.py index ada4ae58..482a3ff7 100644 --- a/tests/test_ivpsolvers/test_corrections.py +++ b/tests/test_ivpsolvers/test_corrections.py @@ -46,8 +46,8 @@ def fixture_solution(correction_impl, fact): except NotImplementedError: testing.skip(reason="This type of linearisation has not been implemented.") - strategy = ivpsolvers.strategy_filter(ibm, corr, ssm=ssm) - solver = ivpsolvers.solver_mle(strategy, ssm=ssm) + strategy = ivpsolvers.strategy_filter(ssm=ssm) + solver = ivpsolvers.solver_mle(strategy, prior=ibm, correction=corr, ssm=ssm) adaptive_solver = ivpsolve.adaptive(solver, atol=1e-2, rtol=1e-2, ssm=ssm) adaptive_kwargs = {"adaptive_solver": adaptive_solver, "dt0": 0.1, "ssm": ssm} diff --git a/tests/test_ivpsolvers/test_strategy_smoother_vs_filter.py b/tests/test_ivpsolvers/test_strategy_smoother_vs_filter.py index 3ab46254..b2d4dee5 100644 --- a/tests/test_ivpsolvers/test_strategy_smoother_vs_filter.py +++ b/tests/test_ivpsolvers/test_strategy_smoother_vs_filter.py @@ -20,8 +20,8 @@ def fixture_filter_solution(solver_setup): tcoeffs = solver_setup["tcoeffs"] ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, ssm_fact=solver_setup["fact"]) ts0 = ivpsolvers.correction_ts0(ssm=ssm) - strategy = ivpsolvers.strategy_filter(ibm, ts0, ssm=ssm) - solver = ivpsolvers.solver(strategy) + strategy = ivpsolvers.strategy_filter(ssm=ssm) + solver = ivpsolvers.solver(strategy, prior=ibm, correction=ts0, ssm=ssm) init = solver.initial_condition() return ivpsolve.solve_fixed_grid( @@ -34,8 +34,8 @@ def fixture_smoother_solution(solver_setup): tcoeffs = solver_setup["tcoeffs"] ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, ssm_fact=solver_setup["fact"]) ts0 = ivpsolvers.correction_ts0(ssm=ssm) - strategy = ivpsolvers.strategy_smoother(ibm, ts0, ssm=ssm) - solver = ivpsolvers.solver(strategy) + strategy = ivpsolvers.strategy_smoother(ssm=ssm) + solver = ivpsolvers.solver(strategy, prior=ibm, correction=ts0, ssm=ssm) init = solver.initial_condition() return ivpsolve.solve_fixed_grid( diff --git a/tests/test_ivpsolvers/test_strategy_smoother_vs_fixedpoint.py b/tests/test_ivpsolvers/test_strategy_smoother_vs_fixedpoint.py index 2852edd0..8f7e0274 100644 --- a/tests/test_ivpsolvers/test_strategy_smoother_vs_fixedpoint.py +++ b/tests/test_ivpsolvers/test_strategy_smoother_vs_fixedpoint.py @@ -22,8 +22,8 @@ def fixture_solution_smoother(solver_setup): tcoeffs, fact = solver_setup["tcoeffs"], solver_setup["fact"] ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, ssm_fact=fact) ts0 = ivpsolvers.correction_ts0(ssm=ssm) - strategy = ivpsolvers.strategy_smoother(ibm, ts0, ssm=ssm) - solver = ivpsolvers.solver(strategy) + strategy = ivpsolvers.strategy_smoother(ssm=ssm) + solver = ivpsolvers.solver(strategy, prior=ibm, correction=ts0, ssm=ssm) adaptive_solver = ivpsolve.adaptive(solver, atol=1e-3, rtol=1e-3, ssm=ssm) init = solver.initial_condition() @@ -43,8 +43,8 @@ def test_fixedpoint_smoother_equivalent_same_grid(solver_setup, solution_smoothe tcoeffs, fact = solver_setup["tcoeffs"], solver_setup["fact"] ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, ssm_fact=fact) ts0 = ivpsolvers.correction_ts0(ssm=ssm) - strategy = ivpsolvers.strategy_fixedpoint(ibm, ts0, ssm=ssm) - solver = ivpsolvers.solver(strategy) + strategy = ivpsolvers.strategy_fixedpoint(ssm=ssm) + solver = ivpsolvers.solver(strategy, prior=ibm, correction=ts0, ssm=ssm) adaptive_solver = ivpsolve.adaptive(solver, atol=1e-3, rtol=1e-3, ssm=ssm) save_at = solution_smoother.t @@ -69,8 +69,8 @@ def test_fixedpoint_smoother_equivalent_different_grid(solver_setup, solution_sm tcoeffs, fact = solver_setup["tcoeffs"], solver_setup["fact"] ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, ssm_fact=fact) ts0 = ivpsolvers.correction_ts0(ssm=ssm) - strategy = ivpsolvers.strategy_smoother(ibm, ts0, ssm=ssm) - solver_smoother = ivpsolvers.solver(strategy) + strategy = ivpsolvers.strategy_smoother(ssm=ssm) + solver_smoother = ivpsolvers.solver(strategy, prior=ibm, correction=ts0, ssm=ssm) # Compute the offgrid-marginals ts = np.linspace(save_at[0], save_at[-1], num=7, endpoint=True) @@ -82,8 +82,8 @@ def test_fixedpoint_smoother_equivalent_different_grid(solver_setup, solution_sm tcoeffs, fact = solver_setup["tcoeffs"], solver_setup["fact"] ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, ssm_fact=fact) ts0 = ivpsolvers.correction_ts0(ssm=ssm) - strategy = ivpsolvers.strategy_fixedpoint(ibm, ts0, ssm=ssm) - solver = ivpsolvers.solver(strategy) + strategy = ivpsolvers.strategy_fixedpoint(ssm=ssm) + solver = ivpsolvers.solver(strategy, prior=ibm, correction=ts0, ssm=ssm) adaptive_solver = ivpsolve.adaptive(solver, atol=1e-3, rtol=1e-3, ssm=ssm) init = solver.initial_condition() diff --git a/tests/test_ivpsolvers/test_strategy_warnings_for_wrong_strategies.py b/tests/test_ivpsolvers/test_strategy_warnings_for_wrong_strategies.py index bbbcf1a8..6a776906 100644 --- a/tests/test_ivpsolvers/test_strategy_warnings_for_wrong_strategies.py +++ b/tests/test_ivpsolvers/test_strategy_warnings_for_wrong_strategies.py @@ -13,8 +13,8 @@ def test_warning_for_fixedpoint_in_save_every_step_mode(fact): ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, ssm_fact=fact) ts0 = ivpsolvers.correction_ts0(ssm=ssm) - strategy = ivpsolvers.strategy_fixedpoint(ibm, ts0, ssm=ssm) - solver = ivpsolvers.solver(strategy) + strategy = ivpsolvers.strategy_fixedpoint(ssm=ssm) + solver = ivpsolvers.solver(strategy, prior=ibm, correction=ts0, ssm=ssm) adaptive_solver = ivpsolve.adaptive(solver, atol=1e-2, rtol=1e-2, ssm=ssm) init = solver.initial_condition() @@ -32,8 +32,8 @@ def test_warning_for_smoother_in_save_at_mode(fact): tcoeffs = taylor.odejet_padded_scan(lambda y: vf(y, t=t0), (u0,), num=2) ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, ssm_fact=fact) ts0 = ivpsolvers.correction_ts0(ssm=ssm) - strategy = ivpsolvers.strategy_smoother(ibm, ts0, ssm=ssm) - solver = ivpsolvers.solver(strategy) + strategy = ivpsolvers.strategy_smoother(ssm=ssm) + solver = ivpsolvers.solver(strategy, prior=ibm, correction=ts0, ssm=ssm) adaptive_solver = ivpsolve.adaptive(solver, atol=1e-2, rtol=1e-2, ssm=ssm) init = solver.initial_condition() diff --git a/tests/test_stats/test_log_marginal_likelihood.py b/tests/test_stats/test_log_marginal_likelihood.py index e1f54bd9..79601a3e 100644 --- a/tests/test_stats/test_log_marginal_likelihood.py +++ b/tests/test_stats/test_log_marginal_likelihood.py @@ -14,8 +14,8 @@ def fixture_solution(fact): ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, ssm_fact=fact) ts0 = ivpsolvers.correction_ts0(ssm=ssm) - strategy = ivpsolvers.strategy_fixedpoint(ibm, ts0, ssm=ssm) - solver = ivpsolvers.solver(strategy) + strategy = ivpsolvers.strategy_fixedpoint(ssm=ssm) + solver = ivpsolvers.solver(strategy, prior=ibm, correction=ts0, ssm=ssm) adaptive_solver = ivpsolve.adaptive(solver, atol=1e-2, rtol=1e-2, ssm=ssm) init = solver.initial_condition() @@ -96,8 +96,8 @@ def test_raises_error_for_filter(fact): ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, ssm_fact=fact) ts0 = ivpsolvers.correction_ts0(ssm=ssm) - strategy = ivpsolvers.strategy_filter(ibm, ts0, ssm=ssm) - solver = ivpsolvers.solver(strategy) + strategy = ivpsolvers.strategy_filter(ssm=ssm) + solver = ivpsolvers.solver(strategy, prior=ibm, correction=ts0, ssm=ssm) grid = np.linspace(t0, t1, num=3) init = solver.initial_condition() diff --git a/tests/test_stats/test_log_marginal_likelihood_terminal_values.py b/tests/test_stats/test_log_marginal_likelihood_terminal_values.py index 172a391a..c5cd651d 100644 --- a/tests/test_stats/test_log_marginal_likelihood_terminal_values.py +++ b/tests/test_stats/test_log_marginal_likelihood_terminal_values.py @@ -29,8 +29,8 @@ def fixture_solution(strategy_func, fact): tcoeffs = taylor.odejet_padded_scan(lambda y: vf(y, t=t0), (u0,), num=4) ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, ssm_fact=fact) ts0 = ivpsolvers.correction_ts0(ssm=ssm) - strategy = strategy_func(ibm, ts0, ssm=ssm) - solver = ivpsolvers.solver(strategy) + strategy = strategy_func(ssm=ssm) + solver = ivpsolvers.solver(strategy, prior=ibm, correction=ts0, ssm=ssm) adaptive_solver = ivpsolve.adaptive(solver, atol=1e-2, rtol=1e-2, ssm=ssm) init = solver.initial_condition() diff --git a/tests/test_stats/test_offgrid_marginals.py b/tests/test_stats/test_offgrid_marginals.py index 5ae6c420..baf8a8ad 100644 --- a/tests/test_stats/test_offgrid_marginals.py +++ b/tests/test_stats/test_offgrid_marginals.py @@ -13,8 +13,8 @@ def test_filter_marginals_close_only_to_left_boundary(fact): tcoeffs = (u0, vf(u0, t=t0)) ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, ssm_fact=fact) ts0 = ivpsolvers.correction_ts0(ssm=ssm) - strategy = ivpsolvers.strategy_filter(ibm, ts0, ssm=ssm) - solver = ivpsolvers.solver(strategy) + strategy = ivpsolvers.strategy_filter(ssm=ssm) + solver = ivpsolvers.solver(strategy, prior=ibm, correction=ts0, ssm=ssm) init = solver.initial_condition() grid = np.linspace(t0, t1, endpoint=True, num=5) @@ -37,8 +37,8 @@ def test_smoother_marginals_close_to_both_boundaries(fact): tcoeffs = taylor.odejet_padded_scan(lambda y: vf(y, t=t0), (u0,), num=4) ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, ssm_fact=fact) ts0 = ivpsolvers.correction_ts0(ssm=ssm) - strategy = ivpsolvers.strategy_smoother(ibm, ts0, ssm=ssm) - solver = ivpsolvers.solver(strategy) + strategy = ivpsolvers.strategy_smoother(ssm=ssm) + solver = ivpsolvers.solver(strategy, prior=ibm, correction=ts0, ssm=ssm) init = solver.initial_condition() grid = np.linspace(t0, t1, endpoint=True, num=5) diff --git a/tests/test_stats/test_sample.py b/tests/test_stats/test_sample.py index 9beddef9..bc4cab95 100644 --- a/tests/test_stats/test_sample.py +++ b/tests/test_stats/test_sample.py @@ -12,8 +12,8 @@ def fixture_approximation(fact): tcoeffs = taylor.odejet_padded_scan(lambda y: vf(y, t=t0), (u0,), num=2) ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, ssm_fact=fact) ts0 = ivpsolvers.correction_ts0(ssm=ssm) - strategy = ivpsolvers.strategy_smoother(ibm, ts0, ssm=ssm) - solver = ivpsolvers.solver(strategy) + strategy = ivpsolvers.strategy_smoother(ssm=ssm) + solver = ivpsolvers.solver(strategy, prior=ibm, correction=ts0, ssm=ssm) adaptive_solver = ivpsolve.adaptive(solver, atol=1e-2, rtol=1e-2, ssm=ssm) init = solver.initial_condition()