Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Switch backend from PyCall to PythonCall and improve package management #118

Merged
merged 15 commits into from
Oct 10, 2023
31 changes: 15 additions & 16 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -162,16 +162,15 @@ import numba
numba_f = numba.jit(f)

prob = de.ODEProblem(numba_f, u0, tspan)
sol = de.solve(prob)
sol = de.solve(prob) # ERROR
```

Additionally, you can directly define the functions in Julia. This will allow
for more specialization and could be helpful to increase the efficiency over
the Numba version for repeat or long calls. This is done via `julia.Main.eval`:
the Numba version for repeat or long calls. This is done via `seval`:

```py
from julia import Main
jul_f = Main.eval("(u,p,t)->-u") # Define the anonymous function in Julia
jul_f = de.seval("(u,p,t)->-u") # Define the anonymous function in Julia
prob = de.ODEProblem(jul_f, u0, tspan)
sol = de.solve(prob)
```
Expand All @@ -195,7 +194,7 @@ p = [10.0,28.0,8/3]
prob = de.ODEProblem(f, u0, tspan, p)
sol = de.solve(prob,saveat=0.01)

plt.plot(sol.t,sol.u)
plt.plot(sol.t,de.transpose(de.stack(sol.u)))
plt.show()
```

Expand All @@ -204,11 +203,11 @@ plt.show()
or we can draw the phase plot:

```py
ut = numpy.transpose(sol.u)
us = de.stack(sol.u)
from mpl_toolkits.mplot3d import Axes3D
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.plot(ut[0,:],ut[1,:],ut[2,:])
ax.plot(us[0,:],us[1,:],us[2,:])
plt.show()
```

Expand Down Expand Up @@ -240,7 +239,7 @@ sol = de.solve(prob)
or using a Julia function:

```py
jul_f = Main.eval("""
jul_f = de.seval("""
function f(du,u,p,t)
x, y, z = u
sigma, rho, beta = p
Expand Down Expand Up @@ -275,7 +274,7 @@ tspan = (0.0,1.0)
prob = de.SDEProblem(f,g,u0,tspan)
sol = de.solve(prob,reltol=1e-3,abstol=1e-3)

plt.plot(sol.t,sol.u)
plt.plot(sol.t,de.stack(sol.u))
plt.show()
```

Expand Down Expand Up @@ -310,11 +309,11 @@ sol = de.solve(prob)

# Now let's draw a phase plot

ut = numpy.transpose(sol.u)
us = de.stack(sol.u)
from mpl_toolkits.mplot3d import Axes3D
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.plot(ut[0,:],ut[1,:],ut[2,:])
ax.plot(us[0,:],us[1,:],us[2,:])
plt.show()
```

Expand Down Expand Up @@ -359,11 +358,11 @@ sol = de.solve(prob,saveat=0.005)

# Now let's draw a phase plot

ut = numpy.transpose(sol.u)
us = de.stack(sol.u)
from mpl_toolkits.mplot3d import Axes3D
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.plot(ut[0,:],ut[1,:],ut[2,:])
ax.plot(us[0,:],us[1,:],us[2,:])
plt.show()
```

Expand Down Expand Up @@ -412,7 +411,7 @@ def f(resid,du,u,p,t):

numba_f = numba.jit(f)
prob = de.DAEProblem(numba_f,du0,u0,tspan,differential_vars=differential_vars)
sol = de.solve(prob)
sol = de.solve(prob) # ERROR
```

## Delay Differential Equations
Expand All @@ -430,14 +429,14 @@ the solver accuracy by accurately stepping at the points of discontinuity.
Together this is:

```py
f = Main.eval("""
f = de.seval("""
function f(du, u, h, p, t)
du[1] = 1.1/(1 + sqrt(10)*(h(p, t-20)[1])^(5/4)) - 10*u[1]/(1 + 40*u[2])
du[2] = 100*u[1]/(1 + 40*u[2]) - 2.43*u[2]
end""")
u0 = [1.05767027/3, 1.030713491/3]

h = Main.eval("""
h = de.seval("""
function h(p,t)
[1.05767027/3, 1.030713491/3]
end
Expand Down
9 changes: 9 additions & 0 deletions diffeqpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,12 @@ def _ensure_installed(*kwargs):
if not _find_julia():
# TODO: this should probably ensure that packages are installed too
install(*kwargs)

# TODO: upstream this function or an alternative into juliacall
def load_julia_package(name):
# This is terrifying to many people. However, it seems SciML takes pragmatic approach.
_ensure_installed()

# Must be loaded after `_ensure_installed()`
from juliacall import Main
return Main.seval(f"import Pkg; Pkg.activate(\"diffeqpy\", shared=true); import {name}; {name}")
17 changes: 2 additions & 15 deletions diffeqpy/de.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,3 @@
import os
import sys

from . import _ensure_installed

# This is terrifying to many people. However, it seems SciML takes pragmatic approach.
_ensure_installed()

# PyJulia have to be loaded after `_ensure_installed()`
from julia import Main

script_dir = os.path.dirname(os.path.realpath(__file__))
Main.include(os.path.join(script_dir, "setup.jl"))

from julia import DifferentialEquations
sys.modules[__name__] = DifferentialEquations # mutate myself
from . import load_julia_package
sys.modules[__name__] = load_julia_package("DifferentialEquations") # mutate myself
10 changes: 3 additions & 7 deletions diffeqpy/install.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,4 @@
using Pkg
Pkg.add("DifferentialEquations")
Pkg.add("OrdinaryDiffEq")
Pkg.add("DiffEqBase")
Pkg.add("PyCall")
Pkg.build("PyCall")
using DifferentialEquations
using PyCall
Pkg.activate("diffeqpy", shared=true)
Pkg.add(["DifferentialEquations", "OrdinaryDiffEq", "PythonCall"])
using DifferentialEquations, OrdinaryDiffEq, PythonCall # Precompile
11 changes: 2 additions & 9 deletions diffeqpy/ode.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,3 @@
import os
import sys

from julia import Main

script_dir = os.path.dirname(os.path.realpath(__file__))
Main.include(os.path.join(script_dir, "setup.jl"))

from julia import OrdinaryDiffEq
sys.modules[__name__] = OrdinaryDiffEq # mutate myself
from . import load_julia_package
sys.modules[__name__] = load_julia_package("OrdinaryDiffEq") # mutate myself
34 changes: 0 additions & 34 deletions diffeqpy/setup.jl

This file was deleted.

9 changes: 3 additions & 6 deletions diffeqpy/tests/test_dde.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,18 @@
from julia import Main

from .. import de


def test():
f = Main.eval("""
f = de.seval("""
function f(du, u, h, p, t)
du[1] = 1.1/(1 + sqrt(10)*(h(p, t-20)[1])^(5/4)) - 10*u[1]/(1 + 40*u[2])
du[2] = 100*u[1]/(1 + 40*u[2]) - 2.43*u[2]
end""")
u0 = [1.05767027/3, 1.030713491/3]

h = Main.eval("""
h = de.seval("""
function h(p,t)
[1.05767027/3, 1.030713491/3]
end
""")
end""")

tspan = (0.0, 100.0)
constant_lags = [20.0]
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,6 @@ def readme():
author_email='[email protected]',
license='MIT',
packages=['diffeqpy','diffeqpy.tests'],
install_requires=['julia>=0.2', 'jill'],
install_requires=['juliacall>=0.9.14', 'jill'],
include_package_data=True,
zip_safe=False)
Loading