Skip to content

Commit

Permalink
Separate strategies from priors and corrections (#797)
Browse files Browse the repository at this point in the history
* Move name, is_suitable* from _Strategy to _ExtraImpl

* Move extrapolation from _Strategy to _ProbabilisticSolver

* Move is_suitable_for_offgrid_marginals and ssm

* Move the prior

* Update the smoother

* Move prior out of ExtraImpl

* Delete Strategy.initial_condition

* Move init() and offgrid_marginals()

* Move strategy.begin

* Move strategy.extract

* Migrate interpolate_at_t1

* Migrate case_interpolate

* Update the dynamic solver

* Remove the strategy variable

* Move 't' from StrategyState to SolverState

* Delete the aux_corr field

* Delete the strategy state

* Remove the _StrategyState

* Update the dynamic solver

* Update test_ivpsolve

* Update test_ivpsolvers

* Update test_stats

* Update the benchmarks

* Update the examples
  • Loading branch information
pnkraemer authored Oct 26, 2024
1 parent d4d836b commit 4e0af24
Show file tree
Hide file tree
Showing 31 changed files with 321 additions and 345 deletions.
4 changes: 2 additions & 2 deletions docs/benchmarks/hires/run_hires.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,8 @@ def param_to_solution(tol):
tcoeffs = taylor.odejet_padded_scan(vf_auto, (u0,), num=num_derivatives)
ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, ssm_fact="dense")
ts1 = ivpsolvers.correction_ts1(ssm=ssm)
strategy = ivpsolvers.strategy_filter(ibm, ts1, ssm=ssm)
solver = ivpsolvers.solver_dynamic(strategy, ssm=ssm)
strategy = ivpsolvers.strategy_filter(ssm=ssm)
solver = ivpsolvers.solver_dynamic(strategy, prior=ibm, correction=ts1, ssm=ssm)
control = ivpsolve.control_proportional_integral(clip=True)
adaptive_solver = ivpsolve.adaptive(
solver, atol=1e-2 * tol, rtol=tol, control=control, ssm=ssm
Expand Down
5 changes: 3 additions & 2 deletions docs/benchmarks/lotkavolterra/run_lotkavolterra.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,9 @@ def param_to_solution(tol):
tcoeffs = taylor.odejet_padded_scan(vf_auto, (u0,), num=num_derivatives)

ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, ssm_fact=implementation)
strategy = ivpsolvers.strategy_filter(ibm, correction(ssm=ssm), ssm=ssm)
solver = ivpsolvers.solver_mle(strategy, ssm=ssm)
strategy = ivpsolvers.strategy_filter(ssm=ssm)
corr = correction(ssm=ssm)
solver = ivpsolvers.solver_mle(strategy, prior=ibm, correction=corr, ssm=ssm)
control = ivpsolve.control_proportional_integral()
adaptive_solver = ivpsolve.adaptive(
solver, atol=1e-2 * tol, rtol=tol, control=control, ssm=ssm
Expand Down
6 changes: 4 additions & 2 deletions docs/benchmarks/pleiades/run_pleiades.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,10 @@ def param_to_solution(tol):

ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, ssm_fact="isotropic")
ts0_or_ts1 = correction_fun(ssm=ssm, ode_order=2)
strategy = ivpsolvers.strategy_filter(ibm, ts0_or_ts1, ssm=ssm)
solver = ivpsolvers.solver_dynamic(strategy, ssm=ssm)
strategy = ivpsolvers.strategy_filter(ssm=ssm)
solver = ivpsolvers.solver_dynamic(
strategy, prior=ibm, correction=ts0_or_ts1, ssm=ssm
)
control = ivpsolve.control_proportional_integral()
adaptive_solver = ivpsolve.adaptive(
solver, atol=1e-3 * tol, rtol=tol, control=control, ssm=ssm
Expand Down
6 changes: 4 additions & 2 deletions docs/benchmarks/vanderpol/run_vanderpol.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,11 @@ def param_to_solution(tol):

ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, ssm_fact="dense")
ts0_or_ts1 = ivpsolvers.correction_ts1(ode_order=2, ssm=ssm)
strategy = ivpsolvers.strategy_filter(ibm, ts0_or_ts1, ssm=ssm)
strategy = ivpsolvers.strategy_filter(ssm=ssm)

