From 7581747fc36ab5182dbf67b6925a8047991fffe6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicholas=20Kr=C3=A4mer?= Date: Thu, 13 Jun 2024 10:11:49 +0200 Subject: [PATCH] Move the content of adaptive to ivpsolve because they are always used together (#757) * Rename simulate_terminal... to solve_for_terminal... so all solution routines share a prefix * Hide the implementation of the Solution * Rename solution routines to express adaptivity vs fixed steps * Move content of adaptive.py to ivpsolve.py --- docs/api_docs/adaptive.md | 1 - docs/benchmarks/hires/run_hires.py | 8 +- .../lotkavolterra/run_lotkavolterra.py | 8 +- docs/benchmarks/pleiades/run_pleiades.py | 8 +- docs/benchmarks/vanderpol/run_vanderpol.py | 8 +- docs/dev_docs/changelog.md | 2 +- .../use_equinox_bounded_while_loop.py | 6 +- .../physics_enhanced_regression_2.py | 6 +- docs/examples_quickstart/easy_example.py | 6 +- .../conditioning-on-zero-residual.py | 6 +- .../posterior_uncertainties.py | 10 +- .../second_order_problems.py | 10 +- docs/getting_started/choosing_a_solver.md | 4 +- .../transitioning_from_other_packages.md | 4 +- mkdocs.yml | 1 - probdiffeq/adaptive.py | 361 ---------------- probdiffeq/ivpsolve.py | 398 +++++++++++++++++- tests/test_ivpsolve/__init__.py | 8 +- .../test_fixed_grid_vs_save_every_step.py | 12 +- .../test_save_at_vs_save_every_step.py | 10 +- tests/test_ivpsolve/test_save_every_step.py | 6 +- tests/test_ivpsolve/test_solution_object.py | 10 +- ...test_terminal_values_vs_save_every_step.py | 10 +- ...test_mle_calibration_vs_calibrationfree.py | 18 +- .../test_warnings_for_wrong_strategies.py | 10 +- .../test_log_marginal_likelihood.py | 6 +- ...log_marginal_likelihood_terminal_values.py | 6 +- tests/test_solvers/test_stats/test_sample.py | 6 +- .../test_solvers/test_strategies/__init__.py | 2 +- .../test_rmse_of_correction.py | 8 +- ...test_smoother_vs_fixedpoint_equivalence.py | 14 +- 31 files changed, 483 insertions(+), 490 deletions(-) delete mode 100644 docs/api_docs/adaptive.md delete mode 100644 probdiffeq/adaptive.py diff --git a/docs/api_docs/adaptive.md b/docs/api_docs/adaptive.md deleted file mode 100644 index 3afed937..00000000 --- a/docs/api_docs/adaptive.md +++ /dev/null @@ -1 +0,0 @@ -::: probdiffeq.adaptive diff --git a/docs/benchmarks/hires/run_hires.py b/docs/benchmarks/hires/run_hires.py index ad9221b5..6fd690d4 100644 --- a/docs/benchmarks/hires/run_hires.py +++ b/docs/benchmarks/hires/run_hires.py @@ -17,7 +17,7 @@ import scipy.integrate import tqdm -from probdiffeq import adaptive, ivpsolve +from probdiffeq import ivpsolve from probdiffeq.impl import impl from probdiffeq.solvers import components, solvers from probdiffeq.taylor import autodiff @@ -96,8 +96,8 @@ def param_to_solution(tol): ts1 = components.correction_ts1() strategy = components.strategy_filter(ibm, ts1) solver = solvers.dynamic(strategy) - control = adaptive.control_proportional_integral_clipped() - adaptive_solver = adaptive.adaptive( + control = ivpsolve.control_proportional_integral_clipped() + adaptive_solver = ivpsolve.adaptive( solver, atol=1e-2 * tol, rtol=tol, control=control ) @@ -108,7 +108,7 @@ def param_to_solution(tol): # Solve dt0 = ivpsolve.dt0(vf_auto, (u0,)) - solution = ivpsolve.simulate_terminal_values( + solution = ivpsolve.solve_adaptive_terminal_values( vf_probdiffeq, init, t0=t0, t1=t1, dt0=dt0, adaptive_solver=adaptive_solver ) diff --git a/docs/benchmarks/lotkavolterra/run_lotkavolterra.py b/docs/benchmarks/lotkavolterra/run_lotkavolterra.py index f4649677..ce5edfc2 100644 --- a/docs/benchmarks/lotkavolterra/run_lotkavolterra.py +++ b/docs/benchmarks/lotkavolterra/run_lotkavolterra.py @@ -18,7 +18,7 @@ import scipy.integrate import tqdm -from probdiffeq import adaptive, ivpsolve +from probdiffeq import ivpsolve from probdiffeq.impl import impl from probdiffeq.solvers import components, solvers from probdiffeq.taylor import autodiff @@ -84,8 +84,8 @@ def param_to_solution(tol): ibm = components.prior_ibm(num_derivatives=num_derivatives) strategy = components.strategy_filter(ibm, correction()) solver = solvers.mle(strategy) - control = adaptive.control_proportional_integral() - adaptive_solver = adaptive.adaptive( + control = ivpsolve.control_proportional_integral() + adaptive_solver = ivpsolve.adaptive( solver, atol=1e-2 * tol, rtol=tol, control=control ) @@ -97,7 +97,7 @@ def param_to_solution(tol): # Solve dt0 = ivpsolve.dt0(vf_auto, (u0,)) - solution = ivpsolve.simulate_terminal_values( + solution = ivpsolve.solve_adaptive_terminal_values( vf_probdiffeq, init, t0=t0, t1=t1, dt0=dt0, adaptive_solver=adaptive_solver ) diff --git a/docs/benchmarks/pleiades/run_pleiades.py b/docs/benchmarks/pleiades/run_pleiades.py index 1740b8f0..943b391f 100644 --- a/docs/benchmarks/pleiades/run_pleiades.py +++ b/docs/benchmarks/pleiades/run_pleiades.py @@ -18,7 +18,7 @@ import scipy.integrate import tqdm -from probdiffeq import adaptive, ivpsolve +from probdiffeq import ivpsolve from probdiffeq.impl import impl from probdiffeq.solvers import components, solvers from probdiffeq.taylor import autodiff @@ -108,8 +108,8 @@ def param_to_solution(tol): ts0_or_ts1 = correction_fun(ode_order=2) strategy = components.strategy_filter(ibm, ts0_or_ts1) solver = solvers.dynamic(strategy) - control = adaptive.control_proportional_integral() - adaptive_solver = adaptive.adaptive( + control = ivpsolve.control_proportional_integral() + adaptive_solver = ivpsolve.adaptive( solver, atol=1e-3 * tol, rtol=tol, control=control ) @@ -120,7 +120,7 @@ def param_to_solution(tol): # Solve dt0 = ivpsolve.dt0(vf_auto, (u0, du0)) - solution = ivpsolve.simulate_terminal_values( + solution = ivpsolve.solve_adaptive_terminal_values( vf_probdiffeq, init, t0=t0, t1=t1, dt0=dt0, adaptive_solver=adaptive_solver ) diff --git a/docs/benchmarks/vanderpol/run_vanderpol.py b/docs/benchmarks/vanderpol/run_vanderpol.py index 9f829262..90d90fcf 100644 --- a/docs/benchmarks/vanderpol/run_vanderpol.py +++ b/docs/benchmarks/vanderpol/run_vanderpol.py @@ -17,7 +17,7 @@ import scipy.integrate import tqdm -from probdiffeq import adaptive, ivpsolve +from probdiffeq import ivpsolve from probdiffeq.impl import impl from probdiffeq.solvers import components, solvers from probdiffeq.taylor import autodiff @@ -87,8 +87,8 @@ def param_to_solution(tol): ts0_or_ts1 = components.correction_ts1(ode_order=2) strategy = components.strategy_filter(ibm, ts0_or_ts1) solver = solvers.dynamic(strategy) - control = adaptive.control_proportional_integral_clipped() - adaptive_solver = adaptive.adaptive( + control = ivpsolve.control_proportional_integral_clipped() + adaptive_solver = ivpsolve.adaptive( solver, atol=1e-3 * tol, rtol=tol, control=control ) @@ -99,7 +99,7 @@ def param_to_solution(tol): # Solve dt0 = ivpsolve.dt0(vf_auto, (u0, du0)) - solution = ivpsolve.simulate_terminal_values( + solution = ivpsolve.solve_adaptive_terminal_values( vf_probdiffeq, init, t0=t0, t1=t1, dt0=dt0, adaptive_solver=adaptive_solver ) diff --git a/docs/dev_docs/changelog.md b/docs/dev_docs/changelog.md index 0b3c7905..9711b440 100644 --- a/docs/dev_docs/changelog.md +++ b/docs/dev_docs/changelog.md @@ -37,7 +37,7 @@ From now on, this change log will be used properly. This means that the behaviour of, e.g., parameter estimation scripts will change slightly. A related bugfix in computing the whitened residuals implies that the dynamic solver with a ts1() correction and a dense implementation is not exactly equivalent to tornadox.ReferenceEK1 anymore (because the tornadox-version still has the same error). -* The interpolation behaviour of the MLESolver when called in solve_and_save_at() had a small error, which amplified the output scale unnecessarily between steps. +* The interpolation behaviour of the MLESolver when called in solve_adaptive_save_at() had a small error, which amplified the output scale unnecessarily between steps. This has been fixed. As a result, the posterior-uncertainty notebook displays more realistic uncertainty estimates in high-order derivatives. Check it out! ## Prior to v0.2.0 diff --git a/docs/examples_misc/use_equinox_bounded_while_loop.py b/docs/examples_misc/use_equinox_bounded_while_loop.py index 545183ec..9c62dee2 100644 --- a/docs/examples_misc/use_equinox_bounded_while_loop.py +++ b/docs/examples_misc/use_equinox_bounded_while_loop.py @@ -24,7 +24,7 @@ import jax import jax.numpy as jnp -from probdiffeq import adaptive, ivpsolve +from probdiffeq import ivpsolve from probdiffeq.backend import control_flow from probdiffeq.impl import impl from probdiffeq.solvers import components, solvers @@ -69,14 +69,14 @@ def vf(y, *, t): # noqa: ARG001 strategy = components.strategy_fixedpoint(ibm, ts0) solver = solvers.solver(strategy) - adaptive_solver = adaptive.adaptive(solver) + adaptive_solver = ivpsolve.adaptive(solver) tcoeffs = autodiff.taylor_mode_scan(lambda y: vf(y, t=t0), (u0,), num=1) init = solver.initial_condition(tcoeffs, 1.0) def simulate(init_val): """Evaluate the parameter-to-solution function.""" - sol = ivpsolve.simulate_terminal_values( + sol = ivpsolve.solve_adaptive_terminal_values( vf, init_val, t0=t0, t1=t1, dt0=0.1, adaptive_solver=adaptive_solver ) diff --git a/docs/examples_parameter_estimation/physics_enhanced_regression_2.py b/docs/examples_parameter_estimation/physics_enhanced_regression_2.py index 6dc783ba..35762cf2 100644 --- a/docs/examples_parameter_estimation/physics_enhanced_regression_2.py +++ b/docs/examples_parameter_estimation/physics_enhanced_regression_2.py @@ -132,7 +132,7 @@ import matplotlib.pyplot as plt from diffeqzoo import backend, ivps -from probdiffeq import adaptive, ivpsolve +from probdiffeq import ivpsolve from probdiffeq.impl import impl from probdiffeq.solvers import components, solvers, stats from probdiffeq.taylor import autodiff @@ -211,12 +211,12 @@ def solve_adaptive(theta, *, save_at): ts0 = components.correction_ts0() strategy = components.strategy_filter(ibm, ts0) solver = solvers.solver(strategy) - adaptive_solver = adaptive.adaptive(solver) + adaptive_solver = ivpsolve.adaptive(solver) tcoeffs = autodiff.taylor_mode_scan(lambda y: vf(y, t=t0), (theta,), num=2) output_scale = 10.0 init = solver.initial_condition(tcoeffs, output_scale) - return ivpsolve.solve_and_save_at( + return ivpsolve.solve_adaptive_save_at( vf, init, save_at=save_at, adaptive_solver=adaptive_solver, dt0=0.1 ) diff --git a/docs/examples_quickstart/easy_example.py b/docs/examples_quickstart/easy_example.py index b9052f50..db636afd 100644 --- a/docs/examples_quickstart/easy_example.py +++ b/docs/examples_quickstart/easy_example.py @@ -22,7 +22,7 @@ import jax import jax.numpy as jnp -from probdiffeq import adaptive, ivpsolve +from probdiffeq import ivpsolve from probdiffeq.impl import impl from probdiffeq.solvers import components, solvers from probdiffeq.taylor import autodiff @@ -88,7 +88,7 @@ def vf(y, *, t): # noqa: ARG001 strategy = components.strategy_smoother(ibm, ts0) solver = solvers.solver(strategy) -adaptive_solver = adaptive.adaptive(solver) +adaptive_solver = ivpsolve.adaptive(solver) # - # Why so many layers? @@ -136,7 +136,7 @@ def vf(y, *, t): # noqa: ARG001 # + dt0 = ivpsolve.dt0(lambda y: vf(y, t=t0), (u0,)) # or use e.g. dt0=0.1 -solution = ivpsolve.solve_and_save_every_step( +solution = ivpsolve.solve_adaptive_save_every_step( vf, init, t0=t0, t1=t1, dt0=dt0, adaptive_solver=adaptive_solver ) diff --git a/docs/examples_solver_config/conditioning-on-zero-residual.py b/docs/examples_solver_config/conditioning-on-zero-residual.py index ec944dce..dbeff5a3 100644 --- a/docs/examples_solver_config/conditioning-on-zero-residual.py +++ b/docs/examples_solver_config/conditioning-on-zero-residual.py @@ -26,7 +26,7 @@ import matplotlib.pyplot as plt from diffeqzoo import backend -from probdiffeq import adaptive, ivpsolve +from probdiffeq import ivpsolve from probdiffeq.impl import impl from probdiffeq.solvers import components, solvers, stats from probdiffeq.taylor import autodiff @@ -86,12 +86,12 @@ def vector_field(y, t): # noqa: ARG001 slr1 = components.correction_ts1() ibm = components.prior_ibm(num_derivatives=NUM_DERIVATIVES) solver = solvers.solver(components.strategy_fixedpoint(ibm, slr1)) -adaptive_solver = adaptive.adaptive(solver, atol=1e-1, rtol=1e-2) +adaptive_solver = ivpsolve.adaptive(solver, atol=1e-1, rtol=1e-2) dt0 = ivpsolve.dt0(lambda y: vector_field(y, t=t0), (u0,)) init = solver.initial_condition(tcoeffs, output_scale=1.0) -sol = ivpsolve.solve_and_save_at( +sol = ivpsolve.solve_adaptive_save_at( vector_field, init, save_at=ts, dt0=1.0, adaptive_solver=adaptive_solver ) # posterior = stats.calibrate(sol.posterior, sol.output_scale) diff --git a/docs/examples_solver_config/posterior_uncertainties.py b/docs/examples_solver_config/posterior_uncertainties.py index de9faf01..983ac72e 100644 --- a/docs/examples_solver_config/posterior_uncertainties.py +++ b/docs/examples_solver_config/posterior_uncertainties.py @@ -22,7 +22,7 @@ import matplotlib.pyplot as plt from diffeqzoo import backend, ivps -from probdiffeq import adaptive, ivpsolve +from probdiffeq import ivpsolve from probdiffeq.impl import impl from probdiffeq.solvers import components, solvers, stats from probdiffeq.taylor import autodiff @@ -66,7 +66,7 @@ def vf(*ys, t): # noqa: ARG001 ibm = components.prior_ibm(num_derivatives=4) ts0 = components.correction_ts0() solver = solvers.mle(components.strategy_filter(ibm, ts0)) -adaptive_solver = adaptive.adaptive(solver, atol=1e-2, rtol=1e-2) +adaptive_solver = ivpsolve.adaptive(solver, atol=1e-2, rtol=1e-2) ts = jnp.linspace(t0, t0 + 2.0, endpoint=True, num=500) @@ -75,7 +75,7 @@ def vf(*ys, t): # noqa: ARG001 tcoeffs = autodiff.taylor_mode_scan(lambda y: vf(y, t=t0), (u0,), num=4) init = solver.initial_condition(tcoeffs, output_scale=1.0) -sol = ivpsolve.solve_and_save_at( +sol = ivpsolve.solve_adaptive_save_at( vf, init, save_at=ts, dt0=dt0, adaptive_solver=adaptive_solver ) @@ -121,13 +121,13 @@ def vf(*ys, t): # noqa: ARG001 ibm = components.prior_ibm(num_derivatives=4) ts0 = components.correction_ts0() solver = solvers.mle(components.strategy_fixedpoint(ibm, ts0)) -adaptive_solver = adaptive.adaptive(solver, atol=1e-2, rtol=1e-2) +adaptive_solver = ivpsolve.adaptive(solver, atol=1e-2, rtol=1e-2) ts = jnp.linspace(t0, t0 + 2.0, endpoint=True, num=500) # + init = solver.initial_condition(tcoeffs, output_scale=1.0) -sol = ivpsolve.solve_and_save_at( +sol = ivpsolve.solve_adaptive_save_at( vf, init, save_at=ts, dt0=dt0, adaptive_solver=adaptive_solver ) diff --git a/docs/examples_solver_config/second_order_problems.py b/docs/examples_solver_config/second_order_problems.py index b002a25b..e8ffe601 100644 --- a/docs/examples_solver_config/second_order_problems.py +++ b/docs/examples_solver_config/second_order_problems.py @@ -22,7 +22,7 @@ import matplotlib.pyplot as plt from diffeqzoo import backend, ivps -from probdiffeq import adaptive, ivpsolve +from probdiffeq import ivpsolve from probdiffeq.impl import impl from probdiffeq.solvers import components, solvers from probdiffeq.taylor import autodiff @@ -56,14 +56,14 @@ def vf_1(y, t): # noqa: ARG001 ibm = components.prior_ibm(num_derivatives=4) ts0 = components.correction_ts0() solver_1st = solvers.mle(components.strategy_filter(ibm, ts0)) -adaptive_solver_1st = adaptive.adaptive(solver_1st, atol=1e-5, rtol=1e-5) +adaptive_solver_1st = ivpsolve.adaptive(solver_1st, atol=1e-5, rtol=1e-5) tcoeffs = autodiff.taylor_mode_scan(lambda y: vf_1(y, t=t0), (u0,), num=4) init = solver_1st.initial_condition(tcoeffs, output_scale=1.0) # - -solution = ivpsolve.solve_and_save_every_step( +solution = ivpsolve.solve_adaptive_save_every_step( vf_1, init, t0=t0, t1=t1, dt0=0.1, adaptive_solver=adaptive_solver_1st ) @@ -90,14 +90,14 @@ def vf_2(y, dy, t): # noqa: ARG001 ibm = components.prior_ibm(num_derivatives=4) ts0 = components.correction_ts0(ode_order=2) solver_2nd = solvers.mle(components.strategy_filter(ibm, ts0)) -adaptive_solver_2nd = adaptive.adaptive(solver_2nd, atol=1e-5, rtol=1e-5) +adaptive_solver_2nd = ivpsolve.adaptive(solver_2nd, atol=1e-5, rtol=1e-5) tcoeffs = autodiff.taylor_mode_scan(lambda *ys: vf_2(*ys, t=t0), (u0, du0), num=3) init = solver_2nd.initial_condition(tcoeffs, output_scale=1.0) # - -solution = ivpsolve.solve_and_save_every_step( +solution = ivpsolve.solve_adaptive_save_every_step( vf_2, init, t0=t0, t1=t1, dt0=0.1, adaptive_solver=adaptive_solver_2nd ) diff --git a/docs/getting_started/choosing_a_solver.md b/docs/getting_started/choosing_a_solver.md index 492cbdf5..8c43412c 100644 --- a/docs/getting_started/choosing_a_solver.md +++ b/docs/getting_started/choosing_a_solver.md @@ -30,8 +30,8 @@ If that does not work: let me know what you come up with... ## Filters vs smoothers Almost always, use a `components.strategy_filter` strategy for `simulate_terminal_values`, -a `components.strategy_smoother` strategy for `solve_and_save_every_step`, -and a `components.strategy_fixedpoint` strategy for `solve_and_save_at`. +a `components.strategy_smoother` strategy for `solve_adaptive_save_every_step`, +and a `components.strategy_fixedpoint` strategy for `solve_adaptive_save_at`. Use either a filter (if you must) or a smoother (recommended) for `solve_fixed_step`. Other combinations are possible, but rather rare (and require some understanding of the underlying statistical concepts). diff --git a/docs/getting_started/transitioning_from_other_packages.md b/docs/getting_started/transitioning_from_other_packages.md index 0c83ceab..a733c417 100644 --- a/docs/getting_started/transitioning_from_other_packages.md +++ b/docs/getting_started/transitioning_from_other_packages.md @@ -36,7 +36,7 @@ ProbDiffEq can reproduce most of the implementations in Tornadox: | `ek1.ReferenceEK1()` | `dynamic(strategy_filter(ibm_adaptive(), ts1()))` | Combine with `impl.select("dense", ...)` | | `ek1.ReferenceEK1ConstantDiffusion()` | `mle(strategy_filter(ibm_adaptive(), ts1()))` | Combine with `impl.select("dense", ...)`. | | `ek1.DiagonalEK1()` | Work in progress. | | -| `solver.solve()` | `solve_and_save_every_step()` | Try `solve_and_save_at()` instead. | +| `solver.solve()` | `solve_adaptive_save_every_step()` | Try `solve_adaptive_save_at()` instead. | | `solver.simulate_final_state()` | `simulate_terminal_values()` | ProbDiffEq compiles the whole loop; it will be much faster. | | `solver.solution_generator()` | Work in progress. | | | `init.TaylorMode()` | `taylor.autodiff.taylor_mode` | Consider `taylor.autodiff.forward_mode_recursive()` for low numbers of derivatives and `taylor.autodiff.taylor_mode_doubling()` for (absurdly) high numbers of derivatives | @@ -161,5 +161,5 @@ Most of the divergences from Diffrax apply. Additionally: * Solution objects in ProbDiffEq are random processes (posterior distributions). Random variable types replace most vectors and matrices. This statistical description is richer than a point estimate but needs to be calibrated and demands a non-trivial interaction with the solution (e.g. via sampling from it instead of simply plotting the point-estimate) -* ProbDiffEq offers different solution methods: `simulate_terminal_values()`, `solve_and_save_every_step()`, or `solve_and_save_at()`. Many conventional ODE solver suites expose this functionality through flags in a single `solve` function. +* ProbDiffEq offers different solution methods: `simulate_terminal_values()`, `solve_adaptive_save_every_step()`, or `solve_adaptive_save_at()`. Many conventional ODE solver suites expose this functionality through flags in a single `solve` function. Expressing different modes of solving differential equations in different functions almost exclusively affects the source-code simplicity; but it also allows matching the solver to the solving mode (e.g., terminal values vs save-at). For example, `simulate_terminal_values()` is best combined with a filter. diff --git a/mkdocs.yml b/mkdocs.yml index 891413cb..727b9acb 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -92,7 +92,6 @@ nav: - examples_misc/use_equinox_bounded_while_loop.ipynb - API DOCUMENTATION: - ivpsolve: api_docs/ivpsolve.md - - adaptive: api_docs/adaptive.md - impl: api_docs/impl.md - solvers: - components: api_docs/solvers/components.md diff --git a/probdiffeq/adaptive.py b/probdiffeq/adaptive.py deleted file mode 100644 index 5779398a..00000000 --- a/probdiffeq/adaptive.py +++ /dev/null @@ -1,361 +0,0 @@ -"""Adaptive solvers for initial value problems (IVPs).""" - -from probdiffeq.backend import containers, control_flow, functools, linalg, tree_util -from probdiffeq.backend import numpy as np -from probdiffeq.backend.typing import Any, Callable -from probdiffeq.impl import impl - - -@containers.dataclass -class _Controller: - """Control algorithm.""" - - init: Callable[[float], Any] - """Initialise the controller state.""" - - clip: Callable[[Any, float, float], Any] - """(Optionally) clip the current step to not exceed t1.""" - - apply: Callable[[Any, float, float], Any] - r"""Propose a time-step $\Delta t$.""" - - extract: Callable[[Any], float] - """Extract the time-step from the controller state.""" - - -def control_proportional_integral( - *, - safety=0.95, - factor_min=0.2, - factor_max=10.0, - power_integral_unscaled=0.3, - power_proportional_unscaled=0.4, -) -> _Controller: - """Construct a proportional-integral-controller.""" - init = _proportional_integral_init - apply = functools.partial( - _proportional_integral_apply, - safety=safety, - factor_min=factor_min, - factor_max=factor_max, - power_integral_unscaled=power_integral_unscaled, - power_proportional_unscaled=power_proportional_unscaled, - ) - extract = _proportional_integral_extract - return _Controller(init=init, apply=apply, extract=extract, clip=_no_clip) - - -def control_proportional_integral_clipped( - *, - safety=0.95, - factor_min=0.2, - factor_max=10.0, - power_integral_unscaled=0.3, - power_proportional_unscaled=0.4, -) -> _Controller: - """Construct a proportional-integral-controller with time-clipping.""" - init = _proportional_integral_init - apply = functools.partial( - _proportional_integral_apply, - safety=safety, - factor_min=factor_min, - factor_max=factor_max, - power_integral_unscaled=power_integral_unscaled, - power_proportional_unscaled=power_proportional_unscaled, - ) - extract = _proportional_integral_extract - clip = _proportional_integral_clip - return _Controller(init=init, apply=apply, extract=extract, clip=clip) - - -def _proportional_integral_apply( - state: tuple[float, float], - /, - error_normalised, - *, - error_contraction_rate, - safety, - factor_min, - factor_max, - power_integral_unscaled, - power_proportional_unscaled, -) -> tuple[float, float]: - dt_proposed, error_norm_previously_accepted = state - n1 = power_integral_unscaled / error_contraction_rate - n2 = power_proportional_unscaled / error_contraction_rate - - a1 = (1.0 / error_normalised) ** n1 - a2 = (error_norm_previously_accepted / error_normalised) ** n2 - scale_factor_unclipped = safety * a1 * a2 - - scale_factor_clipped_min = np.minimum(scale_factor_unclipped, factor_max) - scale_factor = np.maximum(factor_min, scale_factor_clipped_min) - error_norm_previously_accepted = np.where( - error_normalised <= 1.0, error_normalised, error_norm_previously_accepted - ) - - dt_proposed = scale_factor * dt_proposed - return dt_proposed, error_norm_previously_accepted - - -def _proportional_integral_init(dt0, /): - return dt0, 1.0 - - -def _proportional_integral_clip( - state: tuple[float, float], /, t, t1 -) -> tuple[float, float]: - dt_proposed, error_norm_previously_accepted = state - dt = dt_proposed - dt_clipped = np.minimum(dt, t1 - t) - return dt_clipped, error_norm_previously_accepted - - -def _proportional_integral_extract(state: tuple[float, float], /): - dt_proposed, _error_norm_previously_accepted = state - return dt_proposed - - -def control_integral(*, safety=0.95, factor_min=0.2, factor_max=10.0) -> _Controller: - """Construct an integral-controller.""" - init = _integral_init - apply = functools.partial( - _integral_apply, safety=safety, factor_min=factor_min, factor_max=factor_max - ) - extract = _integral_extract - return _Controller(init=init, apply=apply, extract=extract, clip=_no_clip) - - -def control_integral_clipped( - *, safety=0.95, factor_min=0.2, factor_max=10.0 -) -> _Controller: - """Construct an integral-controller with time-clipping.""" - init = functools.partial(_integral_init) - apply = functools.partial( - _integral_apply, safety=safety, factor_min=factor_min, factor_max=factor_max - ) - extract = functools.partial(_integral_extract) - return _Controller(init=init, apply=apply, extract=extract, clip=_integral_clip) - - -def _integral_init(dt0, /): - return dt0 - - -def _integral_clip(dt, /, t, t1): - return np.minimum(dt, t1 - t) - - -def _no_clip(dt, /, *_args, **_kwargs): - return dt - - -def _integral_apply( - dt, /, error_normalised, *, error_contraction_rate, safety, factor_min, factor_max -): - error_power = error_normalised ** (-1.0 / error_contraction_rate) - scale_factor_unclipped = safety * error_power - - scale_factor_clipped_min = np.minimum(scale_factor_unclipped, factor_max) - scale_factor = np.maximum(factor_min, scale_factor_clipped_min) - return scale_factor * dt - - -def _integral_extract(dt, /): - return dt - - -# Register the control algorithm as a pytree (temporary?) - - -def _flatten(ctrl): - aux = ctrl.init, ctrl.apply, ctrl.clip, ctrl.extract - return (), aux - - -def _unflatten(aux, _children): - init, apply, clip, extract = aux - return _Controller(init=init, apply=apply, clip=clip, extract=extract) - - -tree_util.register_pytree_node(_Controller, _flatten, _unflatten) - - -def adaptive(solver, atol=1e-4, rtol=1e-2, control=None, norm_ord=None): - """Make an IVP solver adaptive.""" - if control is None: - control = control_proportional_integral() - - return _AdaptiveIVPSolver( - solver, atol=atol, rtol=rtol, control=control, norm_ord=norm_ord - ) - - -class _RejectionState(containers.NamedTuple): - """State for rejection loops. - - (Keep decreasing step-size until error norm is small. - This is one part of an IVP solver step.) - """ - - error_norm_proposed: float - control: Any - proposed: Any - step_from: Any - - -class _AdaptiveState(containers.NamedTuple): - step_from: Any - interp_from: Any - control: Any - stats: Any - - @property - def t(self): - return self.step_from.t - - -class _AdaptiveIVPSolver: - """Adaptive IVP solvers.""" - - def __init__(self, solver, atol, rtol, control, norm_ord): - self.solver = solver - self.atol = atol - self.rtol = rtol - self.control = control - self.norm_ord = norm_ord - - def __repr__(self): - return ( - f"\n{self.__class__.__name__}(" - f"\n\tsolver={self.solver}," - f"\n\tatol={self.atol}," - f"\n\trtol={self.rtol}," - f"\n\tcontrol={self.control}," - f"\n\tnorm_order={self.norm_ord}," - "\n)" - ) - - @functools.jit - def init(self, t, initial_condition, dt0, num_steps): - """Initialise the IVP solver state.""" - state_solver = self.solver.init(t, initial_condition) - state_control = self.control.init(dt0) - return _AdaptiveState(state_solver, state_solver, state_control, num_steps) - - @functools.jit - def rejection_loop(self, state0, *, vector_field, t1): - def cond_fn(s): - return s.error_norm_proposed > 1.0 - - def body_fn(s): - return self._attempt_step(state=s, vector_field=vector_field, t1=t1) - - def init(s0): - larger_than_1 = 1.1 - return _RejectionState( - error_norm_proposed=larger_than_1, - control=s0.control, - proposed=_inf_like(s0.step_from), - step_from=s0.step_from, - ) - - def extract(s): - num_steps = state0.stats + 1 - return _AdaptiveState(s.proposed, s.step_from, s.control, num_steps) - - init_val = init(state0) - state_new = control_flow.while_loop(cond_fn, body_fn, init_val) - return extract(state_new) - - def _attempt_step(self, *, state: _RejectionState, vector_field, t1): - """Attempt a step. - - Perform a step with an IVP solver and - propose a future time-step based on tolerances and error estimates. - """ - # Some controllers like to clip the terminal value instead of interpolating. - # This must happen _before_ the step. - state_control = self.control.clip(state.control, t=state.step_from.t, t1=t1) - - # Perform the actual step. - # todo: error estimate should be a tuple (abs, rel) - error_estimate, state_proposed = self.solver.step( - state=state.step_from, - vector_field=vector_field, - dt=self.control.extract(state_control), - ) - # Normalise the error - u_proposed = impl.hidden_model.qoi(state_proposed.strategy.hidden) - u_step_from = impl.hidden_model.qoi(state_proposed.strategy.hidden) - u = np.maximum(np.abs(u_proposed), np.abs(u_step_from)) - error_normalised = self._normalise_error(error_estimate, u=u) - - # Propose a new step - error_contraction_rate = self.solver.strategy.extrapolation.num_derivatives + 1 - state_control = self.control.apply( - state_control, - error_normalised=error_normalised, - error_contraction_rate=error_contraction_rate, - ) - return _RejectionState( - error_norm_proposed=error_normalised, # new - proposed=state_proposed, # new - control=state_control, # new - step_from=state.step_from, - ) - - def _normalise_error(self, error_estimate, *, u): - error_relative = error_estimate / (self.atol + self.rtol * np.abs(u)) - dim = np.atleast_1d(u).size - return linalg.vector_norm(error_relative, order=self.norm_ord) / np.sqrt(dim) - - def extract(self, state): - solution_solver = self.solver.extract(state.step_from) - solution_control = self.control.extract(state.control) - return solution_solver, solution_control, state.stats - - def right_corner_and_extract(self, state): - interp = self.solver.right_corner(state.interp_from, state.step_from) - accepted, solution, previous = interp - state = _AdaptiveState(accepted, previous, state.control, state.stats) - - solution_solver = self.solver.extract(solution) - solution_control = self.control.extract(state.control) - return state, (solution_solver, solution_control, state.stats) - - def interpolate_and_extract(self, state, t): - interp = self.solver.interpolate(s1=state.step_from, s0=state.interp_from, t=t) - - accepted, solution, previous = interp - state = _AdaptiveState(accepted, previous, state.control, state.stats) - - solution_solver = self.solver.extract(solution) - solution_control = self.control.extract(state.control) - return state, (solution_solver, solution_control, state.stats) - - -# Register outside of class to declutter the AdaptiveIVPSolver source code a bit - - -def _asolver_flatten(asolver): - children = (asolver.solver, asolver.atol, asolver.rtol, asolver.control) - aux = (asolver.norm_ord,) - return children, aux - - -def _asolver_unflatten(aux, children): - solver, atol, rtol, control = children - (norm_ord,) = aux - return _AdaptiveIVPSolver( - solver=solver, atol=atol, rtol=rtol, control=control, norm_ord=norm_ord - ) - - -tree_util.register_pytree_node( - _AdaptiveIVPSolver, flatten_func=_asolver_flatten, unflatten_func=_asolver_unflatten -) - - -def _inf_like(tree): - return tree_util.tree_map(lambda x: np.inf() * np.ones_like(x), tree) diff --git a/probdiffeq/ivpsolve.py b/probdiffeq/ivpsolve.py index d9c60868..d837ae14 100644 --- a/probdiffeq/ivpsolve.py +++ b/probdiffeq/ivpsolve.py @@ -1,6 +1,7 @@ """Routines for estimating solutions of initial value problems.""" from probdiffeq.backend import ( + containers, control_flow, functools, linalg, @@ -9,16 +10,367 @@ warnings, ) from probdiffeq.backend import numpy as np +from probdiffeq.backend.typing import Any, Callable from probdiffeq.impl import impl from probdiffeq.solvers import stats -# todo: change the Solution object to a simple -# named tuple containing (t, full_estimate, u_and_marginals, stats). -# No need to pre/append the initial condition to the solution anymore, -# since the user knows it already. + +@containers.dataclass +class _Controller: + """Control algorithm.""" + + init: Callable[[float], Any] + """Initialise the controller state.""" + + clip: Callable[[Any, float, float], Any] + """(Optionally) clip the current step to not exceed t1.""" + + apply: Callable[[Any, float, float], Any] + r"""Propose a time-step $\Delta t$.""" + + extract: Callable[[Any], float] + """Extract the time-step from the controller state.""" + + +def control_proportional_integral( + *, + safety=0.95, + factor_min=0.2, + factor_max=10.0, + power_integral_unscaled=0.3, + power_proportional_unscaled=0.4, +) -> _Controller: + """Construct a proportional-integral-controller.""" + init = _proportional_integral_init + apply = functools.partial( + _proportional_integral_apply, + safety=safety, + factor_min=factor_min, + factor_max=factor_max, + power_integral_unscaled=power_integral_unscaled, + power_proportional_unscaled=power_proportional_unscaled, + ) + extract = _proportional_integral_extract + return _Controller(init=init, apply=apply, extract=extract, clip=_no_clip) + + +def control_proportional_integral_clipped( + *, + safety=0.95, + factor_min=0.2, + factor_max=10.0, + power_integral_unscaled=0.3, + power_proportional_unscaled=0.4, +) -> _Controller: + """Construct a proportional-integral-controller with time-clipping.""" + init = _proportional_integral_init + apply = functools.partial( + _proportional_integral_apply, + safety=safety, + factor_min=factor_min, + factor_max=factor_max, + power_integral_unscaled=power_integral_unscaled, + power_proportional_unscaled=power_proportional_unscaled, + ) + extract = _proportional_integral_extract + clip = _proportional_integral_clip + return _Controller(init=init, apply=apply, extract=extract, clip=clip) + + +def _proportional_integral_apply( + state: tuple[float, float], + /, + error_normalised, + *, + error_contraction_rate, + safety, + factor_min, + factor_max, + power_integral_unscaled, + power_proportional_unscaled, +) -> tuple[float, float]: + dt_proposed, error_norm_previously_accepted = state + n1 = power_integral_unscaled / error_contraction_rate + n2 = power_proportional_unscaled / error_contraction_rate + + a1 = (1.0 / error_normalised) ** n1 + a2 = (error_norm_previously_accepted / error_normalised) ** n2 + scale_factor_unclipped = safety * a1 * a2 + + scale_factor_clipped_min = np.minimum(scale_factor_unclipped, factor_max) + scale_factor = np.maximum(factor_min, scale_factor_clipped_min) + error_norm_previously_accepted = np.where( + error_normalised <= 1.0, error_normalised, error_norm_previously_accepted + ) + + dt_proposed = scale_factor * dt_proposed + return dt_proposed, error_norm_previously_accepted + + +def _proportional_integral_init(dt0, /): + return dt0, 1.0 + + +def _proportional_integral_clip( + state: tuple[float, float], /, t, t1 +) -> tuple[float, float]: + dt_proposed, error_norm_previously_accepted = state + dt = dt_proposed + dt_clipped = np.minimum(dt, t1 - t) + return dt_clipped, error_norm_previously_accepted + + +def _proportional_integral_extract(state: tuple[float, float], /): + dt_proposed, _error_norm_previously_accepted = state + return dt_proposed + + +def control_integral(*, safety=0.95, factor_min=0.2, factor_max=10.0) -> _Controller: + """Construct an integral-controller.""" + init = _integral_init + apply = functools.partial( + _integral_apply, safety=safety, factor_min=factor_min, factor_max=factor_max + ) + extract = _integral_extract + return _Controller(init=init, apply=apply, extract=extract, clip=_no_clip) + + +def control_integral_clipped( + *, safety=0.95, factor_min=0.2, factor_max=10.0 +) -> _Controller: + """Construct an integral-controller with time-clipping.""" + init = functools.partial(_integral_init) + apply = functools.partial( + _integral_apply, safety=safety, factor_min=factor_min, factor_max=factor_max + ) + extract = functools.partial(_integral_extract) + return _Controller(init=init, apply=apply, extract=extract, clip=_integral_clip) + + +def _integral_init(dt0, /): + return dt0 + + +def _integral_clip(dt, /, t, t1): + return np.minimum(dt, t1 - t) + + +def _no_clip(dt, /, *_args, **_kwargs): + return dt + + +def _integral_apply( + dt, /, error_normalised, *, error_contraction_rate, safety, factor_min, factor_max +): + error_power = error_normalised ** (-1.0 / error_contraction_rate) + scale_factor_unclipped = safety * error_power + + scale_factor_clipped_min = np.minimum(scale_factor_unclipped, factor_max) + scale_factor = np.maximum(factor_min, scale_factor_clipped_min) + return scale_factor * dt + + +def _integral_extract(dt, /): + return dt + + +# Register the control algorithm as a pytree (temporary?) + + +def _flatten(ctrl): + aux = ctrl.init, ctrl.apply, ctrl.clip, ctrl.extract + return (), aux + + +def _unflatten(aux, _children): + init, apply, clip, extract = aux + return _Controller(init=init, apply=apply, clip=clip, extract=extract) + + +tree_util.register_pytree_node(_Controller, _flatten, _unflatten) + + +def adaptive(solver, atol=1e-4, rtol=1e-2, control=None, norm_ord=None): + """Make an IVP solver adaptive.""" + if control is None: + control = control_proportional_integral() + + return _AdaptiveIVPSolver( + solver, atol=atol, rtol=rtol, control=control, norm_ord=norm_ord + ) + + +class _RejectionState(containers.NamedTuple): + """State for rejection loops. + + (Keep decreasing step-size until error norm is small. + This is one part of an IVP solver step.) + """ + + error_norm_proposed: float + control: Any + proposed: Any + step_from: Any + + +class _AdaptiveState(containers.NamedTuple): + step_from: Any + interp_from: Any + control: Any + stats: Any + + @property + def t(self): + return self.step_from.t + + +class _AdaptiveIVPSolver: + """Adaptive IVP solvers.""" + + def __init__(self, solver, atol, rtol, control, norm_ord): + self.solver = solver + self.atol = atol + self.rtol = rtol + self.control = control + self.norm_ord = norm_ord + + def __repr__(self): + return ( + f"\n{self.__class__.__name__}(" + f"\n\tsolver={self.solver}," + f"\n\tatol={self.atol}," + f"\n\trtol={self.rtol}," + f"\n\tcontrol={self.control}," + f"\n\tnorm_order={self.norm_ord}," + "\n)" + ) + + @functools.jit + def init(self, t, initial_condition, dt0, num_steps): + """Initialise the IVP solver state.""" + state_solver = self.solver.init(t, initial_condition) + state_control = self.control.init(dt0) + return _AdaptiveState(state_solver, state_solver, state_control, num_steps) + + @functools.jit + def rejection_loop(self, state0, *, vector_field, t1): + def cond_fn(s): + return s.error_norm_proposed > 1.0 + + def body_fn(s): + return self._attempt_step(state=s, vector_field=vector_field, t1=t1) + + def init(s0): + larger_than_1 = 1.1 + return _RejectionState( + error_norm_proposed=larger_than_1, + control=s0.control, + proposed=_inf_like(s0.step_from), + step_from=s0.step_from, + ) + + def extract(s): + num_steps = state0.stats + 1 + return _AdaptiveState(s.proposed, s.step_from, s.control, num_steps) + + init_val = init(state0) + state_new = control_flow.while_loop(cond_fn, body_fn, init_val) + return extract(state_new) + + def _attempt_step(self, *, state: _RejectionState, vector_field, t1): + """Attempt a step. + + Perform a step with an IVP solver and + propose a future time-step based on tolerances and error estimates. + """ + # Some controllers like to clip the terminal value instead of interpolating. + # This must happen _before_ the step. + state_control = self.control.clip(state.control, t=state.step_from.t, t1=t1) + + # Perform the actual step. + # todo: error estimate should be a tuple (abs, rel) + error_estimate, state_proposed = self.solver.step( + state=state.step_from, + vector_field=vector_field, + dt=self.control.extract(state_control), + ) + # Normalise the error + u_proposed = impl.hidden_model.qoi(state_proposed.strategy.hidden) + u_step_from = impl.hidden_model.qoi(state_proposed.strategy.hidden) + u = np.maximum(np.abs(u_proposed), np.abs(u_step_from)) + error_normalised = self._normalise_error(error_estimate, u=u) + + # Propose a new step + error_contraction_rate = self.solver.strategy.extrapolation.num_derivatives + 1 + state_control = self.control.apply( + state_control, + error_normalised=error_normalised, + error_contraction_rate=error_contraction_rate, + ) + return _RejectionState( + error_norm_proposed=error_normalised, # new + proposed=state_proposed, # new + control=state_control, # new + step_from=state.step_from, + ) + + def _normalise_error(self, error_estimate, *, u): + error_relative = error_estimate / (self.atol + self.rtol * np.abs(u)) + dim = np.atleast_1d(u).size + return linalg.vector_norm(error_relative, order=self.norm_ord) / np.sqrt(dim) + + def extract(self, state): + solution_solver = self.solver.extract(state.step_from) + solution_control = self.control.extract(state.control) + return solution_solver, solution_control, state.stats + + def right_corner_and_extract(self, state): + interp = self.solver.right_corner(state.interp_from, state.step_from) + accepted, solution, previous = interp + state = _AdaptiveState(accepted, previous, state.control, state.stats) + + solution_solver = self.solver.extract(solution) + solution_control = self.control.extract(state.control) + return state, (solution_solver, solution_control, state.stats) + + def interpolate_and_extract(self, state, t): + interp = self.solver.interpolate(s1=state.step_from, s0=state.interp_from, t=t) + + accepted, solution, previous = interp + state = _AdaptiveState(accepted, previous, state.control, state.stats) + + solution_solver = self.solver.extract(solution) + solution_control = self.control.extract(state.control) + return state, (solution_solver, solution_control, state.stats) + + +# Register outside of class to declutter the AdaptiveIVPSolver source code a bit + + +def _asolver_flatten(asolver): + children = (asolver.solver, asolver.atol, asolver.rtol, asolver.control) + aux = (asolver.norm_ord,) + return children, aux + + +def _asolver_unflatten(aux, children): + solver, atol, rtol, control = children + (norm_ord,) = aux + return _AdaptiveIVPSolver( + solver=solver, atol=atol, rtol=rtol, control=control, norm_ord=norm_ord + ) + + +tree_util.register_pytree_node( + _AdaptiveIVPSolver, flatten_func=_asolver_flatten, unflatten_func=_asolver_unflatten +) + + +def _inf_like(tree): + return tree_util.tree_map(lambda x: np.inf() * np.ones_like(x), tree) -class Solution: +class _Solution: """Estimated initial value problem solution.""" def __init__(self, t, u, output_scale, marginals, posterior, num_steps): @@ -87,7 +439,7 @@ def _sol_flatten(sol): def _sol_unflatten(_aux, children): t, u, marginals, posterior, output_scale, n = children - return Solution( + return _Solution( t=t, u=u, marginals=marginals, @@ -97,15 +449,15 @@ def _sol_unflatten(_aux, children): ) -tree_util.register_pytree_node(Solution, _sol_flatten, _sol_unflatten) +tree_util.register_pytree_node(_Solution, _sol_flatten, _sol_unflatten) -def simulate_terminal_values( +def solve_adaptive_terminal_values( vector_field, initial_condition, t0, t1, adaptive_solver, dt0 -) -> Solution: +) -> _Solution: """Simulate the terminal values of an initial value problem.""" save_at = np.asarray([t1]) - (_t, solution_save_at), _, num_steps = _solve_and_save_at( + (_t, solution_save_at), _, num_steps = _solve_adaptive_save_at( tree_util.Partial(vector_field), t0, initial_condition, @@ -122,7 +474,7 @@ def simulate_terminal_values( posterior, output_scale = solution_save_at marginals = posterior.init if isinstance(posterior, stats.MarkovSeq) else posterior u = impl.hidden_model.qoi(marginals) - return Solution( + return _Solution( t=t1, u=u, marginals=marginals, @@ -132,9 +484,9 @@ def simulate_terminal_values( ) -def solve_and_save_at( +def solve_adaptive_save_at( vector_field, initial_condition, save_at, adaptive_solver, dt0 -) -> Solution: +) -> _Solution: """Solve an initial value problem and return the solution at a pre-determined grid. !!! warning "Warning: highly EXPERIMENTAL feature!" @@ -147,11 +499,11 @@ def solve_and_save_at( if not adaptive_solver.solver.strategy.is_suitable_for_save_at: msg = ( f"Strategy {adaptive_solver.solver.strategy} should not " - f"be used in solve_and_save_at. " + f"be used in solve_adaptive_save_at. " ) warnings.warn(msg, stacklevel=1) - (_t, solution_save_at), _, num_steps = _solve_and_save_at( + (_t, solution_save_at), _, num_steps = _solve_adaptive_save_at( tree_util.Partial(vector_field), save_at[0], initial_condition, @@ -167,7 +519,7 @@ def solve_and_save_at( _tmp = _userfriendly_output(posterior=posterior_save_at, posterior_t0=posterior_t0) marginals, posterior = _tmp u = impl.hidden_model.qoi(marginals) - return Solution( + return _Solution( t=save_at, u=u, marginals=marginals, @@ -177,7 +529,7 @@ def solve_and_save_at( ) -def _solve_and_save_at( +def _solve_adaptive_save_at( vector_field, t, initial_condition, *, save_at, adaptive_solver, dt0 ): advance_func = functools.partial( @@ -217,9 +569,9 @@ def body_fun(s): return state, solution -def solve_and_save_every_step( +def solve_adaptive_save_every_step( vector_field, initial_condition, t0, t1, adaptive_solver, dt0 -) -> Solution: +) -> _Solution: """Solve an initial value problem and save every step. This function uses a native-Python while loop. @@ -230,7 +582,7 @@ def solve_and_save_every_step( if not adaptive_solver.solver.strategy.is_suitable_for_save_every_step: msg = ( f"Strategy {adaptive_solver.solver.strategy} should not " - f"be used in solve_and_save_every_step." + f"be used in solve_adaptive_save_every_step." ) warnings.warn(msg, stacklevel=1) @@ -257,7 +609,7 @@ def solve_and_save_every_step( marginals, posterior = _tmp u = impl.hidden_model.qoi(marginals) - return Solution( + return _Solution( t=t, u=u, marginals=marginals, @@ -289,7 +641,7 @@ def _solution_generator( yield solution -def solve_fixed_grid(vector_field, initial_condition, grid, solver) -> Solution: +def solve_fixed_grid(vector_field, initial_condition, grid, solver) -> _Solution: """Solve an initial value problem on a fixed, pre-determined grid.""" # Compute the solution @@ -308,7 +660,7 @@ def body_fn(s, dt): marginals, posterior = _tmp u = impl.hidden_model.qoi(marginals) - return Solution( + return _Solution( t=grid, u=u, marginals=marginals, diff --git a/tests/test_ivpsolve/__init__.py b/tests/test_ivpsolve/__init__.py index c13fe0eb..10689a1e 100644 --- a/tests/test_ivpsolve/__init__.py +++ b/tests/test_ivpsolve/__init__.py @@ -4,7 +4,7 @@ ```python solver = test_util.generate_solver() -solution = solve_and_save_every_step(lotka_volterra, solver) +solution = solve_adaptive_save_every_step(lotka_volterra, solver) ``` If this approximation is accurate (measured in error comparing to e.g. diffrax), @@ -14,10 +14,10 @@ * simulate_terminal_values() with the same arguments should yield the same terminal value as the base case. -* solve_and_save_at() should be identical to -interpolating the solve_and_save_every_step results +* solve_adaptive_save_at() should be identical to +interpolating the solve_adaptive_save_every_step results * solve_fixed_grid() should be identical -if the fixed grid is the solution grid of solve_and_save_every_step +if the fixed grid is the solution grid of solve_adaptive_save_every_step If these tests pass, and assuming that interpolation is correct, the solution routines must be correct. diff --git a/tests/test_ivpsolve/test_fixed_grid_vs_save_every_step.py b/tests/test_ivpsolve/test_fixed_grid_vs_save_every_step.py index 2c132098..bfef0a33 100644 --- a/tests/test_ivpsolve/test_fixed_grid_vs_save_every_step.py +++ b/tests/test_ivpsolve/test_fixed_grid_vs_save_every_step.py @@ -1,6 +1,6 @@ -"""Compare solve_fixed_grid to solve_and_save_every_step.""" +"""Compare solve_fixed_grid to solve_adaptive_save_every_step.""" -from probdiffeq import adaptive, ivpsolve +from probdiffeq import ivpsolve from probdiffeq.backend import numpy as np from probdiffeq.backend import testing from probdiffeq.impl import impl @@ -16,8 +16,8 @@ def test_fixed_grid_result_matches_adaptive_grid_result(): ts0 = components.correction_ts0() strategy = components.strategy_filter(ibm, ts0) solver = solvers.mle(strategy) - control = adaptive.control_integral_clipped() # Any clipped controller will do. - adaptive_solver = adaptive.adaptive(solver, atol=1e-2, rtol=1e-2, control=control) + control = ivpsolve.control_integral_clipped() # Any clipped controller will do. + adaptive_solver = ivpsolve.adaptive(solver, atol=1e-2, rtol=1e-2, control=control) tcoeffs = autodiff.taylor_mode_scan(lambda y: vf(y, t=t0), u0, num=2) output_scale = np.ones_like(impl.prototypes.output_scale()) @@ -30,7 +30,9 @@ def test_fixed_grid_result_matches_adaptive_grid_result(): "dt0": 0.1, "adaptive_solver": adaptive_solver, } - solution_adaptive = ivpsolve.solve_and_save_every_step(*args, **adaptive_kwargs) + solution_adaptive = ivpsolve.solve_adaptive_save_every_step( + *args, **adaptive_kwargs + ) grid_adaptive = solution_adaptive.t fixed_kwargs = {"grid": grid_adaptive, "solver": solver} diff --git a/tests/test_ivpsolve/test_save_at_vs_save_every_step.py b/tests/test_ivpsolve/test_save_at_vs_save_every_step.py index 8c9d1349..57f411da 100644 --- a/tests/test_ivpsolve/test_save_at_vs_save_every_step.py +++ b/tests/test_ivpsolve/test_save_at_vs_save_every_step.py @@ -1,6 +1,6 @@ -"""Assert that solve_and_save_at is consistent with solve_with_python_loop().""" +"""Assert that solve_adaptive_save_at is consistent with solve_with_python_loop().""" -from probdiffeq import adaptive, ivpsolve +from probdiffeq import ivpsolve from probdiffeq.backend import functools, testing, tree_util from probdiffeq.backend import numpy as np from probdiffeq.impl import impl @@ -19,7 +19,7 @@ def test_save_at_result_matches_interpolated_adaptive_result(): ts0 = components.correction_ts0() strategy = components.strategy_filter(ibm, ts0) solver = solvers.solver(strategy) - adaptive_solver = adaptive.adaptive(solver, atol=1e-2, rtol=1e-2) + adaptive_solver = ivpsolve.adaptive(solver, atol=1e-2, rtol=1e-2) tcoeffs = autodiff.taylor_mode_scan(lambda y: vf(y, t=t0), u0, num=2) output_scale = np.ones_like(impl.prototypes.output_scale()) @@ -29,7 +29,7 @@ def test_save_at_result_matches_interpolated_adaptive_result(): # Compute an adaptive solution and interpolate ts = np.linspace(t0, t1, num=15, endpoint=True) - solution_adaptive = ivpsolve.solve_and_save_every_step( + solution_adaptive = ivpsolve.solve_adaptive_save_every_step( *problem_args, t0=t0, t1=t1, **adaptive_kwargs ) u_interp, marginals_interp = stats.offgrid_marginals_searchsorted( @@ -37,7 +37,7 @@ def test_save_at_result_matches_interpolated_adaptive_result(): ) # Compute a save-at solution and remove the edge-points - solution_save_at = ivpsolve.solve_and_save_at( + solution_save_at = ivpsolve.solve_adaptive_save_at( *problem_args, save_at=ts, **adaptive_kwargs ) diff --git a/tests/test_ivpsolve/test_save_every_step.py b/tests/test_ivpsolve/test_save_every_step.py index 7df4720d..949030ce 100644 --- a/tests/test_ivpsolve/test_save_every_step.py +++ b/tests/test_ivpsolve/test_save_every_step.py @@ -1,6 +1,6 @@ """Assert that solve_with_python_loop is accurate.""" -from probdiffeq import adaptive, ivpsolve +from probdiffeq import ivpsolve from probdiffeq.backend import numpy as np from probdiffeq.backend import ode, testing from probdiffeq.impl import impl @@ -17,7 +17,7 @@ def fixture_python_loop_solution(): ts0 = components.correction_ts0() strategy = components.strategy_filter(ibm, ts0) solver = solvers.mle(strategy) - adaptive_solver = adaptive.adaptive(solver, atol=1e-2, rtol=1e-2) + adaptive_solver = ivpsolve.adaptive(solver, atol=1e-2, rtol=1e-2) dt0 = ivpsolve.dt0_adaptive( vf, u0, t0=t0, atol=1e-2, rtol=1e-2, error_contraction_rate=5 @@ -29,7 +29,7 @@ def fixture_python_loop_solution(): args = (vf, init) kwargs = {"t0": t0, "t1": t1, "adaptive_solver": adaptive_solver, "dt0": dt0} - return ivpsolve.solve_and_save_every_step(*args, **kwargs) + return ivpsolve.solve_adaptive_save_every_step(*args, **kwargs) @testing.fixture(name="reference_solution") diff --git a/tests/test_ivpsolve/test_solution_object.py b/tests/test_ivpsolve/test_solution_object.py index a7704ba1..b14aaf0c 100644 --- a/tests/test_ivpsolve/test_solution_object.py +++ b/tests/test_ivpsolve/test_solution_object.py @@ -1,6 +1,6 @@ """Tests for interaction with the solution object.""" -from probdiffeq import adaptive, ivpsolve +from probdiffeq import ivpsolve from probdiffeq.backend import functools, testing from probdiffeq.backend import numpy as np from probdiffeq.impl import impl @@ -18,13 +18,13 @@ def fixture_approximate_solution(): ts0 = components.correction_ts0() strategy = components.strategy_filter(ibm, ts0) solver = solvers.mle(strategy) - adaptive_solver = adaptive.adaptive(solver, atol=1e-2, rtol=1e-2) + adaptive_solver = ivpsolve.adaptive(solver, atol=1e-2, rtol=1e-2) output_scale = np.ones_like(impl.prototypes.output_scale()) tcoeffs = autodiff.taylor_mode_scan(lambda y: vf(y, t=t0), u0, num=1) init = solver.initial_condition(tcoeffs, output_scale) - return ivpsolve.solve_and_save_every_step( + return ivpsolve.solve_adaptive_save_every_step( vf, init, t0=t0, t1=t1, dt0=0.1, adaptive_solver=adaptive_solver ) @@ -65,7 +65,7 @@ def fixture_approximate_solution_batched(): ts0 = components.correction_ts0() strategy = components.strategy_filter(ibm, ts0) solver = solvers.mle(strategy) - adaptive_solver = adaptive.adaptive(solver, atol=1e-2, rtol=1e-2) + adaptive_solver = ivpsolve.adaptive(solver, atol=1e-2, rtol=1e-2) output_scale = np.ones_like(impl.prototypes.output_scale()) save_at = np.linspace(t0, t1, endpoint=True, num=4) @@ -74,7 +74,7 @@ def fixture_approximate_solution_batched(): def solve(init): tcoeffs = (init, vf(init, t=None)) initcond = solver.initial_condition(tcoeffs, output_scale=output_scale) - return ivpsolve.solve_and_save_at( + return ivpsolve.solve_adaptive_save_at( vf, initcond, save_at=save_at, adaptive_solver=adaptive_solver, dt0=0.1 ) diff --git a/tests/test_ivpsolve/test_terminal_values_vs_save_every_step.py b/tests/test_ivpsolve/test_terminal_values_vs_save_every_step.py index c3d604f7..47edd10d 100644 --- a/tests/test_ivpsolve/test_terminal_values_vs_save_every_step.py +++ b/tests/test_ivpsolve/test_terminal_values_vs_save_every_step.py @@ -1,6 +1,6 @@ -"""Compare simulate_terminal_values to solve_and_save_every_step.""" +"""Compare simulate_terminal_values to solve_adaptive_save_every_step.""" -from probdiffeq import adaptive, ivpsolve +from probdiffeq import ivpsolve from probdiffeq.backend import numpy as np from probdiffeq.backend import testing, tree_util from probdiffeq.impl import impl @@ -18,7 +18,7 @@ def fixture_problem_args_kwargs(): ts0 = components.correction_ts0() strategy = components.strategy_filter(ibm, ts0) solver = solvers.mle(strategy) - adaptive_solver = adaptive.adaptive(solver, atol=1e-2, rtol=1e-2) + adaptive_solver = ivpsolve.adaptive(solver, atol=1e-2, rtol=1e-2) tcoeffs = autodiff.taylor_mode_scan(lambda y: vf(y, t=t0), u0, num=2) output_scale = np.ones_like(impl.prototypes.output_scale()) @@ -33,13 +33,13 @@ def fixture_problem_args_kwargs(): def fixture_solution_with_python_while_loop(problem_args_kwargs): args, kwargs = problem_args_kwargs - return ivpsolve.solve_and_save_every_step(*args, **kwargs) + return ivpsolve.solve_adaptive_save_every_step(*args, **kwargs) @testing.fixture(name="simulation_terminal_values") def fixture_simulation_terminal_values(problem_args_kwargs): args, kwargs = problem_args_kwargs - return ivpsolve.simulate_terminal_values(*args, **kwargs) + return ivpsolve.solve_adaptive_terminal_values(*args, **kwargs) def test_terminal_values_identical(solution_python_loop, simulation_terminal_values): diff --git a/tests/test_solvers/test_calibrated/test_mle_calibration_vs_calibrationfree.py b/tests/test_solvers/test_calibrated/test_mle_calibration_vs_calibrationfree.py index 3d979cbc..d2290621 100644 --- a/tests/test_solvers/test_calibrated/test_mle_calibration_vs_calibrationfree.py +++ b/tests/test_solvers/test_calibrated/test_mle_calibration_vs_calibrationfree.py @@ -5,7 +5,7 @@ After applying stats.calibrate(), the posterior is different. """ -from probdiffeq import adaptive, ivpsolve +from probdiffeq import ivpsolve from probdiffeq.backend import numpy as np from probdiffeq.backend import testing from probdiffeq.impl import impl @@ -29,7 +29,7 @@ def solver_to_solution(solver): @testing.case() -def case_solve_and_save_at(): +def case_solve_adaptive_save_at(): vf, u0, (t0, t1) = setup.ode() dt0 = ivpsolve.dt0(lambda y: vf(y, t=t0), u0) tcoeffs = autodiff.taylor_mode_scan(lambda y: vf(y, t=t0), u0, num=4) @@ -38,8 +38,8 @@ def case_solve_and_save_at(): def solver_to_solution(solver): init = solver.initial_condition(tcoeffs, output_scale) - adaptive_solver = adaptive.adaptive(solver, atol=1e-2, rtol=1e-2) - return ivpsolve.solve_and_save_at( + adaptive_solver = ivpsolve.adaptive(solver, atol=1e-2, rtol=1e-2) + return ivpsolve.solve_adaptive_save_at( vf, init, adaptive_solver=adaptive_solver, **kwargs ) @@ -47,7 +47,7 @@ def solver_to_solution(solver): @testing.case() -def case_solve_and_save_every_step(): +def case_solve_adaptive_save_every_step(): vf, u0, (t0, t1) = setup.ode() dt0 = ivpsolve.dt0(lambda y: vf(y, t=t0), u0) tcoeffs = autodiff.taylor_mode_scan(lambda y: vf(y, t=t0), u0, num=4) @@ -56,8 +56,8 @@ def case_solve_and_save_every_step(): def solver_to_solution(solver): init = solver.initial_condition(tcoeffs, output_scale) - adaptive_solver = adaptive.adaptive(solver, atol=1e-2, rtol=1e-2) - return ivpsolve.solve_and_save_every_step( + adaptive_solver = ivpsolve.adaptive(solver, atol=1e-2, rtol=1e-2) + return ivpsolve.solve_adaptive_save_every_step( vf, init, adaptive_solver=adaptive_solver, **kwargs ) @@ -74,8 +74,8 @@ def case_simulate_terminal_values(): def solver_to_solution(solver): init = solver.initial_condition(tcoeffs, output_scale) - adaptive_solver = adaptive.adaptive(solver, atol=1e-2, rtol=1e-2) - return ivpsolve.simulate_terminal_values( + adaptive_solver = ivpsolve.adaptive(solver, atol=1e-2, rtol=1e-2) + return ivpsolve.solve_adaptive_terminal_values( vf, init, adaptive_solver=adaptive_solver, **kwargs ) diff --git a/tests/test_solvers/test_misc/test_warnings_for_wrong_strategies.py b/tests/test_solvers/test_misc/test_warnings_for_wrong_strategies.py index 64a70d9d..74544e39 100644 --- a/tests/test_solvers/test_misc/test_warnings_for_wrong_strategies.py +++ b/tests/test_solvers/test_misc/test_warnings_for_wrong_strategies.py @@ -1,6 +1,6 @@ """Some strategies don't work with all solution routines.""" -from probdiffeq import adaptive, ivpsolve +from probdiffeq import ivpsolve from probdiffeq.backend import numpy as np from probdiffeq.backend import testing from probdiffeq.impl import impl @@ -16,14 +16,14 @@ def test_warning_for_fixedpoint_in_save_every_step_mode(): ts0 = components.correction_ts0() strategy = components.strategy_fixedpoint(ibm, ts0) solver = solvers.solver(strategy) - adaptive_solver = adaptive.adaptive(solver, atol=1e-2, rtol=1e-2) + adaptive_solver = ivpsolve.adaptive(solver, atol=1e-2, rtol=1e-2) output_scale = np.ones_like(impl.prototypes.output_scale()) tcoeffs = autodiff.taylor_mode_scan(lambda y: vf(y, t=t0), (u0,), num=2) init = solver.initial_condition(tcoeffs, output_scale) with testing.warns(): - _ = ivpsolve.solve_and_save_every_step( + _ = ivpsolve.solve_adaptive_save_every_step( vf, init, t0=t0, t1=t1, adaptive_solver=adaptive_solver, dt0=0.1 ) @@ -35,14 +35,14 @@ def test_warning_for_smoother_in_save_at_mode(): ts0 = components.correction_ts0() strategy = components.strategy_smoother(ibm, ts0) solver = solvers.solver(strategy) - adaptive_solver = adaptive.adaptive(solver, atol=1e-2, rtol=1e-2) + adaptive_solver = ivpsolve.adaptive(solver, atol=1e-2, rtol=1e-2) output_scale = np.ones_like(impl.prototypes.output_scale()) tcoeffs = autodiff.taylor_mode_scan(lambda y: vf(y, t=t0), (u0,), num=2) init = solver.initial_condition(tcoeffs, output_scale) with testing.warns(): - _ = ivpsolve.solve_and_save_at( + _ = ivpsolve.solve_adaptive_save_at( vf, init, save_at=np.linspace(t0, t1), diff --git a/tests/test_solvers/test_stats/test_log_marginal_likelihood.py b/tests/test_solvers/test_stats/test_log_marginal_likelihood.py index 1686d53a..1068b2d0 100644 --- a/tests/test_solvers/test_stats/test_log_marginal_likelihood.py +++ b/tests/test_solvers/test_stats/test_log_marginal_likelihood.py @@ -1,6 +1,6 @@ """Tests for log-marginal-likelihood functionality.""" -from probdiffeq import adaptive, ivpsolve +from probdiffeq import ivpsolve from probdiffeq.backend import numpy as np from probdiffeq.backend import testing, tree_util from probdiffeq.impl import impl @@ -17,14 +17,14 @@ def fixture_sol(): ts0 = components.correction_ts0() strategy = components.strategy_fixedpoint(ibm, ts0) solver = solvers.solver(strategy) - adaptive_solver = adaptive.adaptive(solver, atol=1e-2, rtol=1e-2) + adaptive_solver = ivpsolve.adaptive(solver, atol=1e-2, rtol=1e-2) output_scale = np.ones_like(impl.prototypes.output_scale()) tcoeffs = autodiff.taylor_mode_scan(lambda y: vf(y, t=t0), (u0,), num=2) init = solver.initial_condition(tcoeffs, output_scale) save_at = np.linspace(t0, t1, endpoint=True, num=4) - return ivpsolve.solve_and_save_at( + return ivpsolve.solve_adaptive_save_at( vf, init, save_at=save_at, adaptive_solver=adaptive_solver, dt0=0.1 ) diff --git a/tests/test_solvers/test_stats/test_log_marginal_likelihood_terminal_values.py b/tests/test_solvers/test_stats/test_log_marginal_likelihood_terminal_values.py index ff623ed7..0f2c5cb8 100644 --- a/tests/test_solvers/test_stats/test_log_marginal_likelihood_terminal_values.py +++ b/tests/test_solvers/test_stats/test_log_marginal_likelihood_terminal_values.py @@ -1,6 +1,6 @@ """Tests for marginal log likelihood functionality (terminal values).""" -from probdiffeq import adaptive, ivpsolve +from probdiffeq import ivpsolve from probdiffeq.backend import numpy as np from probdiffeq.backend import testing from probdiffeq.impl import impl @@ -33,12 +33,12 @@ def fixture_sol(strategy_func): ts0 = components.correction_ts0() strategy = strategy_func(ibm, ts0) solver = solvers.solver(strategy) - adaptive_solver = adaptive.adaptive(solver, atol=1e-2, rtol=1e-2) + adaptive_solver = ivpsolve.adaptive(solver, atol=1e-2, rtol=1e-2) tcoeffs = autodiff.taylor_mode_scan(lambda y: vf(y, t=t0), (u0,), num=4) output_scale = np.ones_like(impl.prototypes.output_scale()) init = solver.initial_condition(tcoeffs, output_scale) - return ivpsolve.simulate_terminal_values( + return ivpsolve.solve_adaptive_terminal_values( vf, init, t0=t0, t1=t1, adaptive_solver=adaptive_solver, dt0=0.1 ) diff --git a/tests/test_solvers/test_stats/test_sample.py b/tests/test_solvers/test_stats/test_sample.py index 3e5a44f1..c4508820 100644 --- a/tests/test_solvers/test_stats/test_sample.py +++ b/tests/test_solvers/test_stats/test_sample.py @@ -1,6 +1,6 @@ """Tests for sampling behaviour.""" -from probdiffeq import adaptive, ivpsolve +from probdiffeq import ivpsolve from probdiffeq.backend import numpy as np from probdiffeq.backend import random, testing, tree_util from probdiffeq.impl import impl @@ -17,12 +17,12 @@ def fixture_approximation(): ts0 = components.correction_ts0() strategy = components.strategy_smoother(ibm, ts0) solver = solvers.solver(strategy) - adaptive_solver = adaptive.adaptive(solver, atol=1e-2, rtol=1e-2) + adaptive_solver = ivpsolve.adaptive(solver, atol=1e-2, rtol=1e-2) output_scale = np.ones_like(impl.prototypes.output_scale()) tcoeffs = autodiff.taylor_mode_scan(lambda y: vf(y, t=t0), (u0,), num=2) init = solver.initial_condition(tcoeffs, output_scale) - return ivpsolve.solve_and_save_every_step( + return ivpsolve.solve_adaptive_save_every_step( vf, init, t0=t0, t1=t1, adaptive_solver=adaptive_solver, dt0=0.1 ) diff --git a/tests/test_solvers/test_strategies/__init__.py b/tests/test_solvers/test_strategies/__init__.py index d028a2f4..cdfd34ea 100644 --- a/tests/test_solvers/test_strategies/__init__.py +++ b/tests/test_solvers/test_strategies/__init__.py @@ -6,7 +6,7 @@ using the same configuration (e.g. fixed grid solutions). Both should yield a reasonable approximation of the ODE solution. (This is a slightly unpredictable test because it depends highly on parameter choices.) -* The result of the fixed-point smoother in solve_and_save_at should be *identical* +* The result of the fixed-point smoother in solve_adaptive_save_at should be *identical* to interpolating the smoother results (we can reuse the solution from earlier). This is a strict test, and one that has failed many times in the past. diff --git a/tests/test_solvers/test_strategies/test_corrections/test_rmse_of_correction.py b/tests/test_solvers/test_strategies/test_corrections/test_rmse_of_correction.py index d53cff01..68187559 100644 --- a/tests/test_solvers/test_strategies/test_corrections/test_rmse_of_correction.py +++ b/tests/test_solvers/test_strategies/test_corrections/test_rmse_of_correction.py @@ -1,6 +1,6 @@ """Assert that every recipe yields a decent ODE approximation.""" -from probdiffeq import adaptive, ivpsolve +from probdiffeq import ivpsolve from probdiffeq.backend import numpy as np from probdiffeq.backend import ode, testing from probdiffeq.impl import impl @@ -67,14 +67,16 @@ def fixture_solution(correction_impl): ibm = components.prior_ibm(num_derivatives=2) strategy = components.strategy_filter(ibm, correction_impl) solver = solvers.mle(strategy) - adaptive_solver = adaptive.adaptive(solver, atol=1e-2, rtol=1e-2) + adaptive_solver = ivpsolve.adaptive(solver, atol=1e-2, rtol=1e-2) adaptive_kwargs = {"adaptive_solver": adaptive_solver, "dt0": 0.1} tcoeffs = autodiff.taylor_mode_scan(lambda y: vf(y, t=t0), u0, num=2) output_scale = np.ones_like(impl.prototypes.output_scale()) init = solver.initial_condition(tcoeffs, output_scale) - return ivpsolve.simulate_terminal_values(vf, init, t0=t0, t1=t1, **adaptive_kwargs) + return ivpsolve.solve_adaptive_terminal_values( + vf, init, t0=t0, t1=t1, **adaptive_kwargs + ) @testing.fixture(name="reference_solution") diff --git a/tests/test_solvers/test_strategies/test_smoother_vs_fixedpoint_equivalence.py b/tests/test_solvers/test_strategies/test_smoother_vs_fixedpoint_equivalence.py index 7b56a2a5..3bb426bf 100644 --- a/tests/test_solvers/test_strategies/test_smoother_vs_fixedpoint_equivalence.py +++ b/tests/test_solvers/test_strategies/test_smoother_vs_fixedpoint_equivalence.py @@ -3,7 +3,7 @@ That is, when called with correct adaptive- and checkpoint-setups. """ -from probdiffeq import adaptive, ivpsolve +from probdiffeq import ivpsolve from probdiffeq.backend import functools, testing, tree_util from probdiffeq.backend import numpy as np from probdiffeq.impl import impl @@ -34,11 +34,11 @@ def fixture_solution_smoother(solver_setup): ts0 = components.correction_ts0() strategy = components.strategy_smoother(ibm, ts0) solver = solvers.solver(strategy) - adaptive_solver = adaptive.adaptive(solver, atol=1e-3, rtol=1e-3) + adaptive_solver = ivpsolve.adaptive(solver, atol=1e-3, rtol=1e-3) tcoeffs, output_scale = solver_setup["tcoeffs"], solver_setup["output_scale"] init = solver.initial_condition(tcoeffs, output_scale) - return ivpsolve.solve_and_save_every_step( + return ivpsolve.solve_adaptive_save_every_step( solver_setup["vf"], init, t0=solver_setup["t0"], @@ -54,14 +54,14 @@ def test_fixedpoint_smoother_equivalent_same_grid(solver_setup, solution_smoothe ts0 = components.correction_ts0() strategy = components.strategy_fixedpoint(ibm, ts0) solver = solvers.solver(strategy) - adaptive_solver = adaptive.adaptive(solver, atol=1e-3, rtol=1e-3) + adaptive_solver = ivpsolve.adaptive(solver, atol=1e-3, rtol=1e-3) save_at = solution_smoother.t tcoeffs, output_scale = solver_setup["tcoeffs"], solver_setup["output_scale"] init = solver.initial_condition(tcoeffs, output_scale) - solution_fixedpoint = ivpsolve.solve_and_save_at( + solution_fixedpoint = ivpsolve.solve_adaptive_save_at( solver_setup["vf"], init, save_at=save_at, @@ -92,11 +92,11 @@ def test_fixedpoint_smoother_equivalent_different_grid(solver_setup, solution_sm ts0 = components.correction_ts0() strategy = components.strategy_fixedpoint(ibm, ts0) solver = solvers.solver(strategy) - adaptive_solver = adaptive.adaptive(solver, atol=1e-3, rtol=1e-3) + adaptive_solver = ivpsolve.adaptive(solver, atol=1e-3, rtol=1e-3) tcoeffs, output_scale = solver_setup["tcoeffs"], solver_setup["output_scale"] init = solver.initial_condition(tcoeffs, output_scale) - solution_fixedpoint = ivpsolve.solve_and_save_at( + solution_fixedpoint = ivpsolve.solve_adaptive_save_at( solver_setup["vf"], init, save_at=ts, adaptive_solver=adaptive_solver, dt0=0.1 )