Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add PyCall extension #502

Merged
merged 18 commits into from
Oct 5, 2023
1 change: 1 addition & 0 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ jobs:
group:
- Core
- Downstream
- Python
version:
- '1'
steps:
Expand Down
5 changes: 4 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,11 @@ TruncatedStacktraces = "781d530d-4396-4725-bb49-402e4bee1e77"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[weakdeps]
PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[extensions]
PyCallExt = "PyCall"
ZygoteExt = "Zygote"

[compat]
Expand Down Expand Up @@ -69,11 +71,12 @@ Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
DelayDiffEq = "bcd4f6db-9728-5f36-b5f7-82caef46ccdb"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Pkg", "SafeTestsets", "Test", "StaticArrays", "StochasticDiffEq", "Aqua", "Zygote"]
test = ["Pkg", "PyCall", "SafeTestsets", "Test", "StaticArrays", "StochasticDiffEq", "Aqua", "Zygote"]
20 changes: 20 additions & 0 deletions ext/PyCallExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
module PyCallExt

using PyCall: PyCall, PyObject, PyAny, pyfunctionret, pyimport, hasproperty
using SciMLBase: SciMLBase, solve

# SciML uses a function's arity (number of arguments) to determine if it operates in place.
# PyCall does not preserve arity, so we inspect Python functions to find their arity.
function SciMLBase.numargs(f::PyObject)
inspect = pyimport("inspect")
f2 = hasproperty(f, :py_func) ? f.py_func : f

Check warning on line 10 in ext/PyCallExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/PyCallExt.jl#L8-L10

Added lines #L8 - L10 were not covered by tests
# if `f` is a bound method (i.e., `self.f`), `getfullargspec` includes
# `self` in the `args` list. So, we subtract 1 in that case:
length(first(inspect.getfullargspec(f2))) - inspect.ismethod(f2)

Check warning on line 13 in ext/PyCallExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/PyCallExt.jl#L13

Added line #L13 was not covered by tests
end

# differential equation solutions can be converted to lists, this tells PyCall not
# to perform that conversion automatically when a solution is returned from `solve`
PyCall.PyObject(::typeof(solve)) = pyfunctionret(solve, Any, Vararg{PyAny})

Check warning on line 18 in ext/PyCallExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/PyCallExt.jl#L18

Added line #L18 was not covered by tests

end
9 changes: 9 additions & 0 deletions test/python/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
[deps]
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"

[compat]
OrdinaryDiffEq = "6.33"
PyCall = "1.96"
SciMLBase = "2"
45 changes: 45 additions & 0 deletions test/python/pycall.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
using PyCall, SciMLBase, OrdinaryDiffEq

py""" # This is a mess because normal site-packages is not writeable in CI
import subprocess, sys, site
subprocess.run([sys.executable, '-m', 'pip', 'install', '--user', 'julia'])
sys.path.append(site.getusersitepackages())
"""

@testset "numargs" begin
py"""
def three_arg(a, b, c):
return a + b + c

def four_arg(a, b, c, d):
return a + b + c + d

class MyClass:
def three_arg_method(self, a, b, c):
return a + b + c

def four_arg_method(self, a, b, c, d):
return a + b + c + d
"""

@test SciMLBase.numargs(py"three_arg") === 3
@test SciMLBase.numargs(py"four_arg") === 4
x = py"MyClass()"
@test SciMLBase.numargs(x.three_arg_method) === 3
@test SciMLBase.numargs(x.four_arg_method) === 4
end

@testset "solution handling" begin
py"""
from julia import OrdinaryDiffEq as ode

def f(u,p,t):
return -u

u0 = 0.5
tspan = (0., 1.)
prob = ode.ODEProblem(f, u0, tspan)
sol = ode.solve(prob, ode.Tsit5())
"""
@test py"sol" isa ODESolution
end
13 changes: 13 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@ function activate_downstream_env()
Pkg.instantiate()
end

function activate_python_env()
Pkg.activate("python")
Pkg.develop(PackageSpec(path = dirname(@__DIR__)))
Pkg.instantiate()
end

@time begin
if GROUP == "Core" || GROUP == "All"
@time @safetestset "Aqua" begin
Expand Down Expand Up @@ -93,4 +99,11 @@ end
include("downstream/remake_autodiff.jl")
end
end

if !is_APPVEYOR && GROUP == "Python"
activate_python_env()
@time @safetestset "PyCall" begin
include("python/pycall.jl")
end
end
end
Loading