Skip to content

Commit

Permalink
Move solver-components from package to a single module because we alw…
Browse files Browse the repository at this point in the history
…ays need one of each (#753)

* Collect all solver components in a module

* Delete the outdated (and empty) components subpackage

* Hide the implementation details in components.py

* Add the correction_ prefix to ts0, ts1, slr0, slr1 to distinguish them from other components

* Add prior_* prefix to all priors to distinguish them from other components

* Move cubature rules around (inside the module) to fix typing issues
  • Loading branch information
pnkraemer authored Jun 13, 2024
1 parent ede2d82 commit 1a821cd
Show file tree
Hide file tree
Showing 40 changed files with 393 additions and 434 deletions.
1 change: 1 addition & 0 deletions docs/api_docs/solvers/components.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
::: probdiffeq.solvers.components
1 change: 0 additions & 1 deletion docs/api_docs/solvers/components/corrections.md

This file was deleted.

1 change: 0 additions & 1 deletion docs/api_docs/solvers/components/cubature.md

This file was deleted.

1 change: 0 additions & 1 deletion docs/api_docs/solvers/components/priors.md

This file was deleted.

7 changes: 3 additions & 4 deletions docs/benchmarks/hires/run_hires.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@

from probdiffeq import adaptive, ivpsolve
from probdiffeq.impl import impl
from probdiffeq.solvers import solvers, strategies
from probdiffeq.solvers.components import corrections, priors
from probdiffeq.solvers import components, solvers, strategies
from probdiffeq.taylor import autodiff
from probdiffeq.util.doc_util import info

Expand Down Expand Up @@ -93,8 +92,8 @@ def vf_probdiffeq(u, *, t): # noqa: ARG001
@jax.jit
def param_to_solution(tol):
# Build a solver
ibm = priors.ibm_adaptive(num_derivatives=num_derivatives)
ts1 = corrections.ts1()
ibm = components.prior_ibm(num_derivatives=num_derivatives)
ts1 = components.correction_ts1()
strategy = strategies.filter_adaptive(ibm, ts1)
solver = solvers.dynamic(strategy)
control = adaptive.control_proportional_integral_clipped()
Expand Down
7 changes: 3 additions & 4 deletions docs/benchmarks/lotkavolterra/run_lotkavolterra.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@

from probdiffeq import adaptive, ivpsolve
from probdiffeq.impl import impl
from probdiffeq.solvers import solvers, strategies
from probdiffeq.solvers.components import corrections, priors
from probdiffeq.solvers import components, solvers, strategies
from probdiffeq.taylor import autodiff
from probdiffeq.util.doc_util import info

Expand Down Expand Up @@ -82,7 +81,7 @@ def vf_probdiffeq(y, *, t): # noqa: ARG001
def param_to_solution(tol):
impl.select(implementation, ode_shape=(2,))
# Build a solver
ibm = priors.ibm_adaptive(num_derivatives=num_derivatives)
ibm = components.prior_ibm(num_derivatives=num_derivatives)
strategy = strategies.filter_adaptive(ibm, correction())
solver = solvers.mle(strategy)
control = adaptive.control_proportional_integral()
Expand Down Expand Up @@ -249,7 +248,7 @@ def parameter_list_to_workprecision(list_of_args, /):
timeit_fun = timeit_fun_from_args(args)

# Assemble algorithms
ts0, ts1 = corrections.ts0, corrections.ts1
ts0, ts1 = components.correction_ts0, components.correction_ts1
ts0_iso = solver_probdiffeq(5, correction=ts0, implementation="isotropic")
ts0_bd = solver_probdiffeq(5, correction=ts0, implementation="blockdiag")
ts1_dense = solver_probdiffeq(8, correction=ts1, implementation="dense")
Expand Down
9 changes: 4 additions & 5 deletions docs/benchmarks/pleiades/run_pleiades.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@

from probdiffeq import adaptive, ivpsolve
from probdiffeq.impl import impl
from probdiffeq.solvers import solvers, strategies
from probdiffeq.solvers.components import corrections, priors
from probdiffeq.solvers import components, solvers, strategies
from probdiffeq.taylor import autodiff
from probdiffeq.util.doc_util import info

Expand Down Expand Up @@ -105,7 +104,7 @@ def vf_probdiffeq(u, du, *, t): # noqa: ARG001
@jax.jit
def param_to_solution(tol):
# Build a solver
ibm = priors.ibm_adaptive(num_derivatives=num_derivatives)
ibm = components.prior_ibm(num_derivatives=num_derivatives)
ts0_or_ts1 = correction_fun(ode_order=2)
strategy = strategies.filter_adaptive(ibm, ts0_or_ts1)
solver = solvers.dynamic(strategy)
Expand Down Expand Up @@ -333,10 +332,10 @@ def parameter_list_to_workprecision(list_of_args, /):
"Diffrax: Tsit5()": solver_diffrax(solver=diffrax.Tsit5()),
"Diffrax: Dopri8()": solver_diffrax(solver=diffrax.Dopri8()),
r"ProbDiffEq: TS0($5$)": solver_probdiffeq(
num_derivatives=5, correction_fun=corrections.ts0
num_derivatives=5, correction_fun=components.correction_ts0
),
r"ProbDiffEq: TS0($8$)": solver_probdiffeq(
num_derivatives=8, correction_fun=corrections.ts0
num_derivatives=8, correction_fun=components.correction_ts0
),
}

