Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add PythonCall Extension #519

Merged
merged 8 commits into from
Oct 7, 2023

Conversation

LilithHafner
Copy link
Member

@LilithHafner LilithHafner commented Oct 3, 2023

The goal of this PR is to make PythonCall and the DifferentialEquations ecosystem fully compatible, making SciML/diffeqpy#118 trivial.

I think I've implemented the nontrivial design decisions that have to be made, so this is ready for review. If the design looks good and once #502 merges, I'll finish the details to get this to a mergeable state.

  • add prepare_u0 to all problems (currently just ODE and SDE)
  • add tests
  • add package extension boilerplate (e.g. update Project.toml)

@codecov
Copy link

codecov bot commented Oct 3, 2023

Codecov Report

Merging #519 (bfd023d) into master (5d0d7e0) will decrease coverage by 0.63%.
The diff coverage is 83.33%.

@@            Coverage Diff             @@
##           master     #519      +/-   ##
==========================================
- Coverage   54.25%   53.63%   -0.63%     
==========================================
  Files          51       52       +1     
  Lines        3854     3897      +43     
==========================================
- Hits         2091     2090       -1     
- Misses       1763     1807      +44     
Files Coverage Δ
ext/PythonCallExt.jl 100.00% <100.00%> (ø)
src/problems/analytical_problems.jl 100.00% <100.00%> (ø)
src/problems/bvp_problems.jl 33.33% <100.00%> (+1.90%) ⬆️
src/problems/dae_problems.jl 100.00% <100.00%> (ø)
src/problems/dde_problems.jl 26.19% <100.00%> (+1.80%) ⬆️
src/problems/discrete_problems.jl 72.00% <100.00%> (+1.16%) ⬆️
src/problems/rode_problems.jl 100.00% <100.00%> (ø)
src/problems/sdde_problems.jl 69.23% <100.00%> (+2.56%) ⬆️
src/problems/sde_problems.jl 58.97% <100.00%> (+1.07%) ⬆️
src/problems/steady_state_problems.jl 66.66% <100.00%> (-15.16%) ⬇️
... and 4 more

... and 5 files with indirect coverage changes

📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more

@ChrisRackauckas
Copy link
Member

I think this makes sense. The other place that it could be would be the DiffEqBase solve.jl, which intercepts right before solve and makes conversions so that the problem is solvable by a given solver (and does a bunch of error throws according to the options). However, that is done at a later time because those conversions can be dependent on the solver that is chosen. Here, this would be conversions that should just always happen. So I guess this is needed here, though it is a bit tedious to add it everywhere

@LilithHafner
Copy link
Member Author

Here are my internal TDD notes for posterity's sake and to draw from when writing tests
# Patch 0
using SciMLBase
SciMLBase.numargs(f::ComposedFunction) = SciMLBase.numargs(f.inner) # https://github.com/SciML/SciMLBase.jl/pull/506

# Test 1

using DifferentialEquations, PythonCall
pyexec("""
from juliacall import Main
de = Main.seval("DifferentialEquations")

def f(u,p,t):
    return -u

u0 = 0.5
tspan = (0., 1.)
prob = de.ODEProblem(f, u0, tspan)
sol = de.solve(prob)
""", @__MODULE__)
@test pyconvert(Any, pyeval("sol", @__MODULE__)) isa ODESolution

# Test 1 Failure 1: "Detected an in-place function with an initial condition of type Number or SArray."

# Patch 1

using PythonCall: Py, pyimport, hasproperty, pyconvert
using SciMLBase: SciMLBase

# SciML uses a function's arity (number of arguments) to determine if it operates in place.
# PythonCall does not preserve arity, so we inspect Python functions to find their arity.
function SciMLBase.numargs(f::Py)
    inspect = pyimport("inspect")
    f2 = hasproperty(f, :py_func) ? f.py_func : f
    # if `f` is a bound method (i.e., `self.f`), `getfullargspec` includes
    # `self` in the `args` list. So, we subtract 1 in that case:
    pyconvert(Int, length(first(inspect.getfullargspec(f2))) - inspect.ismethod(f2))
end

# Test 1 Failure 2 "ERROR: Python: Julia: MethodError: Cannot `convert` an object of type Py to an object of type Float64"

# Patch 2

function SciMLBase.ODEProblem(f::Py, u0, tspan, args...)
    ODEProblem(Base.Fix1(pyconvert, Any)  f, pyconvert(Any, u0), pyconvert(Any, tspan), pyconvert.(Any, args)...)
end

# Test 1 Pass.

# Test 2