solver = ivpsolvers.solver_dynamic(strategy, ssm=ssm)
solver = ivpsolvers.solver_dynamic(
strategy, prior=ibm, correction=ts0_or_ts1, ssm=ssm
)
control = ivpsolve.control_proportional_integral(clip=True)
adaptive_solver = ivpsolve.adaptive(
solver, atol=1e-3 * tol, rtol=tol, control=control, ssm=ssm
Expand Down
4 changes: 2 additions & 2 deletions docs/examples_misc/use_equinox_bounded_while_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ def vf(y, *, t): # noqa: ARG001
ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, ssm_fact="isotropic")
ts0 = ivpsolvers.correction_ts0(ode_order=1, ssm=ssm)

strategy = ivpsolvers.strategy_fixedpoint(ibm, ts0, ssm=ssm)
solver = ivpsolvers.solver(strategy)
strategy = ivpsolvers.strategy_fixedpoint(ssm=ssm)
solver = ivpsolvers.solver(strategy, prior=ibm, correction=ts0, ssm=ssm)
adaptive_solver = ivpsolve.adaptive(solver, ssm=ssm)
init = solver.initial_condition()

Expand Down
16 changes: 8 additions & 8 deletions docs/examples_parameter_estimation/neural_ode.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ def loss_fn(parameters):
tcoeffs = (*initial_values, vf(*initial_values, t=t0, p=parameters))
ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, ssm_fact="isotropic")
ts0 = ivpsolvers.correction_ts0(ssm=ssm)
strategy = ivpsolvers.strategy_smoother(ibm, ts0, ssm=ssm)
solver_ts0 = ivpsolvers.solver(strategy)
strategy = ivpsolvers.strategy_smoother(ssm=ssm)
solver_ts0 = ivpsolvers.solver(strategy, prior=ibm, correction=ts0, ssm=ssm)
init = solver_ts0.initial_condition()

sol = ivpsolve.solve_fixed_grid(
Expand Down Expand Up @@ -128,8 +128,8 @@ def vf(y, *, t, p):
tcoeffs = (u0, vf(u0, t=t0, p=f_args))
ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, output_scale=1.0, ssm_fact="isotropic")
ts0 = ivpsolvers.correction_ts0(ssm=ssm)
strategy = ivpsolvers.strategy_smoother(ibm, ts0, ssm=ssm)
solver_ts0 = ivpsolvers.solver(strategy)
strategy = ivpsolvers.strategy_smoother(ssm=ssm)
solver_ts0 = ivpsolvers.solver(strategy, prior=ibm, correction=ts0, ssm=ssm)
init = solver_ts0.initial_condition()

# +
Expand Down Expand Up @@ -168,8 +168,8 @@ def vf(y, *, t, p):
tcoeffs = (u0, vf(u0, t=t0, p=f_args))
ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, output_scale=1.0, ssm_fact="isotropic")
ts0 = ivpsolvers.correction_ts0(ssm=ssm)
strategy = ivpsolvers.strategy_smoother(ibm, ts0, ssm=ssm)
solver_ts0 = ivpsolvers.solver(strategy)
strategy = ivpsolvers.strategy_smoother(ssm=ssm)
solver_ts0 = ivpsolvers.solver(strategy, prior=ibm, correction=ts0, ssm=ssm)
init = solver_ts0.initial_condition()

sol = ivpsolve.solve_fixed_grid(
Expand All @@ -182,8 +182,8 @@ def vf(y, *, t, p):
tcoeffs = (u0, vf(u0, t=t0, p=f_args))
ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, output_scale=1.0, ssm_fact="isotropic")
ts0 = ivpsolvers.correction_ts0(ssm=ssm)
strategy = ivpsolvers.strategy_smoother(ibm, ts0, ssm=ssm)
solver_ts0 = ivpsolvers.solver(strategy)
strategy = ivpsolvers.strategy_smoother(ssm=ssm)
solver_ts0 = ivpsolvers.solver(strategy, prior=ibm, correction=ts0, ssm=ssm)
init = solver_ts0.initial_condition()

sol = ivpsolve.solve_fixed_grid(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ def solve(p):
tcoeffs, output_scale=output_scale, ssm_fact="isotropic"
)
ts0 = ivpsolvers.correction_ts0(ssm=ssm)
strategy = ivpsolvers.strategy_smoother(ibm, ts0, ssm=ssm)
solver = ivpsolvers.solver(strategy)
strategy = ivpsolvers.strategy_smoother(ssm=ssm)
solver = ivpsolvers.solver(strategy, prior=ibm, correction=ts0, ssm=ssm)