Expand Down
7 changes: 3 additions & 4 deletions docs/benchmarks/vanderpol/run_vanderpol.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@

from probdiffeq import adaptive, ivpsolve
from probdiffeq.impl import impl
from probdiffeq.solvers import solvers, strategies
from probdiffeq.solvers.components import corrections, priors
from probdiffeq.solvers import components, solvers, strategies
from probdiffeq.taylor import autodiff
from probdiffeq.util.doc_util import info

Expand Down Expand Up @@ -84,8 +83,8 @@ def vf_probdiffeq(u, du, *, t): # noqa: ARG001
@jax.jit
def param_to_solution(tol):
# Build a solver
ibm = priors.ibm_adaptive(num_derivatives=num_derivatives)
ts0_or_ts1 = corrections.ts1(ode_order=2)
ibm = components.prior_ibm(num_derivatives=num_derivatives)
ts0_or_ts1 = components.correction_ts1(ode_order=2)
strategy = strategies.filter_adaptive(ibm, ts0_or_ts1)
solver = solvers.dynamic(strategy)
control = adaptive.control_proportional_integral_clipped()
Expand Down
7 changes: 3 additions & 4 deletions docs/examples_misc/use_equinox_bounded_while_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,7 @@
from probdiffeq import adaptive, ivpsolve
from probdiffeq.backend import control_flow
from probdiffeq.impl import impl
from probdiffeq.solvers import solvers, strategies
from probdiffeq.solvers.components import corrections, priors
from probdiffeq.solvers import components, solvers, strategies
from probdiffeq.taylor import autodiff

jax.config.update("jax_platform_name", "cpu")
Expand Down Expand Up @@ -65,8 +64,8 @@ def vf(y, *, t): # noqa: ARG001
t0, t1 = 0.0, 1.0
u0 = jnp.asarray([0.1])

ibm = priors.ibm_adaptive(num_derivatives=1)
ts0 = corrections.ts0(ode_order=1)
ibm = components.prior_ibm(num_derivatives=1)
ts0 = components.correction_ts0(ode_order=1)

strategy = strategies.fixedpoint_adaptive(ibm, ts0)
solver = solvers.solver(strategy)
Expand Down
7 changes: 3 additions & 4 deletions docs/examples_parameter_estimation/neural_ode.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@

from probdiffeq import ivpsolve
from probdiffeq.impl import impl
from probdiffeq.solvers import solution, solvers, strategies
from probdiffeq.solvers.components import corrections, priors
from probdiffeq.solvers import components, solution, solvers, strategies
from probdiffeq.util.doc_util import notebook

# -
Expand Down Expand Up @@ -118,8 +117,8 @@ def vf(y, *, t, p):


# Make a solver
ibm = priors.ibm_adaptive(num_derivatives=1)
ts0 = corrections.ts0()
ibm = components.prior_ibm(num_derivatives=1)
ts0 = components.correction_ts0()
strategy = strategies.smoother_adaptive(ibm, ts0)
solver_ts0 = solvers.solver(strategy)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,7 @@

from probdiffeq import ivpsolve
from probdiffeq.impl import impl
from probdiffeq.solvers import solution, solvers, strategies
from probdiffeq.solvers.components import corrections, priors
from probdiffeq.solvers import components, solution, solvers, strategies
from probdiffeq.util.doc_util import notebook

# -
Expand Down Expand Up @@ -70,8 +69,8 @@ def vf(y, t, *, p): # noqa: ARG001