pyexec("""
def f(u,p,t):
    x, y, z = u
    sigma, rho, beta = p
    return [sigma * (y - x), x * (rho - z) - y, x * y - beta * z]

u0 = [1.0,0.0,0.0]
tspan = (0., 100.)
p = [10.0,28.0,8/3]
prob = de.ODEProblem(f, u0, tspan, p)
sol = de.solve(prob,saveat=0.01)
""", @__MODULE__)
@test pyconvert(Any, pyeval("sol", @__MODULE__)) isa ODESolution

# Patch 3 (replaces patch 2)

using PythonCall: pyisinstance

_pyconvert(x::Py) = pyisinstance(x, pybuiltins.list) ? [_pyconvert(x) for x in x] : pyconvert(Any, x)
_pyconvert(x::PyList) = [_pyconvert(x) for x in x]
_pyconvert(x) = x

function SciMLBase.ODEProblem(f::Py, u0, tspan, args...)
    ODEProblem(_pyconvert  f, _pyconvert(u0), _pyconvert(tspan), pyconvert.(Any, args)...)
end

# Test 2 passes

# Test 2 continued

pyexec("""
import matplotlib.pyplot as plt

plt.plot(sol.t, de.transpose(de.stack(sol.u))) # :( fails without the conversion
plt.show()
""", @__MODULE__)

# Test 3

@pyexec """
jul_f = Main.seval(""\"
function f(du,u,p,t)
    x, y, z = u
    sigma, rho, beta = p
    du[1] = sigma * (y - x)
    du[2] = x * (rho - z) - y
    du[3] = x * y - beta * z
end""\")
u0 = [1.0,0.0,0.0]
tspan = (0., 100.)
p = [10.0,28.0,2.66]
prob = de.ODEProblem(jul_f, u0, tspan, p)
sol = de.solve(prob)
"""
@test pyconvert(Any, pyeval("sol", @__MODULE__)) isa ODESolution

# Test 3 failure 1: "ERROR: Python: Julia: MethodError: no method matching oneunit(::Type{Any})"

# Patch 4 (replaces patch 3)

using PythonCall: pyisinstance, Py, PyList, pybuiltins, pyconvert

_pyconvert(x::Py) = pyisinstance(x, pybuiltins.list) ? [_pyconvert(x) for x in x] : pyconvert(Any, x)
_pyconvert(x::PyList) = [_pyconvert(x) for x in x]
_pyconvert(x) = x

SciMLBase.prepare_u0(u0::Union{Py, PyList}) = _pyconvert(u0)
SciMLBase.prepare_f(f::Py) = _pyconvert  f

# upstreamed
@eval SciMLBase begin
prepare_u0(u0) = u0
prepare_f(f) = f

function ODEProblem(f::AbstractODEFunction, u0, tspan, args...; kwargs...)
    ODEProblem{isinplace(f)}(prepare_f(f), prepare_u0(u0), tspan, args...; kwargs...)
end

function ODEProblem(f, u0, tspan, p = NullParameters(); kwargs...)
    _f = prepare_f(f)
    iip = isinplace(_f, 4)
    _u0 = prepare_u0(u0)
    _tspan = promote_tspan(tspan)
    __f = ODEFunction{iip, DEFAULT_SPECIALIZATION}(_f)
    ODEProblem{isinplace(__f)}(__f, _u0, _tspan, p; kwargs...)
end

end

# Test 3 passes

# Test 4

pyexec("""
def f(u,p,t):
  return 1.01*u

def g(u,p,t):
  return 0.87*u

u0 = 0.5
tspan = (0.0,1.0)
prob = de.SDEProblem(f,g,u0,tspan)
sol = de.solve(prob,reltol=1e-3,abstol=1e-3)
""", @__MODULE__)

# Test 4 failure 1: "ERROR: Python: TypeError: 'float' object is not iterable"

# Patch 5 (upstreamed)

@eval SciMLBase begin
    function SDEProblem(f::AbstractSDEFunction, g, u0, tspan, p = NullParameters(); kwargs...)
        SDEProblem{isinplace(f)}(prepare_f(f), prepare_f(g), prepare_u0(u0), tspan, p; kwargs...)
    end

    function SDEProblem(f, g, u0, tspan, p = NullParameters(); kwargs...)
        _g = prepare_f(g)
        SDEProblem(SDEFunction(prepare_f(f), _g), _g, prepare_u0(u0), tspan, p; kwargs...)
    end
end

# Test 4 Pass.

# Patch Summary

# SciMLBase
numargs(f::ComposedFunction) = numargs(f.inner) # https://github.com/SciML/SciMLBase.jl/pull/506

"""
    prepare_initial_state(u0) = u0

Whenever an initial state is passed to the SciML ecosystem, is passed to
`prepare_initial_state` and the result is used instead. If you define a
type which cannot be used as a state but can be converted to something that
can be, then you may define `prepare_initial_state(x::YourType) = ...`.

!!! warning
    This function is experimental and may be removed in the future.

See also: `prepare_function`.
"""
prepare_initial_state(u0) = u0

