Skip to content

Commit

Permalink
Merge pull request #519 from LilithHafner/lh/PythonCall-extension
Browse files Browse the repository at this point in the history
Add PythonCall Extension
  • Loading branch information
ChrisRackauckas authored Oct 7, 2023
2 parents 5d0d7e0 + bfd023d commit a8b39e9
Show file tree
Hide file tree
Showing 19 changed files with 267 additions and 74 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,6 @@ Manifest.toml
# vscode stuff
.vscode
.vscode/*

# python extensions
.CondaPkg
5 changes: 4 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,12 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[weakdeps]
PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0"
PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[extensions]
PyCallExt = "PyCall"
PythonCallExt = "PythonCall"
ZygoteExt = "Zygote"

[compat]
Expand Down Expand Up @@ -72,11 +74,12 @@ DelayDiffEq = "bcd4f6db-9728-5f36-b5f7-82caef46ccdb"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0"
PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Pkg", "PyCall", "SafeTestsets", "Test", "StaticArrays", "StochasticDiffEq", "Aqua", "Zygote"]
test = ["Pkg", "PyCall", "PythonCall", "SafeTestsets", "Test", "StaticArrays", "StochasticDiffEq", "Aqua", "Zygote"]
23 changes: 23 additions & 0 deletions ext/PythonCallExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
module PythonCallExt

using PythonCall: Py, PyList, pyimport, hasproperty, pyconvert, pyisinstance, pybuiltins
using SciMLBase: SciMLBase

# SciML uses a function's arity (number of arguments) to determine if it operates in place.
# PythonCall does not preserve arity, so we inspect Python functions to find their arity.
function SciMLBase.numargs(f::Py)
inspect = pyimport("inspect")
f2 = hasproperty(f, :py_func) ? f.py_func : f
# if `f` is a bound method (i.e., `self.f`), `getfullargspec` includes
# `self` in the `args` list. So, we subtract 1 in that case:
pyconvert(Int, length(first(inspect.getfullargspec(f2))) - inspect.ismethod(f2))
end

_pyconvert(x::Py) = pyisinstance(x, pybuiltins.list) ? [_pyconvert(x) for x in x] : pyconvert(Any, x)
_pyconvert(x::PyList) = [_pyconvert(x) for x in x]
_pyconvert(x) = x

SciMLBase.prepare_initial_state(u0::Union{Py, PyList}) = _pyconvert(u0)
SciMLBase.prepare_function(f::Py) = _pyconvert f

end
5 changes: 3 additions & 2 deletions src/problems/analytical_problems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@ struct AnalyticalProblem{uType, tType, isinplace, P, F, K} <:
kwargs::K
@add_kwonly function AnalyticalProblem{iip}(f, u0, tspan, p = NullParameters();
kwargs...) where {iip}
_u0 = prepare_initial_state(u0)
_tspan = promote_tspan(tspan)
warn_paramtype(p)
new{typeof(u0), typeof(_tspan), iip, typeof(p),
new{typeof(_u0), typeof(_tspan), iip, typeof(p),
typeof(f), typeof(kwargs)}(f,
u0,
_u0,
_tspan,
p,
kwargs)
Expand Down
5 changes: 3 additions & 2 deletions src/problems/bvp_problems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ struct BVProblem{uType, tType, isinplace, P, F, PT, K} <:

@add_kwonly function BVProblem{iip}(f::AbstractBVPFunction{iip, TP}, u0, tspan,
p = NullParameters(); problem_type=nothing, kwargs...) where {iip, TP}
_u0 = prepare_initial_state(u0)
_tspan = promote_tspan(tspan)
warn_paramtype(p)
prob_type = TP ? TwoPointBVProblem{iip}() : StandardBVProblem()
Expand All @@ -125,8 +126,8 @@ struct BVProblem{uType, tType, isinplace, P, F, PT, K} <:
else
@assert prob_type === problem_type "This indicates incorrect problem type specification! Users should never pass in `problem_type` kwarg, this exists exclusively for internal use."
end
return new{typeof(u0), typeof(_tspan), iip, typeof(p), typeof(f),
typeof(problem_type), typeof(kwargs)}(f, u0, _tspan, p, problem_type, kwargs)
return new{typeof(_u0), typeof(_tspan), iip, typeof(p), typeof(f),
typeof(problem_type), typeof(kwargs)}(f, _u0, _tspan, p, problem_type, kwargs)
end

function BVProblem{iip}(f, bc, u0, tspan, p = NullParameters(); kwargs...) where {iip}
Expand Down
12 changes: 7 additions & 5 deletions src/problems/dae_problems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,21 +80,23 @@ struct DAEProblem{uType, duType, tType, isinplace, P, F, K, D} <:
du0, u0, tspan, p = NullParameters();
differential_vars = nothing,
kwargs...) where {iip}
if !isnothing(u0)
_u0 = prepare_initial_state(u0)
_du0 = prepare_initial_state(du0)
if !isnothing(_u0)
# Defend against external solvers like Sundials breaking on non-uniform input dimensions.
size(du0) == size(u0) ||
size(_du0) == size(_u0) ||
throw(ArgumentError("Sizes of u0 and du0 must be the same."))
if !isnothing(differential_vars)
size(u0) == size(differential_vars) ||
size(_u0) == size(differential_vars) ||
throw(ArgumentError("Sizes of u0 and differential_vars must be the same."))
end
end
_tspan = promote_tspan(tspan)
warn_paramtype(p)
new{typeof(u0), typeof(du0), typeof(_tspan),
new{typeof(_u0), typeof(_du0), typeof(_tspan),
isinplace(f), typeof(p),
typeof(f), typeof(kwargs),
typeof(differential_vars)}(f, du0, u0, _tspan, p,
typeof(differential_vars)}(f, _du0, _u0, _tspan, p,
kwargs, differential_vars)
end

Expand Down
6 changes: 4 additions & 2 deletions src/problems/dde_problems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -224,11 +224,13 @@ struct DDEProblem{uType, tType, lType, lType2, isinplace, P, F, H, K, PT} <:
order_discontinuity_t0 = 0,
problem_type = StandardDDEProblem(),
kwargs...) where {iip}
_u0 = prepare_initial_state(u0)
_tspan = promote_tspan(tspan)
warn_paramtype(p)
new{typeof(u0), typeof(_tspan), typeof(constant_lags), typeof(dependent_lags),
new{typeof(_u0), typeof(_tspan), typeof(constant_lags), typeof(dependent_lags),
isinplace(f),
typeof(p), typeof(f), typeof(h), typeof(kwargs), typeof(problem_type)}(f, u0, h,
typeof(p), typeof(f), typeof(h), typeof(kwargs), typeof(problem_type)}(f, _u0,
h,
_tspan,
p,
constant_lags,
Expand Down
5 changes: 3 additions & 2 deletions src/problems/discrete_problems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,12 +90,13 @@ struct DiscreteProblem{uType, tType, isinplace, P, F, K} <:
@add_kwonly function DiscreteProblem{iip}(f::AbstractDiscreteFunction{iip},
u0, tspan::Tuple, p = NullParameters();
kwargs...) where {iip}
_u0 = prepare_initial_state(u0)
_tspan = promote_tspan(tspan)
warn_paramtype(p)
new{typeof(u0), typeof(_tspan), isinplace(f, 4),
new{typeof(_u0), typeof(_tspan), isinplace(f, 4),
typeof(p),
typeof(f), typeof(kwargs)}(f,
u0,
_u0,
_tspan,
p,
kwargs)
Expand Down
5 changes: 3 additions & 2 deletions src/problems/implicit_discrete_problems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,13 @@ struct ImplicitDiscreteProblem{uType, tType, isinplace, P, F, K} <:
u0, tspan::Tuple,
p = NullParameters();
kwargs...) where {iip}
_u0 = prepare_initial_state(u0)
_tspan = promote_tspan(tspan)
warn_paramtype(p)
new{typeof(u0), typeof(_tspan), isinplace(f, 6),
new{typeof(_u0), typeof(_tspan), isinplace(f, 6),
typeof(p),
typeof(f), typeof(kwargs)}(f,
u0,
_u0,
_tspan,
p,
kwargs)
Expand Down
18 changes: 11 additions & 7 deletions src/problems/ode_problems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -112,13 +112,14 @@ mutable struct ODEProblem{uType, tType, isinplace, P, F, K, PT} <:
u0, tspan, p = NullParameters(),
problem_type = StandardODEProblem();
kwargs...) where {iip}
_u0 = prepare_initial_state(u0)
_tspan = promote_tspan(tspan)
warn_paramtype(p)
new{typeof(u0), typeof(_tspan),
new{typeof(_u0), typeof(_tspan),
isinplace(f), typeof(p), typeof(f),
typeof(kwargs),
typeof(problem_type)}(f,
u0,
_u0,
_tspan,
p,
kwargs,
Expand All @@ -133,9 +134,10 @@ mutable struct ODEProblem{uType, tType, isinplace, P, F, K, PT} <:
This is determined automatically, but not inferred.
"""
function ODEProblem{iip}(f, u0, tspan, p = NullParameters(); kwargs...) where {iip}
_u0 = prepare_initial_state(u0)
_tspan = promote_tspan(tspan)
_f = ODEFunction{iip, DEFAULT_SPECIALIZATION}(f)
ODEProblem(_f, u0, _tspan, p; kwargs...)
ODEProblem(_f, _u0, _tspan, p; kwargs...)
end

@add_kwonly function ODEProblem{iip, recompile}(f, u0, tspan, p = NullParameters();
Expand All @@ -145,19 +147,20 @@ mutable struct ODEProblem{uType, tType, isinplace, P, F, K, PT} <:

function ODEProblem{iip, FunctionWrapperSpecialize}(f, u0, tspan, p = NullParameters();
kwargs...) where {iip}
_u0 = prepare_initial_state(u0)
_tspan = promote_tspan(tspan)
if !(f isa FunctionWrappersWrappers.FunctionWrappersWrapper)
if iip
ff = ODEFunction{iip, FunctionWrapperSpecialize}(wrapfun_iip(f,
(u0, u0, p,
(_u0, _u0, p,
_tspan[1])))
else
ff = ODEFunction{iip, FunctionWrapperSpecialize}(wrapfun_oop(f,
(u0, p,
(_u0, p,
_tspan[1])))
end
end
ODEProblem{iip}(ff, u0, _tspan, p; kwargs...)
ODEProblem{iip}(ff, _u0, _tspan, p; kwargs...)
end
end
TruncatedStacktraces.@truncate_stacktrace ODEProblem 3 1 2
Expand All @@ -183,9 +186,10 @@ end

function ODEProblem(f, u0, tspan, p = NullParameters(); kwargs...)
iip = isinplace(f, 4)
_u0 = prepare_initial_state(u0)
_tspan = promote_tspan(tspan)
_f = ODEFunction{iip, DEFAULT_SPECIALIZATION}(f)
ODEProblem(_f, u0, _tspan, p; kwargs...)
ODEProblem(_f, _u0, _tspan, p; kwargs...)
end

"""
Expand Down
5 changes: 3 additions & 2 deletions src/problems/rode_problems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,13 @@ mutable struct RODEProblem{uType, tType, isinplace, P, NP, F, K, ND} <:
rand_prototype = nothing,
noise = nothing, seed = UInt64(0),
kwargs...) where {iip}
_u0 = prepare_initial_state(u0)
_tspan = promote_tspan(tspan)
warn_paramtype(p)
new{typeof(u0), typeof(_tspan),
new{typeof(_u0), typeof(_tspan),
isinplace(f), typeof(p),
typeof(noise), typeof(f), typeof(kwargs),
typeof(rand_prototype)}(f, u0, _tspan, p, noise, kwargs,
typeof(rand_prototype)}(f, _u0, _tspan, p, noise, kwargs,
rand_prototype, seed)
end
function RODEProblem{iip}(f, u0, tspan, p = NullParameters(); kwargs...) where {iip}
Expand Down
5 changes: 3 additions & 2 deletions src/problems/sdde_problems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -126,12 +126,13 @@ struct SDDEProblem{uType, tType, lType, lType2, isinplace, P, NP, F, G, H, K, ND
det(f.mass_matrix) != 1,
order_discontinuity_t0 = 0 // 1,
kwargs...) where {iip}
_u0 = prepare_initial_state(u0)
_tspan = promote_tspan(tspan)
warn_paramtype(p)
new{typeof(u0), typeof(_tspan), typeof(constant_lags), typeof(dependent_lags),
new{typeof(_u0), typeof(_tspan), typeof(constant_lags), typeof(dependent_lags),
isinplace(f),
typeof(p), typeof(noise), typeof(f), typeof(g), typeof(h), typeof(kwargs),
typeof(noise_rate_prototype)}(f, g, u0, h, _tspan, p, noise, constant_lags,
typeof(noise_rate_prototype)}(f, g, _u0, h, _tspan, p, noise, constant_lags,
dependent_lags, kwargs, noise_rate_prototype,
seed, neutral, order_discontinuity_t0)
end
Expand Down
5 changes: 3 additions & 2 deletions src/problems/sde_problems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -99,13 +99,14 @@ struct SDEProblem{uType, tType, isinplace, P, NP, F, G, K, ND} <:
noise_rate_prototype = nothing,
noise = nothing, seed = UInt64(0),
kwargs...) where {iip}
_u0 = prepare_initial_state(u0)
_tspan = promote_tspan(tspan)
warn_paramtype(p)
new{typeof(u0), typeof(_tspan),
new{typeof(_u0), typeof(_tspan),
isinplace(f), typeof(p),
typeof(noise), typeof(f), typeof(f.g),
typeof(kwargs),
typeof(noise_rate_prototype)}(f, f.g, u0, _tspan, p,
typeof(noise_rate_prototype)}(f, f.g, _u0, _tspan, p,
noise, kwargs,
noise_rate_prototype, seed)
end
Expand Down
3 changes: 2 additions & 1 deletion src/problems/steady_state_problems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,9 @@ struct SteadyStateProblem{uType, isinplace, P, F, K} <:
@add_kwonly function SteadyStateProblem{iip}(f::AbstractODEFunction{iip},
u0, p = NullParameters();
kwargs...) where {iip}
_u0 = prepare_initial_state(u0)
warn_paramtype(p)
new{typeof(u0), isinplace(f), typeof(p), typeof(f), typeof(kwargs)}(f, u0, p,
new{typeof(_u0), isinplace(f), typeof(p), typeof(f), typeof(kwargs)}(f, _u0, p,
kwargs)
end

Expand Down
Loading

0 comments on commit a8b39e9

Please sign in to comment.