def solve(p):
"""Evaluate the parameter-to-solution map."""
ibm = priors.ibm_adaptive(num_derivatives=1)
ts0 = corrections.ts0()
ibm = components.prior_ibm(num_derivatives=1)
ts0 = components.correction_ts0()
strategy = strategies.smoother_adaptive(ibm, ts0)
solver = solvers.solver(strategy)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,7 @@

from probdiffeq import adaptive, ivpsolve
from probdiffeq.impl import impl
from probdiffeq.solvers import solution, solvers, strategies
from probdiffeq.solvers.components import corrections, priors
from probdiffeq.solvers import components, solution, solvers, strategies
from probdiffeq.taylor import autodiff
from probdiffeq.util.doc_util import notebook

Expand Down Expand Up @@ -191,8 +190,8 @@ def plot_solution(sol, *, ax, marker=".", **plotting_kwargs):
def solve_fixed(theta, *, ts):
"""Evaluate the parameter-to-solution map, solving on a fixed grid."""
# Create a probabilistic solver
ibm = priors.ibm_adaptive(num_derivatives=2)
ts0 = corrections.ts0()
ibm = components.prior_ibm(num_derivatives=2)
ts0 = components.correction_ts0()
strategy = strategies.filter_adaptive(ibm, ts0)
solver = solvers.solver(strategy)

Expand All @@ -208,8 +207,8 @@ def solve_fixed(theta, *, ts):
def solve_adaptive(theta, *, save_at):
"""Evaluate the parameter-to-solution map, solving on an adaptive grid."""
# Create a probabilistic solver
ibm = priors.ibm_adaptive(num_derivatives=2)
ts0 = corrections.ts0()
ibm = components.prior_ibm(num_derivatives=2)
ts0 = components.correction_ts0()
strategy = strategies.filter_adaptive(ibm, ts0)
solver = solvers.solver(strategy)
adaptive_solver = adaptive.adaptive(solver)
Expand Down
7 changes: 3 additions & 4 deletions docs/examples_quickstart/easy_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@

from probdiffeq import adaptive, ivpsolve
from probdiffeq.impl import impl
from probdiffeq.solvers import solvers, strategies
from probdiffeq.solvers.components import corrections, priors
from probdiffeq.solvers import components, solvers, strategies
from probdiffeq.taylor import autodiff

jax.config.update("jax_platform_name", "cpu")
Expand Down Expand Up @@ -84,8 +83,8 @@ def vf(y, *, t): # noqa: ARG001
#

# +
ibm = priors.ibm_adaptive(num_derivatives=4)
ts0 = corrections.ts1(ode_order=1)
ibm = components.prior_ibm(num_derivatives=4)
ts0 = components.correction_ts1(ode_order=1)

strategy = strategies.smoother_adaptive(ibm, ts0)
solver = solvers.solver(strategy)
Expand Down
9 changes: 4 additions & 5 deletions docs/examples_solver_config/conditioning-on-zero-residual.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@

from probdiffeq import adaptive, ivpsolve
from probdiffeq.impl import impl
from probdiffeq.solvers import markov, solvers, strategies
from probdiffeq.solvers.components import corrections, priors
from probdiffeq.solvers import components, markov, solvers, strategies
from probdiffeq.taylor import autodiff
from probdiffeq.util.doc_util import notebook

Expand Down Expand Up @@ -66,7 +65,7 @@ def vector_field(y, t): # noqa: ARG001

