Skip to content

Commit

Permalink
Merge pull request #119 from LilithHafner/lh/jit
Browse files Browse the repository at this point in the history
Bundle ModelingToolkit and add de.jit
  • Loading branch information
ChrisRackauckas authored Oct 10, 2023
2 parents 92bbdd6 + fad6b6f commit fc32432
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 41 deletions.
54 changes: 21 additions & 33 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,7 @@ interpreter then run:
>>> diffeqpy.install()
```

and you're good! In addition, to improve the performance of your code it is
recommended that you use Numba to JIT compile your derivative functions. To
install Numba, use:

```
pip install numba
```
and you're good!

## General Flow

Expand Down Expand Up @@ -150,32 +144,30 @@ sol = de.solve(prob,de.Vern9(),saveat=0.1,abstol=1e-10,reltol=1e-10)
The set of algorithms for ODEs is described
[at the ODE solvers page](http://diffeq.sciml.ai/dev/solvers/ode_solve).

### Compilation with Numba and Julia
### Compilation with `de.jit` and Julia

When solving a differential equation, it's pertinent that your derivative
function `f` is fast since it occurs in the inner loop of the solver. We can
utilize Numba to JIT compile our derivative functions to improve the efficiency
of the solver:
convert the entire ode problem to symbolic form, optimize that symbolic form,
and emit efficient native code to simulate it using `de.jit` to improve the
efficiency of the solver at the expense of added setup time:

```py
import numba
numba_f = numba.jit(f)

prob = de.ODEProblem(numba_f, u0, tspan)
sol = de.solve(prob) # ERROR
fast_prob = de.jit(prob)
sol = de.solve(fast_prob)
```

Additionally, you can directly define the functions in Julia. This will allow
for more specialization and could be helpful to increase the efficiency over
the Numba version for repeat or long calls. This is done via `seval`:
Additionally, you can directly define the functions in Julia. This will also
allow for specialization and could be helpful to increase the efficiency for
repeat or long calls. This is done via `seval`:

```py
jul_f = de.seval("(u,p,t)->-u") # Define the anonymous function in Julia
prob = de.ODEProblem(jul_f, u0, tspan)
sol = de.solve(prob)
```

#### Note that when using Numba, one must avoid Python lists and pass state and parameters as NumPy arrays!
#### Note that when using `de.jit`, certain undocumented restrictions apply!!

### Systems of ODEs: Lorenz Equations

Expand Down Expand Up @@ -228,12 +220,12 @@ def f(du,u,p,t):
du[1] = x * (rho - z) - y
du[2] = x * y - beta * z

numba_f = numba.jit(f)
u0 = [1.0,0.0,0.0]
tspan = (0., 100.)
p = [10.0,28.0,2.66]
prob = de.ODEProblem(numba_f, u0, tspan, p)
sol = de.solve(prob)
prob = de.ODEProblem(f, u0, tspan, p)
jit_prob = de.jit(prob)
sol = de.solve(jit_prob)
```

or using a Julia function:
Expand Down Expand Up @@ -299,12 +291,10 @@ def g(du,u,p,t):
du[1] = 0.3*u[1]
du[2] = 0.3*u[2]

numba_f = numba.jit(f)
numba_g = numba.jit(g)
u0 = [1.0,0.0,0.0]
tspan = (0., 100.)
p = [10.0,28.0,2.66]
prob = de.SDEProblem(numba_f, numba_g, u0, tspan, p)
prob = de.jit(de.SDEProblem(f, g, u0, tspan, p))
sol = de.solve(prob)

# Now let's draw a phase plot
Expand Down Expand Up @@ -351,10 +341,9 @@ u0 = [1.0,0.0,0.0]
tspan = (0.0,100.0)
p = [10.0,28.0,2.66]
nrp = numpy.zeros((3,2))
numba_f = numba.jit(f)
numba_g = numba.jit(g)
prob = de.SDEProblem(numba_f,numba_g,u0,tspan,p,noise_rate_prototype=nrp)
sol = de.solve(prob,saveat=0.005)
prob = de.SDEProblem(f,g,u0,tspan,p,noise_rate_prototype=nrp)
jit_prob = de.jit(prob)
sol = de.solve(jit_prob,saveat=0.005)

# Now let's draw a phase plot

Expand Down Expand Up @@ -409,9 +398,9 @@ def f(resid,du,u,p,t):
resid[1] = + 0.04*u[0] - 3e7*u[1]**2 - 1e4*u[1]*u[2] - du[1]
resid[2] = u[0] + u[1] + u[2] - 1.0

numba_f = numba.jit(f)
prob = de.DAEProblem(numba_f,du0,u0,tspan,differential_vars=differential_vars)
sol = de.solve(prob) # ERROR
prob = de.DAEProblem(f,du0,u0,tspan,differential_vars=differential_vars)
jit_prob = de.jit(prob) # Error: no method matching matching modelingtoolkitize(::SciMLBase.DAEProblem{...})
sol = de.solve(jit_prob)
```

## Delay Differential Equations
Expand Down Expand Up @@ -476,7 +465,6 @@ Unit tests can be run by [`tox`](http://tox.readthedocs.io).

```sh
tox
tox -e py3-numba # test with Numba
```

### Troubleshooting
Expand Down
4 changes: 2 additions & 2 deletions diffeqpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,10 @@ def _ensure_installed(*kwargs):
install(*kwargs)

# TODO: upstream this function or an alternative into juliacall
def load_julia_package(name):
def load_julia_packages(names):
# This is terrifying to many people. However, it seems SciML takes pragmatic approach.
_ensure_installed()

# Must be loaded after `_ensure_installed()`
from juliacall import Main
return Main.seval(f"import Pkg; Pkg.activate(\"diffeqpy\", shared=true); import {name}; {name}")
return Main.seval(f"import Pkg; Pkg.activate(\"diffeqpy\", shared=true); import {names}; {names}")
7 changes: 5 additions & 2 deletions diffeqpy/de.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import sys
from . import load_julia_package
sys.modules[__name__] = load_julia_package("DifferentialEquations") # mutate myself
from . import load_julia_packages
de, _ = load_julia_packages("DifferentialEquations, ModelingToolkit")
from juliacall import Main
de.jit = Main.seval("jit(x) = typeof(x).name.wrapper(ModelingToolkit.modelingtoolkitize(x), x.u0, x.tspan, x.p)") # kinda hackey
sys.modules[__name__] = de # mutate myself
4 changes: 2 additions & 2 deletions diffeqpy/install.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using Pkg
Pkg.activate("diffeqpy", shared=true)
Pkg.add(["DifferentialEquations", "OrdinaryDiffEq", "PythonCall"])
using DifferentialEquations, OrdinaryDiffEq, PythonCall # Precompile
Pkg.add(["DifferentialEquations", "ModelingToolkit", "OrdinaryDiffEq","PythonCall"])
using DifferentialEquations, ModelingToolkit, OrdinaryDiffEq, PythonCall # Precompile
4 changes: 2 additions & 2 deletions diffeqpy/ode.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
import sys
from . import load_julia_package
sys.modules[__name__] = load_julia_package("OrdinaryDiffEq") # mutate myself
from . import load_julia_packages
sys.modules[__name__] = load_julia_packages("OrdinaryDiffEq") # mutate myself

0 comments on commit fc32432

Please sign in to comment.