Skip to content

Commit

Permalink
Add initial draft
Browse files Browse the repository at this point in the history
TODO: add prepare_u0 to all problems (currently just ODE and SDE)
TODO: add tests
TODO: add package extension boilerplate (e.g. update Project.toml)
  • Loading branch information
Lilith Hafner authored and Lilith Hafner committed Oct 5, 2023
1 parent b974ce5 commit be6174a
Show file tree
Hide file tree
Showing 5 changed files with 143 additions and 53 deletions.
23 changes: 23 additions & 0 deletions ext/PythonCallExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
module PythonCallExt

using PythonCall: Py, PyList, pyimport, hasproperty, pyconvert, pyisinstance, pybuiltins
using SciMLBase: SciMLBase

# SciML uses a function's arity (number of arguments) to determine if it operates in place.
# PythonCall does not preserve arity, so we inspect Python functions to find their arity.
function SciMLBase.numargs(f::Py)
inspect = pyimport("inspect")
f2 = hasproperty(f, :py_func) ? f.py_func : f

Check warning on line 10 in ext/PythonCallExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/PythonCallExt.jl#L8-L10

Added lines #L8 - L10 were not covered by tests
# if `f` is a bound method (i.e., `self.f`), `getfullargspec` includes
# `self` in the `args` list. So, we subtract 1 in that case:
pyconvert(Int, length(first(inspect.getfullargspec(f2))) - inspect.ismethod(f2))

Check warning on line 13 in ext/PythonCallExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/PythonCallExt.jl#L13

Added line #L13 was not covered by tests
end

_pyconvert(x::Py) = pyisinstance(x, pybuiltins.list) ? [_pyconvert(x) for x in x] : pyconvert(Any, x)
_pyconvert(x::PyList) = [_pyconvert(x) for x in x]
_pyconvert(x) = x

Check warning on line 18 in ext/PythonCallExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/PythonCallExt.jl#L16-L18

Added lines #L16 - L18 were not covered by tests

SciMLBase.prepare_initial_state(u0::Union{Py, PyList}) = _pyconvert(u0)
SciMLBase.prepare_function(f::Py) = _pyconvert f

Check warning on line 21 in ext/PythonCallExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/PythonCallExt.jl#L20-L21

Added lines #L20 - L21 were not covered by tests

end
18 changes: 11 additions & 7 deletions src/problems/ode_problems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -112,13 +112,14 @@ mutable struct ODEProblem{uType, tType, isinplace, P, F, K, PT} <:
u0, tspan, p = NullParameters(),
problem_type = StandardODEProblem();
kwargs...) where {iip}
_u0 = prepare_initial_state(u0)
_tspan = promote_tspan(tspan)
warn_paramtype(p)
new{typeof(u0), typeof(_tspan),
new{typeof(_u0), typeof(_tspan),
isinplace(f), typeof(p), typeof(f),
typeof(kwargs),
typeof(problem_type)}(f,
u0,
_u0,
_tspan,
p,
kwargs,
Expand All @@ -133,9 +134,10 @@ mutable struct ODEProblem{uType, tType, isinplace, P, F, K, PT} <:
This is determined automatically, but not inferred.
"""
function ODEProblem{iip}(f, u0, tspan, p = NullParameters(); kwargs...) where {iip}
_u0 = prepare_initial_state(u0)
_tspan = promote_tspan(tspan)
_f = ODEFunction{iip, DEFAULT_SPECIALIZATION}(f)
ODEProblem(_f, u0, _tspan, p; kwargs...)
ODEProblem(_f, _u0, _tspan, p; kwargs...)
end

@add_kwonly function ODEProblem{iip, recompile}(f, u0, tspan, p = NullParameters();
Expand All @@ -145,19 +147,20 @@ mutable struct ODEProblem{uType, tType, isinplace, P, F, K, PT} <:

function ODEProblem{iip, FunctionWrapperSpecialize}(f, u0, tspan, p = NullParameters();
kwargs...) where {iip}
_u0 = prepare_initial_state(u0)
_tspan = promote_tspan(tspan)
if !(f isa FunctionWrappersWrappers.FunctionWrappersWrapper)
if iip
ff = ODEFunction{iip, FunctionWrapperSpecialize}(wrapfun_iip(f,
(u0, u0, p,
(_u0, _u0, p,
_tspan[1])))
else
ff = ODEFunction{iip, FunctionWrapperSpecialize}(wrapfun_oop(f,
(u0, p,
(_u0, p,
_tspan[1])))
end
end
ODEProblem{iip}(ff, u0, _tspan, p; kwargs...)
ODEProblem{iip}(ff, _u0, _tspan, p; kwargs...)
end
end
TruncatedStacktraces.@truncate_stacktrace ODEProblem 3 1 2
Expand All @@ -183,9 +186,10 @@ end

function ODEProblem(f, u0, tspan, p = NullParameters(); kwargs...)
iip = isinplace(f, 4)
_u0 = prepare_initial_state(u0)
_tspan = promote_tspan(tspan)
_f = ODEFunction{iip, DEFAULT_SPECIALIZATION}(f)
ODEProblem(_f, u0, _tspan, p; kwargs...)
ODEProblem(_f, _u0, _tspan, p; kwargs...)
end

"""
Expand Down
5 changes: 3 additions & 2 deletions src/problems/sde_problems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -99,13 +99,14 @@ struct SDEProblem{uType, tType, isinplace, P, NP, F, G, K, ND} <:
noise_rate_prototype = nothing,
noise = nothing, seed = UInt64(0),
kwargs...) where {iip}
_u0 = prepare_initial_state(u0)
_tspan = promote_tspan(tspan)
warn_paramtype(p)
new{typeof(u0), typeof(_tspan),
new{typeof(_u0), typeof(_tspan),
isinplace(f), typeof(p),
typeof(noise), typeof(f), typeof(f.g),
typeof(kwargs),
typeof(noise_rate_prototype)}(f, f.g, u0, _tspan, p,
typeof(noise_rate_prototype)}(f, f.g, _u0, _tspan, p,
noise, kwargs,
noise_rate_prototype, seed)
end
Expand Down
Loading

0 comments on commit be6174a

Please sign in to comment.