NUM_DERIVATIVES = 2
ts = jnp.linspace(t0, t1, num=500, endpoint=True)
init_raw, transitions = priors.ibm_discretised(
init_raw, transitions = components.prior_ibm_discrete(
ts, num_derivatives=NUM_DERIVATIVES, output_scale=100.0
)

Expand All @@ -84,8 +83,8 @@ def vector_field(y, t): # noqa: ARG001
# +
# Compute the posterior

slr1 = corrections.ts1()
ibm = priors.ibm_adaptive(num_derivatives=NUM_DERIVATIVES)
slr1 = components.correction_ts1()
ibm = components.prior_ibm(num_derivatives=NUM_DERIVATIVES)
solver = solvers.solver(strategies.fixedpoint_adaptive(ibm, slr1))
adaptive_solver = adaptive.adaptive(solver, atol=1e-1, rtol=1e-2)

Expand Down
7 changes: 3 additions & 4 deletions docs/examples_solver_config/dynamic_output_scales.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,7 @@

from probdiffeq import ivpsolve
from probdiffeq.impl import impl
from probdiffeq.solvers import solvers, strategies
from probdiffeq.solvers.components import corrections, priors
from probdiffeq.solvers import components, solvers, strategies
from probdiffeq.util.doc_util import notebook

# -
Expand Down Expand Up @@ -73,8 +72,8 @@ def vf(*ys, t): # noqa: ARG001
# +
num_derivatives = 1

ibm = priors.ibm_adaptive(num_derivatives=1)
ts1 = corrections.ts1()
ibm = components.prior_ibm(num_derivatives=1)
ts1 = components.correction_ts1()
strategy = strategies.filter_adaptive(ibm, ts1)
dynamic = solvers.dynamic(strategy)
mle = solvers.mle(strategy)
Expand Down
11 changes: 5 additions & 6 deletions docs/examples_solver_config/posterior_uncertainties.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@

from probdiffeq import adaptive, ivpsolve
from probdiffeq.impl import impl
from probdiffeq.solvers import markov, solution, solvers, strategies
from probdiffeq.solvers.components import corrections, priors
from probdiffeq.solvers import components, markov, solution, solvers, strategies
from probdiffeq.taylor import autodiff
from probdiffeq.util.doc_util import notebook

Expand Down Expand Up @@ -64,8 +63,8 @@ def vf(*ys, t): # noqa: ARG001
# ## Filter

# +
ibm = priors.ibm_adaptive(num_derivatives=4)
ts0 = corrections.ts0()
ibm = components.prior_ibm(num_derivatives=4)
ts0 = components.correction_ts0()
solver = solvers.mle(strategies.filter_adaptive(ibm, ts0))
adaptive_solver = adaptive.adaptive(solver, atol=1e-2, rtol=1e-2)

Expand Down Expand Up @@ -119,8 +118,8 @@ def vf(*ys, t): # noqa: ARG001
# ## Smoother

# +
ibm = priors.ibm_adaptive(num_derivatives=4)
ts0 = corrections.ts0()
ibm = components.prior_ibm(num_derivatives=4)
ts0 = components.correction_ts0()
solver = solvers.mle(strategies.fixedpoint_adaptive(ibm, ts0))
adaptive_solver = adaptive.adaptive(solver, atol=1e-2, rtol=1e-2)

Expand Down
11 changes: 5 additions & 6 deletions docs/examples_solver_config/second_order_problems.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@

from probdiffeq import adaptive, ivpsolve
from probdiffeq.impl import impl
from probdiffeq.solvers import solvers, strategies
from probdiffeq.solvers.components import corrections, priors
from probdiffeq.solvers import components, solvers, strategies
from probdiffeq.taylor import autodiff
from probdiffeq.util.doc_util import notebook

Expand Down Expand Up @@ -54,8 +53,8 @@ def vf_1(y, t): # noqa: ARG001
return f(y, *f_args)


ibm = priors.ibm_adaptive(num_derivatives=4)
ts0 = corrections.ts0()
ibm = components.prior_ibm(num_derivatives=4)
ts0 = components.correction_ts0()
solver_1st = solvers.mle(strategies.filter_adaptive(ibm, ts0))
adaptive_solver_1st = adaptive.adaptive(solver_1st, atol=1e-5, rtol=1e-5)

Expand Down Expand Up @@ -88,8 +87,8 @@ def vf_2(y, dy, t): # noqa: ARG001


# One derivative more than above because we don't transform to first order
ibm = priors.ibm_adaptive(num_derivatives=4)
ts0 = corrections.ts0(ode_order=2)
ibm = components.prior_ibm(num_derivatives=4)
ts0 = components.correction_ts0(ode_order=2)
solver_2nd = solvers.mle(strategies.filter_adaptive(ibm, ts0))
adaptive_solver_2nd = adaptive.adaptive(solver_2nd, atol=1e-5, rtol=1e-5)

Expand Down
5 changes: 1 addition & 4 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,8 @@ nav:
- adaptive: api_docs/adaptive.md
- impl: api_docs/impl.md
- solvers:
- components:
- priors: api_docs/solvers/components/priors.md
- corrections: api_docs/solvers/components/corrections.md
- cubature: api_docs/solvers/components/cubature.md
- strategies: api_docs/solvers/strategies.md
- components: api_docs/solvers/components.md
- solvers: api_docs/solvers/solvers.md
- solution: api_docs/solvers/solution.md
- markov: api_docs/solvers/markov.md
Expand Down
Loading

0 comments on commit 1a821cd

Please sign in to comment.