diff --git a/docs/benchmarks/taylor_node/run_taylor_node.py b/docs/benchmarks/taylor_node/run_taylor_node.py index e8233743..abf085aa 100644 --- a/docs/benchmarks/taylor_node/run_taylor_node.py +++ b/docs/benchmarks/taylor_node/run_taylor_node.py @@ -166,8 +166,7 @@ def adaptive_benchmark(fun, *, timeit_fun: Callable, max_time) -> dict: if __name__ == "__main__": set_jax_config() - if not backend.has_been_selected: - backend.select("jax") + backend.select("jax") algorithms = { r"Forward-mode": odejet_via_jvp(), diff --git a/probdiffeq/backend/ode.py b/probdiffeq/backend/ode.py index f36bd824..2436ec64 100644 --- a/probdiffeq/backend/ode.py +++ b/probdiffeq/backend/ode.py @@ -1,14 +1,8 @@ """ODE stuff.""" -import diffeqzoo.ivps -import diffrax import jax import jax.experimental.ode import jax.numpy as jnp -from diffeqzoo import backend - -# ODE examples must be in JAX -backend.select("jax") def odeint_and_save_at(vf, y0: tuple, /, save_at, *, atol, rtol): @@ -22,6 +16,9 @@ def vf_wrapped(y, t): def odeint_dense(vf, y0: tuple, /, t0, t1, *, atol, rtol): + # Local import because diffrax is not an official dependency + import diffrax + assert isinstance(y0, (tuple, list)) assert len(y0) == 1 @@ -53,7 +50,13 @@ def solution(t): def ivp_logistic(): - f, u0, (t0, _), f_args = diffeqzoo.ivps.logistic() + # Local imports because diffeqzoo is not an official dependency + from diffeqzoo import backend, ivps + + if not backend.has_been_selected: + backend.select("jax") + + f, u0, (t0, _), f_args = ivps.logistic() t1 = 0.75 @jax.jit @@ -64,7 +67,13 @@ def vf(x, *, t): # noqa: ARG001 def ivp_lotka_volterra(): - f, u0, (t0, _), f_args = diffeqzoo.ivps.lotka_volterra() + # Local imports because diffeqzoo is not an official dependency + from diffeqzoo import backend, ivps + + if not backend.has_been_selected: + backend.select("jax") + + f, u0, (t0, _), f_args = ivps.lotka_volterra() t1 = 2.0 # Short time-intervals are sufficient for this test. @jax.jit @@ -103,7 +112,13 @@ def solution(t): def ivp_three_body_1st(): - f, u0, (t0, t1), f_args = diffeqzoo.ivps.three_body_restricted_first_order() + # Local imports because diffeqzoo is not an official dependency + from diffeqzoo import backend, ivps + + if not backend.has_been_selected: + backend.select("jax") + + f, u0, (t0, t1), f_args = ivps.three_body_restricted_first_order() def vf(u, *, t): # noqa: ARG001 return f(u, *f_args) @@ -112,7 +127,13 @@ def vf(u, *, t): # noqa: ARG001 def ivp_van_der_pol_2nd(): - f, (u0, du0), (t0, t1), f_args = diffeqzoo.ivps.van_der_pol() + # Local imports because diffeqzoo is not an official dependency + from diffeqzoo import backend, ivps + + if not backend.has_been_selected: + backend.select("jax") + + f, (u0, du0), (t0, t1), f_args = ivps.van_der_pol() def vf(u, du, *, t): # noqa: ARG001 return f(u, du, *f_args)