init = solver.initial_condition()
return ivpsolve.solve_fixed_grid(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -191,8 +191,8 @@ def solve_fixed(theta, *, ts):
tcoeffs, output_scale=output_scale, ssm_fact="isotropic"
)
ts0 = ivpsolvers.correction_ts0(ssm=ssm)
strategy = ivpsolvers.strategy_filter(ibm, ts0, ssm=ssm)
solver = ivpsolvers.solver(strategy)
strategy = ivpsolvers.strategy_filter(ssm=ssm)
solver = ivpsolvers.solver(strategy, prior=ibm, correction=ts0, ssm=ssm)
init = solver.initial_condition()
sol = ivpsolve.solve_fixed_grid(vf, init, grid=ts, solver=solver, ssm=ssm)
return sol[-1]
Expand All @@ -208,8 +208,8 @@ def solve_adaptive(theta, *, save_at):
tcoeffs, output_scale=output_scale, ssm_fact="isotropic"
)
ts0 = ivpsolvers.correction_ts0(ssm=ssm)
strategy = ivpsolvers.strategy_filter(ibm, ts0, ssm=ssm)
solver = ivpsolvers.solver(strategy)
strategy = ivpsolvers.strategy_filter(ssm=ssm)
solver = ivpsolvers.solver(strategy, prior=ibm, correction=ts0, ssm=ssm)
adaptive_solver = ivpsolve.adaptive(solver, ssm=ssm)

init = solver.initial_condition()
Expand Down
6 changes: 3 additions & 3 deletions docs/examples_quickstart/easy_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,11 @@ def vf(y, *, t): # noqa: ARG001
# Set up a state-space model
tcoeffs = taylor.odejet_padded_scan(lambda y: vf(y, t=t0), (u0,), num=4)
ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, ssm_fact="dense")
ts0 = ivpsolvers.correction_ts1(ode_order=1, ssm=ssm)
strategy = ivpsolvers.strategy_smoother(ssm=ssm)

# Build a solver
ts0 = ivpsolvers.correction_ts1(ode_order=1, ssm=ssm)
strategy = ivpsolvers.strategy_smoother(ibm, ts0, ssm=ssm)
solver = ivpsolvers.solver_mle(strategy, ssm=ssm)
solver = ivpsolvers.solver_mle(strategy, prior=ibm, correction=ts0, ssm=ssm)
adaptive_solver = ivpsolve.adaptive(solver, ssm=ssm)
# -

Expand Down
5 changes: 3 additions & 2 deletions docs/examples_solver_config/conditioning-on-zero-residual.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,9 @@ def vector_field(y, t): # noqa: ARG001
# Compute the posterior

ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, output_scale=1.0, ssm_fact="dense")
slr1 = ivpsolvers.correction_ts1(ssm=ssm)
solver = ivpsolvers.solver(ivpsolvers.strategy_fixedpoint(ibm, slr1, ssm=ssm))
ts1 = ivpsolvers.correction_ts1(ssm=ssm)
strategy = ivpsolvers.strategy_fixedpoint(ssm=ssm)
solver = ivpsolvers.solver(strategy, prior=ibm, correction=ts1, ssm=ssm)
adaptive_solver = ivpsolve.adaptive(solver, atol=1e-1, rtol=1e-2, ssm=ssm)

dt0 = ivpsolve.dt0(lambda y: vector_field(y, t=t0), (u0,))
Expand Down
6 changes: 3 additions & 3 deletions docs/examples_solver_config/dynamic_output_scales.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,9 @@ def vf(*ys, t): # noqa: ARG001
tcoeffs = (u0, vf(u0, t=t0))
ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, output_scale=1.0, ssm_fact="dense")
ts1 = ivpsolvers.correction_ts1(ssm=ssm)
strategy = ivpsolvers.strategy_filter(ibm, ts1, ssm=ssm)
dynamic = ivpsolvers.solver_dynamic(strategy, ssm=ssm)
mle = ivpsolvers.solver_mle(strategy, ssm=ssm)
strategy = ivpsolvers.strategy_filter(ssm=ssm)
dynamic = ivpsolvers.solver_dynamic(strategy, prior=ibm, correction=ts1, ssm=ssm)
mle = ivpsolvers.solver_mle(strategy, prior=ibm, correction=ts1, ssm=ssm)

# +
t0, t1 = 0.0, 3.0
Expand Down
8 changes: 4 additions & 4 deletions docs/examples_solver_config/posterior_uncertainties.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ def vf(*ys, t): # noqa: ARG001
tcoeffs = taylor.odejet_padded_scan(lambda y: vf(y, t=t0), (u0,), num=4)
ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, output_scale=1.0, ssm_fact="isotropic")
ts0 = ivpsolvers.correction_ts0(ssm=ssm)
solver = ivpsolvers.solver_mle(ivpsolvers.strategy_filter(ibm, ts0, ssm=ssm), ssm=ssm)
strategy = ivpsolvers.strategy_filter(ssm=ssm)
solver = ivpsolvers.solver_mle(strategy, prior=ibm, correction=ts0, ssm=ssm)
adaptive_solver = ivpsolve.adaptive(solver, atol=1e-2, rtol=1e-2, ssm=ssm)

