Skip to content

Commit

Permalink
Use local imports in backend.ode to avoid diffeqzoo & diffrax depende…
Browse files Browse the repository at this point in the history
…nceis (#763)
  • Loading branch information
pnkraemer authored Jun 13, 2024
1 parent 13c41fd commit bd08a9a
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 12 deletions.
3 changes: 1 addition & 2 deletions docs/benchmarks/taylor_node/run_taylor_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,8 +166,7 @@ def adaptive_benchmark(fun, *, timeit_fun: Callable, max_time) -> dict:
if __name__ == "__main__":
set_jax_config()

if not backend.has_been_selected:
backend.select("jax")
backend.select("jax")

algorithms = {
r"Forward-mode": odejet_via_jvp(),
Expand Down
41 changes: 31 additions & 10 deletions probdiffeq/backend/ode.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,8 @@
"""ODE stuff."""

import diffeqzoo.ivps
import diffrax
import jax
import jax.experimental.ode
import jax.numpy as jnp
from diffeqzoo import backend

# ODE examples must be in JAX
backend.select("jax")


def odeint_and_save_at(vf, y0: tuple, /, save_at, *, atol, rtol):
Expand All @@ -22,6 +16,9 @@ def vf_wrapped(y, t):


def odeint_dense(vf, y0: tuple, /, t0, t1, *, atol, rtol):
# Local import because diffrax is not an official dependency
import diffrax

assert isinstance(y0, (tuple, list))
assert len(y0) == 1

Expand Down Expand Up @@ -53,7 +50,13 @@ def solution(t):


def ivp_logistic():
f, u0, (t0, _), f_args = diffeqzoo.ivps.logistic()
# Local imports because diffeqzoo is not an official dependency
from diffeqzoo import backend, ivps

if not backend.has_been_selected:
backend.select("jax")

f, u0, (t0, _), f_args = ivps.logistic()
t1 = 0.75

@jax.jit
Expand All @@ -64,7 +67,13 @@ def vf(x, *, t): # noqa: ARG001


def ivp_lotka_volterra():
f, u0, (t0, _), f_args = diffeqzoo.ivps.lotka_volterra()
# Local imports because diffeqzoo is not an official dependency
from diffeqzoo import backend, ivps

if not backend.has_been_selected:
backend.select("jax")

f, u0, (t0, _), f_args = ivps.lotka_volterra()
t1 = 2.0 # Short time-intervals are sufficient for this test.

@jax.jit
Expand Down Expand Up @@ -103,7 +112,13 @@ def solution(t):


def ivp_three_body_1st():
f, u0, (t0, t1), f_args = diffeqzoo.ivps.three_body_restricted_first_order()
# Local imports because diffeqzoo is not an official dependency
from diffeqzoo import backend, ivps

if not backend.has_been_selected:
backend.select("jax")

f, u0, (t0, t1), f_args = ivps.three_body_restricted_first_order()

def vf(u, *, t): # noqa: ARG001
return f(u, *f_args)
Expand All @@ -112,7 +127,13 @@ def vf(u, *, t): # noqa: ARG001


def ivp_van_der_pol_2nd():
f, (u0, du0), (t0, t1), f_args = diffeqzoo.ivps.van_der_pol()
# Local imports because diffeqzoo is not an official dependency
from diffeqzoo import backend, ivps

if not backend.has_been_selected:
backend.select("jax")

f, (u0, du0), (t0, t1), f_args = ivps.van_der_pol()

def vf(u, du, *, t): # noqa: ARG001
return f(u, du, *f_args)
Expand Down

0 comments on commit bd08a9a

Please sign in to comment.