Skip to content

Commit

Permalink
Merge pull request #527 from LilithHafner/lh/promote
Browse files Browse the repository at this point in the history
Promote initial state when converting from non-homogeneous python lists.
  • Loading branch information
ChrisRackauckas authored Oct 21, 2023
2 parents 187e884 + 9cb74b4 commit eab57db
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 2 deletions.
12 changes: 10 additions & 2 deletions ext/SciMLBasePythonCallExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,18 @@ function SciMLBase.numargs(f::Py)
pyconvert(Int, length(first(inspect.getfullargspec(f2))) - inspect.ismethod(f2))
end

_pyconvert(x::Py) = pyisinstance(x, pybuiltins.list) ? [_pyconvert(x) for x in x] : pyconvert(Any, x)
_pyconvert(x::PyList) = [_pyconvert(x) for x in x]
_pyconvert(x::Py) = pyisinstance(x, pybuiltins.list) ? _promoting_collect(_pyconvert(x) for x in x) : pyconvert(Any, x)
_pyconvert(x::PyList) = _promoting_collect(_pyconvert(x) for x in x)
_pyconvert(x) = x

# _promoting_collect might copy its input
_promoting_collect(x) = _promoting_collect(collect(x))
function _promoting_collect(x::AbstractArray)
isconcretetype(eltype(x)) && return x
T = mapreduce(typeof, promote_type, x)
T == eltype(x) ? x : T.(x)
end

SciMLBase.prepare_initial_state(u0::Union{Py, PyList}) = _pyconvert(u0)
SciMLBase.prepare_function(f::Py) = _pyconvert f

Expand Down
5 changes: 5 additions & 0 deletions test/python/pythoncall.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,8 @@ using DifferentialEquations, PythonCall
sol = de.solve(prob,reltol=1e-3,abstol=1e-3)
""", @__MODULE__)
end

@testset "promotion" begin
_u0 = pyconvert(Any, pyeval("""de.SciMLBase.prepare_initial_state([1.0, 0, 0])""", @__MODULE__))
@test _u0 isa Vector{Float64}
end

0 comments on commit eab57db

Please sign in to comment.