ts = jnp.linspace(t0, t0 + 2.0, endpoint=True, num=500)
Expand Down Expand Up @@ -115,9 +116,8 @@ def vf(*ys, t): # noqa: ARG001
# +
ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, output_scale=1.0, ssm_fact="isotropic")
ts0 = ivpsolvers.correction_ts0(ssm=ssm)
solver = ivpsolvers.solver_mle(
ivpsolvers.strategy_fixedpoint(ibm, ts0, ssm=ssm), ssm=ssm
)
strategy = ivpsolvers.strategy_fixedpoint(ssm=ssm)
solver = ivpsolvers.solver_mle(strategy, prior=ibm, correction=ts0, ssm=ssm)
adaptive_solver = ivpsolve.adaptive(solver, atol=1e-2, rtol=1e-2, ssm=ssm)

ts = jnp.linspace(t0, t0 + 2.0, endpoint=True, num=500)
Expand Down
8 changes: 4 additions & 4 deletions docs/examples_solver_config/second_order_problems.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ def vf_1(y, t): # noqa: ARG001
tcoeffs = taylor.odejet_padded_scan(lambda y: vf_1(y, t=t0), (u0,), num=4)
ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, output_scale=1.0, ssm_fact="isotropic")
ts0 = ivpsolvers.correction_ts0(ssm=ssm)
strategy = ivpsolvers.strategy_filter(ibm, ts0, ssm=ssm)
solver_1st = ivpsolvers.solver_mle(strategy, ssm=ssm)
strategy = ivpsolvers.strategy_filter(ssm=ssm)
solver_1st = ivpsolvers.solver_mle(strategy, prior=ibm, correction=ts0, ssm=ssm)
adaptive_solver_1st = ivpsolve.adaptive(solver_1st, atol=1e-5, rtol=1e-5, ssm=ssm)


Expand Down Expand Up @@ -86,8 +86,8 @@ def vf_2(y, dy, t): # noqa: ARG001
tcoeffs = taylor.odejet_padded_scan(lambda *ys: vf_2(*ys, t=t0), (u0, du0), num=3)
ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, output_scale=1.0, ssm_fact="isotropic")
ts0 = ivpsolvers.correction_ts0(ode_order=2, ssm=ssm)
strategy = ivpsolvers.strategy_filter(ibm, ts0, ssm=ssm)
solver_2nd = ivpsolvers.solver_mle(strategy, ssm=ssm)
strategy = ivpsolvers.strategy_filter(ssm=ssm)
solver_2nd = ivpsolvers.solver_mle(strategy, prior=ibm, correction=ts0, ssm=ssm)
adaptive_solver_2nd = ivpsolve.adaptive(solver_2nd, atol=1e-5, rtol=1e-5, ssm=ssm)


Expand Down
4 changes: 2 additions & 2 deletions docs/examples_solver_config/taylor_coefficients.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ def solve(tc):
"""Solve the ODE."""
prior, ssm = ivpsolvers.prior_ibm(tc, ssm_fact="dense")
ts0 = ivpsolvers.correction_ts0(ssm=ssm)
strategy = ivpsolvers.strategy_fixedpoint(prior, ts0, ssm=ssm)
solver = ivpsolvers.solver_mle(strategy, ssm=ssm)
strategy = ivpsolvers.strategy_fixedpoint(ssm=ssm)
solver = ivpsolvers.solver_mle(strategy, prior=prior, correction=ts0, ssm=ssm)
init = solver.initial_condition()

ts = jnp.linspace(t0, t1, endpoint=True, num=10)
Expand Down
4 changes: 2 additions & 2 deletions probdiffeq/ivpsolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,8 +205,8 @@ def body_fn(state: _RejectionState) -> _RejectionState:
dt=self.control.extract(state_control),
)
# Normalise the error
u_proposed = self.ssm.stats.qoi(state_proposed.strategy.hidden)[0]
u_step_from = self.ssm.stats.qoi(state_proposed.strategy.hidden)[0]
u_proposed = self.ssm.stats.qoi(state_proposed.hidden)[0]
u_step_from = self.ssm.stats.qoi(state_proposed.hidden)[0]
u = np.maximum(np.abs(u_proposed), np.abs(u_step_from))
error_power = _error_scale_and_normalize(error_estimate, u=u)

Expand Down
Loading

0 comments on commit 4e0af24

Please sign in to comment.