Skip to content

Commit

Permalink
Merge pull request #502 from LilithHafner/lh/PyCall-extension
Browse files Browse the repository at this point in the history
Add PyCall extension
  • Loading branch information
ChrisRackauckas authored Oct 5, 2023
2 parents 75b1925 + 931449c commit aeb491a
Show file tree
Hide file tree
Showing 6 changed files with 92 additions and 1 deletion.
1 change: 1 addition & 0 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ jobs:
group:
- Core
- Downstream
- Python
version:
- '1'
steps:
Expand Down
5 changes: 4 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,11 @@ TruncatedStacktraces = "781d530d-4396-4725-bb49-402e4bee1e77"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[weakdeps]
PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[extensions]
PyCallExt = "PyCall"
ZygoteExt = "Zygote"

[compat]
Expand Down Expand Up @@ -69,11 +71,12 @@ Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
DelayDiffEq = "bcd4f6db-9728-5f36-b5f7-82caef46ccdb"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Pkg", "SafeTestsets", "Test", "StaticArrays", "StochasticDiffEq", "Aqua", "Zygote"]
test = ["Pkg", "PyCall", "SafeTestsets", "Test", "StaticArrays", "StochasticDiffEq", "Aqua", "Zygote"]
20 changes: 20 additions & 0 deletions ext/PyCallExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
module PyCallExt

using PyCall: PyCall, PyObject, PyAny, pyfunctionret, pyimport, hasproperty
using SciMLBase: SciMLBase, solve

# SciML uses a function's arity (number of arguments) to determine if it operates in place.
# PyCall does not preserve arity, so we inspect Python functions to find their arity.
function SciMLBase.numargs(f::PyObject)
inspect = pyimport("inspect")
f2 = hasproperty(f, :py_func) ? f.py_func : f
# if `f` is a bound method (i.e., `self.f`), `getfullargspec` includes
# `self` in the `args` list. So, we subtract 1 in that case:
length(first(inspect.getfullargspec(f2))) - inspect.ismethod(f2)
end

# differential equation solutions can be converted to lists, this tells PyCall not
# to perform that conversion automatically when a solution is returned from `solve`
PyCall.PyObject(::typeof(solve)) = pyfunctionret(solve, Any, Vararg{PyAny})

end
9 changes: 9 additions & 0 deletions test/python/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
[deps]
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"

[compat]
OrdinaryDiffEq = "6.33"
PyCall = "1.96"
SciMLBase = "2"
45 changes: 45 additions & 0 deletions test/python/pycall.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
using PyCall, SciMLBase, OrdinaryDiffEq

py""" # This is a mess because normal site-packages is not writeable in CI
import subprocess, sys, site
subprocess.run([sys.executable, '-m', 'pip', 'install', '--user', 'julia'])
sys.path.append(site.getusersitepackages())
"""

@testset "numargs" begin
py"""
def three_arg(a, b, c):
return a + b + c
def four_arg(a, b, c, d):
return a + b + c + d
class MyClass:
def three_arg_method(self, a, b, c):
return a + b + c
def four_arg_method(self, a, b, c, d):
return a + b + c + d
"""

@test SciMLBase.numargs(py"three_arg") === 3
@test SciMLBase.numargs(py"four_arg") === 4
x = py"MyClass()"
@test SciMLBase.numargs(x.three_arg_method) === 3
@test SciMLBase.numargs(x.four_arg_method) === 4
end

@testset "solution handling" begin
py"""
from julia import OrdinaryDiffEq as ode
def f(u,p,t):
return -u
u0 = 0.5
tspan = (0., 1.)
prob = ode.ODEProblem(f, u0, tspan)
sol = ode.solve(prob, ode.Tsit5())
"""
@test py"sol" isa ODESolution
end
13 changes: 13 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@ function activate_downstream_env()
Pkg.instantiate()
end

function activate_python_env()
Pkg.activate("python")
Pkg.develop(PackageSpec(path = dirname(@__DIR__)))
Pkg.instantiate()
end

@time begin
if GROUP == "Core" || GROUP == "All"
@time @safetestset "Aqua" begin
Expand Down Expand Up @@ -93,4 +99,11 @@ end
include("downstream/remake_autodiff.jl")
end
end

if !is_APPVEYOR && GROUP == "Python"
activate_python_env()
@time @safetestset "PyCall" begin
include("python/pycall.jl")
end
end
end

0 comments on commit aeb491a

Please sign in to comment.