From 943eba6dc2f682243b98f9488bd35826e8957fe8 Mon Sep 17 00:00:00 2001 From: Lilith Hafner Date: Fri, 20 Oct 2023 10:09:29 -0500 Subject: [PATCH 1/4] add test --- test/python/pythoncall.jl | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/test/python/pythoncall.jl b/test/python/pythoncall.jl index 3b13cac0c..f2e8ec427 100644 --- a/test/python/pythoncall.jl +++ b/test/python/pythoncall.jl @@ -77,3 +77,8 @@ using DifferentialEquations, PythonCall sol = de.solve(prob,reltol=1e-3,abstol=1e-3) """, @__MODULE__) end + +@testset "promotion" begin + _u0 = pyeval("""de.prepare_initial_state([1.0, 0, 0])""", @__MODULE__) + @test _u0 isa Vector{Float64} +end From 6b8aa170b01eafe5befe81646c1816074eab12db Mon Sep 17 00:00:00 2001 From: Lilith Hafner Date: Fri, 20 Oct 2023 10:13:16 -0500 Subject: [PATCH 2/4] implement fix --- ext/SciMLBasePythonCallExt.jl | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/ext/SciMLBasePythonCallExt.jl b/ext/SciMLBasePythonCallExt.jl index da7bf262e..7426e4037 100644 --- a/ext/SciMLBasePythonCallExt.jl +++ b/ext/SciMLBasePythonCallExt.jl @@ -13,10 +13,18 @@ function SciMLBase.numargs(f::Py) pyconvert(Int, length(first(inspect.getfullargspec(f2))) - inspect.ismethod(f2)) end -_pyconvert(x::Py) = pyisinstance(x, pybuiltins.list) ? [_pyconvert(x) for x in x] : pyconvert(Any, x) -_pyconvert(x::PyList) = [_pyconvert(x) for x in x] +_pyconvert(x::Py) = pyisinstance(x, pybuiltins.list) ? _promoting_collect(_pyconvert(x) for x in x) : pyconvert(Any, x) +_pyconvert(x::PyList) = _promoting_collect(_pyconvert(x) for x in x) _pyconvert(x) = x +# _promoting_collect might copy its input +_promoting_collect(x) = _promoting_collect(collect(x)) +function _promoting_collect(x::AbstractArray) + isconcretetype(eltype(x)) && return x + T = mapreduce(typeof, promote_type, x) + T == eltype(x) ? x : T.(x) +end + SciMLBase.prepare_initial_state(u0::Union{Py, PyList}) = _pyconvert(u0) SciMLBase.prepare_function(f::Py) = _pyconvert ∘ f From 8389c44d59da6b3531a9a55ed972e7c4f4eb7dc2 Mon Sep 17 00:00:00 2001 From: Lilith Hafner Date: Fri, 20 Oct 2023 11:05:58 -0500 Subject: [PATCH 3/4] fix tesy --- test/python/pythoncall.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/python/pythoncall.jl b/test/python/pythoncall.jl index f2e8ec427..8d4e71236 100644 --- a/test/python/pythoncall.jl +++ b/test/python/pythoncall.jl @@ -79,6 +79,6 @@ using DifferentialEquations, PythonCall end @testset "promotion" begin - _u0 = pyeval("""de.prepare_initial_state([1.0, 0, 0])""", @__MODULE__) + _u0 = pyeval("""de.SciMLBase.prepare_initial_state([1.0, 0, 0])""", @__MODULE__) @test _u0 isa Vector{Float64} end From 9cb74b4fc5e5a2eb32547a2e31c698638fdd1f67 Mon Sep 17 00:00:00 2001 From: Lilith Hafner Date: Fri, 20 Oct 2023 18:03:41 -0500 Subject: [PATCH 4/4] try again to fix test --- test/python/pythoncall.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/python/pythoncall.jl b/test/python/pythoncall.jl index 8d4e71236..7f0758be1 100644 --- a/test/python/pythoncall.jl +++ b/test/python/pythoncall.jl @@ -79,6 +79,6 @@ using DifferentialEquations, PythonCall end @testset "promotion" begin - _u0 = pyeval("""de.SciMLBase.prepare_initial_state([1.0, 0, 0])""", @__MODULE__) + _u0 = pyconvert(Any, pyeval("""de.SciMLBase.prepare_initial_state([1.0, 0, 0])""", @__MODULE__)) @test _u0 isa Vector{Float64} end