"""
    prepare_function(f) = f

Whenever a function is passed to the SciML ecosystem, is passed to
`prepare_function` and the result is used instead. If you define a type which
cannot be used as a function in the SciML ecosystem but can be converted to
something that can be, then you may define `prepare_function(x::YourType) = ...`.

!!! warning
    This function is experimental and may be removed in the future.

See also: `prepare_initial_state`.
"""
prepare_function(f) = f

# begin approx

function ODEProblem(f::AbstractODEFunction, u0, tspan, args...; kwargs...)
    ODEProblem{isinplace(f)}(f, prepare_initial_state(u0), tspan, args...; kwargs...)
end

function ODEFunction(f; kwargs...)
    _f = prepare_function(f)
    ODEFunction{isinplace(_f, 4), FullSpecialize}(_f; kwargs...)
end

function SDEProblem(f::AbstractSDEFunction, g, u0, tspan, p = NullParameters(); kwargs...)
    SDEProblem{isinplace(f)}(f, g, prepare_initial_state(u0), tspan, p; kwargs...)
end

function SDEProblem(f, g, u0, tspan, p = NullParameters(); kwargs...)
    _f = prepare_function(f)
    _g = prepare_function(g)
    SDEProblem(SDEFunction(_f, _g), _g, u0, tspan, p; kwargs...)
end

...

# end approx

# SciMLBase / PythonCall extension

using PythonCall: Py, PyList, pyimport, hasproperty, pyconvert, pyisinstance, pybuiltins
using SciMLBase: SciMLBase

# SciML uses a function's arity (number of arguments) to determine if it operates in place.
# PythonCall does not preserve arity, so we inspect Python functions to find their arity.
function SciMLBase.numargs(f::Py)
    inspect = pyimport("inspect")
    f2 = hasproperty(f, :py_func) ? f.py_func : f
    # if `f` is a bound method (i.e., `self.f`), `getfullargspec` includes
    # `self` in the `args` list. So, we subtract 1 in that case:
    pyconvert(Int, length(first(inspect.getfullargspec(f2))) - inspect.ismethod(f2))
end

_pyconvert(x::Py) = pyisinstance(x, pybuiltins.list) ? [_pyconvert(x) for x in x] : pyconvert(Any, x)
_pyconvert(x::PyList) = [_pyconvert(x) for x in x]
_pyconvert(x) = x

SciMLBase.prepare_initial_state(u0::Union{Py, PyList}) = _pyconvert(u0)
SciMLBase.prepare_function(f::Py) = _pyconvert  f

TODO: add prepare_u0 to all problems (currently just ODE and SDE)
TODO: add tests
TODO: add package extension boilerplate (e.g. update Project.toml)
@LilithHafner LilithHafner force-pushed the lh/PythonCall-extension branch from ff257cf to be6174a Compare October 5, 2023 20:32
@LilithHafner
Copy link
Member Author

Force push was a clean rebase onto master

@LilithHafner LilithHafner marked this pull request as draft October 5, 2023 20:36
@LilithHafner
Copy link
Member Author

However, that is done at a later time because those conversions can be dependent on the solver that is chosen

I agree that delaying these conversions is, unfortunately, probably not a great idea. For example, I think it would be reasonable when choosing a solver to perform a query on u0 or f that would fail if we hadn't converted from Python yet (e.g. any(isnan, u0) fails on Python floats)

Also, looking at DiffEqBase/src/solve.jl, it seems that it would still be a bit messy to extract u0 and all user functions and convert them.

I'll proceed with adding these conversions to all entrypoints I can find.

@ChrisRackauckas
Copy link
Member

Makes sense

@LilithHafner LilithHafner marked this pull request as ready for review October 6, 2023 17:37
@LilithHafner
Copy link
Member Author

CodeCov claims this has pretty high patch coverage, but that is sort of a lie. In theory, this PR enables full usage of all of DifferentialEquations via PythonCall. To test that claim would require rewriting all downstream tests in Python. That's probably not worth doing, but I want to be clear that if I failed to insert a call to convert_initial_state somewhere, or someone else removes some of those calls later, that will not be caught by CI. I can add more tests if you think they are necessary.

@ChrisRackauckas
Copy link
Member

@ChrisRackauckas ChrisRackauckas merged commit a8b39e9 into SciML:master Oct 7, 2023
61 of 70 checks passed
@ErikQQY
Copy link
Member

ErikQQY commented Oct 7, 2023

The failing tests are about TwoPointBVPFunction remake, we still need a dispatch for remake, maybe continue #517?

@ChrisRackauckas
Copy link
Member

That PR is stale though since the function form was updated and the bc parts were already removed from the problem, but yes something like that PR but for the updated function form is required.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants