From be6174a146999ca9b3f2285e9f1b1550dd7f0800 Mon Sep 17 00:00:00 2001 From: Lilith Hafner Date: Sat, 23 Sep 2023 19:17:19 -0500 Subject: [PATCH 1/7] Add initial draft TODO: add prepare_u0 to all problems (currently just ODE and SDE) TODO: add tests TODO: add package extension boilerplate (e.g. update Project.toml) --- ext/PythonCallExt.jl | 23 +++++++ src/problems/ode_problems.jl | 18 +++--- src/problems/sde_problems.jl | 5 +- src/scimlfunctions.jl | 117 ++++++++++++++++++++++------------- src/utils.jl | 33 ++++++++++ 5 files changed, 143 insertions(+), 53 deletions(-) create mode 100644 ext/PythonCallExt.jl diff --git a/ext/PythonCallExt.jl b/ext/PythonCallExt.jl new file mode 100644 index 000000000..e69096017 --- /dev/null +++ b/ext/PythonCallExt.jl @@ -0,0 +1,23 @@ +module PythonCallExt + +using PythonCall: Py, PyList, pyimport, hasproperty, pyconvert, pyisinstance, pybuiltins +using SciMLBase: SciMLBase + +# SciML uses a function's arity (number of arguments) to determine if it operates in place. +# PythonCall does not preserve arity, so we inspect Python functions to find their arity. +function SciMLBase.numargs(f::Py) + inspect = pyimport("inspect") + f2 = hasproperty(f, :py_func) ? f.py_func : f + # if `f` is a bound method (i.e., `self.f`), `getfullargspec` includes + # `self` in the `args` list. So, we subtract 1 in that case: + pyconvert(Int, length(first(inspect.getfullargspec(f2))) - inspect.ismethod(f2)) +end + +_pyconvert(x::Py) = pyisinstance(x, pybuiltins.list) ? [_pyconvert(x) for x in x] : pyconvert(Any, x) +_pyconvert(x::PyList) = [_pyconvert(x) for x in x] +_pyconvert(x) = x + +SciMLBase.prepare_initial_state(u0::Union{Py, PyList}) = _pyconvert(u0) +SciMLBase.prepare_function(f::Py) = _pyconvert ∘ f + +end diff --git a/src/problems/ode_problems.jl b/src/problems/ode_problems.jl index ce4425db3..75355d983 100644 --- a/src/problems/ode_problems.jl +++ b/src/problems/ode_problems.jl @@ -112,13 +112,14 @@ mutable struct ODEProblem{uType, tType, isinplace, P, F, K, PT} <: u0, tspan, p = NullParameters(), problem_type = StandardODEProblem(); kwargs...) where {iip} + _u0 = prepare_initial_state(u0) _tspan = promote_tspan(tspan) warn_paramtype(p) - new{typeof(u0), typeof(_tspan), + new{typeof(_u0), typeof(_tspan), isinplace(f), typeof(p), typeof(f), typeof(kwargs), typeof(problem_type)}(f, - u0, + _u0, _tspan, p, kwargs, @@ -133,9 +134,10 @@ mutable struct ODEProblem{uType, tType, isinplace, P, F, K, PT} <: This is determined automatically, but not inferred. """ function ODEProblem{iip}(f, u0, tspan, p = NullParameters(); kwargs...) where {iip} + _u0 = prepare_initial_state(u0) _tspan = promote_tspan(tspan) _f = ODEFunction{iip, DEFAULT_SPECIALIZATION}(f) - ODEProblem(_f, u0, _tspan, p; kwargs...) + ODEProblem(_f, _u0, _tspan, p; kwargs...) end @add_kwonly function ODEProblem{iip, recompile}(f, u0, tspan, p = NullParameters(); @@ -145,19 +147,20 @@ mutable struct ODEProblem{uType, tType, isinplace, P, F, K, PT} <: function ODEProblem{iip, FunctionWrapperSpecialize}(f, u0, tspan, p = NullParameters(); kwargs...) where {iip} + _u0 = prepare_initial_state(u0) _tspan = promote_tspan(tspan) if !(f isa FunctionWrappersWrappers.FunctionWrappersWrapper) if iip ff = ODEFunction{iip, FunctionWrapperSpecialize}(wrapfun_iip(f, - (u0, u0, p, + (_u0, _u0, p, _tspan[1]))) else ff = ODEFunction{iip, FunctionWrapperSpecialize}(wrapfun_oop(f, - (u0, p, + (_u0, p, _tspan[1]))) end end - ODEProblem{iip}(ff, u0, _tspan, p; kwargs...) + ODEProblem{iip}(ff, _u0, _tspan, p; kwargs...) end end TruncatedStacktraces.@truncate_stacktrace ODEProblem 3 1 2 @@ -183,9 +186,10 @@ end function ODEProblem(f, u0, tspan, p = NullParameters(); kwargs...) iip = isinplace(f, 4) + _u0 = prepare_initial_state(u0) _tspan = promote_tspan(tspan) _f = ODEFunction{iip, DEFAULT_SPECIALIZATION}(f) - ODEProblem(_f, u0, _tspan, p; kwargs...) + ODEProblem(_f, _u0, _tspan, p; kwargs...) end """ diff --git a/src/problems/sde_problems.jl b/src/problems/sde_problems.jl index b3840db48..445a9e8d2 100644 --- a/src/problems/sde_problems.jl +++ b/src/problems/sde_problems.jl @@ -99,13 +99,14 @@ struct SDEProblem{uType, tType, isinplace, P, NP, F, G, K, ND} <: noise_rate_prototype = nothing, noise = nothing, seed = UInt64(0), kwargs...) where {iip} + _u0 = prepare_initial_state(u0) _tspan = promote_tspan(tspan) warn_paramtype(p) - new{typeof(u0), typeof(_tspan), + new{typeof(_u0), typeof(_tspan), isinplace(f), typeof(p), typeof(noise), typeof(f), typeof(f.g), typeof(kwargs), - typeof(noise_rate_prototype)}(f, f.g, u0, _tspan, p, + typeof(noise_rate_prototype)}(f, f.g, _u0, _tspan, p, noise, kwargs, noise_rate_prototype, seed) end diff --git a/src/scimlfunctions.jl b/src/scimlfunctions.jl index c74808df2..a7b244217 100644 --- a/src/scimlfunctions.jl +++ b/src/scimlfunctions.jl @@ -2515,6 +2515,8 @@ function ODEFunction{iip, specialize}(f; throw(NonconformingFunctionsError(functions)) end + _f = prepare_function(f) + if specialize === NoSpecialize ODEFunction{iip, specialize, Any, Any, Any, Any, @@ -2522,31 +2524,31 @@ function ODEFunction{iip, specialize}(f; typeof(sparsity), Any, Any, typeof(W_prototype), Any, typeof(syms), typeof(indepsym), typeof(paramsyms), Any, typeof(_colorvec), - typeof(sys)}(f, mass_matrix, analytic, tgrad, jac, + typeof(sys)}(_f, mass_matrix, analytic, tgrad, jac, jvp, vjp, jac_prototype, sparsity, Wfact, Wfact_t, W_prototype, paramjac, syms, indepsym, paramsyms, observed, _colorvec, sys) elseif specialize === false ODEFunction{iip, FunctionWrapperSpecialize, - typeof(f), typeof(mass_matrix), typeof(analytic), typeof(tgrad), + typeof(_f), typeof(mass_matrix), typeof(analytic), typeof(tgrad), typeof(jac), typeof(jvp), typeof(vjp), typeof(jac_prototype), typeof(sparsity), typeof(Wfact), typeof(Wfact_t), typeof(W_prototype), typeof(paramjac), typeof(syms), typeof(indepsym), typeof(paramsyms), typeof(observed), typeof(_colorvec), - typeof(sys)}(f, mass_matrix, analytic, tgrad, jac, + typeof(sys)}(_f, mass_matrix, analytic, tgrad, jac, jvp, vjp, jac_prototype, sparsity, Wfact, Wfact_t, W_prototype, paramjac, syms, indepsym, paramsyms, observed, _colorvec, sys) else ODEFunction{iip, specialize, - typeof(f), typeof(mass_matrix), typeof(analytic), typeof(tgrad), + typeof(_f), typeof(mass_matrix), typeof(analytic), typeof(tgrad), typeof(jac), typeof(jvp), typeof(vjp), typeof(jac_prototype), typeof(sparsity), typeof(Wfact), typeof(Wfact_t), typeof(W_prototype), typeof(paramjac), typeof(syms), typeof(indepsym), typeof(paramsyms), typeof(observed), typeof(_colorvec), - typeof(sys)}(f, mass_matrix, analytic, tgrad, jac, + typeof(sys)}(_f, mass_matrix, analytic, tgrad, jac, jvp, vjp, jac_prototype, sparsity, Wfact, Wfact_t, W_prototype, paramjac, syms, indepsym, paramsyms, observed, _colorvec, sys) @@ -2875,15 +2877,16 @@ function DiscreteFunction{iip, specialize}(f; sys = __has_sys(f) ? f.sys : nothing) where {iip, specialize, } + _f = prepare_function(f) if specialize === NoSpecialize - DiscreteFunction{iip, specialize, Any, Any, Any, Any, Any, Any, Any}(f, analytic, + DiscreteFunction{iip, specialize, Any, Any, Any, Any, Any, Any, Any}(_f, analytic, syms, indepsym, parasmsyms, observed, sys) else - DiscreteFunction{iip, specialize, typeof(f), typeof(analytic), + DiscreteFunction{iip, specialize, typeof(_f), typeof(analytic), typeof(syms), typeof(indepsym), typeof(paramsyms), - typeof(observed), typeof(sys)}(f, analytic, syms, indepsym, + typeof(observed), typeof(sys)}(_f, analytic, syms, indepsym, paramsyms, observed, sys) end end @@ -2931,8 +2934,9 @@ function ImplicitDiscreteFunction{iip, specialize}(f; iip, specialize, } + _f = prepare_function(f) if specialize === NoSpecialize - ImplicitDiscreteFunction{iip, specialize, Any, Any, Any, Any, Any, Any, Any}(f, + ImplicitDiscreteFunction{iip, specialize, Any, Any, Any, Any, Any, Any, Any}(_f, analytic, syms, indepsym, @@ -2940,9 +2944,9 @@ function ImplicitDiscreteFunction{iip, specialize}(f; observed, sys) else - ImplicitDiscreteFunction{iip, specialize, typeof(f), typeof(analytic), + ImplicitDiscreteFunction{iip, specialize, typeof(_f), typeof(analytic), typeof(syms), typeof(indepsym), typeof(paramsyms), - typeof(observed), typeof(sys)}(f, analytic, syms, indepsym, + typeof(observed), typeof(sys)}(_f, analytic, syms, indepsym, paramsyms, observed, sys) end end @@ -3034,24 +3038,26 @@ function SDEFunction{iip, specialize}(f, g; throw(NonconformingFunctionsError(functions)) end + _f = prepare_function(f) + _g = prepare_function(g) if specialize === NoSpecialize SDEFunction{iip, specialize, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, typeof(syms), typeof(indepsym), typeof(paramsyms), Any, - typeof(_colorvec), typeof(sys)}(f, g, mass_matrix, analytic, + typeof(_colorvec), typeof(sys)}(_f, _g, mass_matrix, analytic, tgrad, jac, jvp, vjp, jac_prototype, sparsity, Wfact, Wfact_t, paramjac, ggprime, syms, indepsym, paramsyms, observed, _colorvec, sys) else - SDEFunction{iip, specialize, typeof(f), typeof(g), + SDEFunction{iip, specialize, typeof(_f), typeof(_g), typeof(mass_matrix), typeof(analytic), typeof(tgrad), typeof(jac), typeof(jvp), typeof(vjp), typeof(jac_prototype), typeof(sparsity), typeof(Wfact), typeof(Wfact_t), typeof(paramjac), typeof(ggprime), typeof(syms), typeof(indepsym), typeof(paramsyms), - typeof(observed), typeof(_colorvec), typeof(sys)}(f, g, mass_matrix, + typeof(observed), typeof(_colorvec), typeof(sys)}(_f, _g, mass_matrix, analytic, tgrad, jac, jvp, vjp, jac_prototype, @@ -3361,11 +3367,13 @@ function RODEFunction{iip, specialize}(f; end =# + _f = prepare_function(f) + if specialize === NoSpecialize RODEFunction{iip, specialize, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, typeof(syms), typeof(indepsym), typeof(paramsyms), Any, - typeof(_colorvec), Any}(f, mass_matrix, analytic, + typeof(_colorvec), Any}(_f, mass_matrix, analytic, tgrad, jac, jvp, vjp, jac_prototype, @@ -3375,13 +3383,13 @@ function RODEFunction{iip, specialize}(f; _colorvec, sys, analytic_full) else - RODEFunction{iip, specialize, typeof(f), typeof(mass_matrix), + RODEFunction{iip, specialize, typeof(_f), typeof(mass_matrix), typeof(analytic), typeof(tgrad), typeof(jac), typeof(jvp), typeof(vjp), typeof(jac_prototype), typeof(sparsity), typeof(Wfact), typeof(Wfact_t), typeof(paramjac), typeof(syms), typeof(indepsym), typeof(paramsyms), typeof(observed), typeof(_colorvec), - typeof(sys)}(f, mass_matrix, analytic, tgrad, + typeof(sys)}(_f, mass_matrix, analytic, tgrad, jac, jvp, vjp, jac_prototype, sparsity, Wfact, Wfact_t, paramjac, syms, indepsym, paramsyms, observed, _colorvec, sys, analytic_full) @@ -3447,23 +3455,25 @@ function DAEFunction{iip, specialize}(f; throw(NonconformingFunctionsError(functions)) end + _f = prepare_function(f) + if specialize === NoSpecialize DAEFunction{iip, specialize, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, typeof(syms), typeof(indepsym), typeof(paramsyms), - Any, typeof(_colorvec), Any}(f, analytic, tgrad, jac, jvp, + Any, typeof(_colorvec), Any}(_f, analytic, tgrad, jac, jvp, vjp, jac_prototype, sparsity, Wfact, Wfact_t, paramjac, syms, indepsym, paramsyms, observed, _colorvec, sys) else - DAEFunction{iip, specialize, typeof(f), typeof(analytic), typeof(tgrad), + DAEFunction{iip, specialize, typeof(_f), typeof(analytic), typeof(tgrad), typeof(jac), typeof(jvp), typeof(vjp), typeof(jac_prototype), typeof(sparsity), typeof(Wfact), typeof(Wfact_t), typeof(paramjac), typeof(syms), typeof(indepsym), typeof(paramsyms), typeof(observed), typeof(_colorvec), - typeof(sys)}(f, analytic, tgrad, jac, jvp, vjp, + typeof(sys)}(_f, analytic, tgrad, jac, jvp, vjp, jac_prototype, sparsity, Wfact, Wfact_t, paramjac, syms, indepsym, paramsyms, observed, _colorvec, sys) @@ -3534,11 +3544,13 @@ function DDEFunction{iip, specialize}(f; throw(NonconformingFunctionsError(functions)) end + _f = prepare_function(f) + if specialize === NoSpecialize DDEFunction{iip, specialize, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, typeof(syms), typeof(indepsym), typeof(paramsyms), - Any, typeof(_colorvec), Any}(f, mass_matrix, + Any, typeof(_colorvec), Any}(_f, mass_matrix, analytic, tgrad, jac, jvp, vjp, @@ -3550,13 +3562,13 @@ function DDEFunction{iip, specialize}(f; observed, _colorvec, sys) else - DDEFunction{iip, specialize, typeof(f), typeof(mass_matrix), typeof(analytic), + DDEFunction{iip, specialize, typeof(_f), typeof(mass_matrix), typeof(analytic), typeof(tgrad), typeof(jac), typeof(jvp), typeof(vjp), typeof(jac_prototype), typeof(sparsity), typeof(Wfact), typeof(Wfact_t), typeof(paramjac), typeof(syms), typeof(indepsym), typeof(paramsyms), typeof(observed), - typeof(_colorvec), typeof(sys)}(f, mass_matrix, analytic, + typeof(_colorvec), typeof(sys)}(_f, mass_matrix, analytic, tgrad, jac, jvp, vjp, jac_prototype, sparsity, Wfact, Wfact_t, paramjac, @@ -3706,11 +3718,14 @@ function SDDEFunction{iip, specialize}(f, g; _colorvec = colorvec end + _f = prepare_function(f) + _g = prepare_function(g) + if specialize === NoSpecialize SDDEFunction{iip, specialize, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, typeof(syms), typeof(indepsym), typeof(paramsyms), - Any, typeof(_colorvec), Any}(f, g, mass_matrix, + Any, typeof(_colorvec), Any}(_f, _g, mass_matrix, analytic, tgrad, jac, jvp, @@ -3724,13 +3739,13 @@ function SDDEFunction{iip, specialize}(f, g; _colorvec, sys) else - SDDEFunction{iip, specialize, typeof(f), typeof(g), + SDDEFunction{iip, specialize, typeof(_f), typeof(_g), typeof(mass_matrix), typeof(analytic), typeof(tgrad), typeof(jac), typeof(jvp), typeof(vjp), typeof(jac_prototype), typeof(sparsity), typeof(Wfact), typeof(Wfact_t), typeof(paramjac), typeof(ggprime), typeof(syms), typeof(indepsym), typeof(paramsyms), typeof(observed), - typeof(_colorvec), typeof(sys)}(f, g, mass_matrix, + typeof(_colorvec), typeof(sys)}(_f, _g, mass_matrix, analytic, tgrad, jac, jvp, vjp, jac_prototype, sparsity, Wfact, @@ -3810,12 +3825,14 @@ function NonlinearFunction{iip, specialize}(f; throw(NonconformingFunctionsError(functions)) end + _f = prepare_function(f) + if specialize === NoSpecialize NonlinearFunction{iip, specialize, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, typeof(syms), typeof(paramsyms), Any, - typeof(_colorvec), Any, Any}(f, mass_matrix, + typeof(_colorvec), Any, Any}(_f, mass_matrix, analytic, tgrad, jac, jvp, vjp, jac_prototype, @@ -3825,13 +3842,13 @@ function NonlinearFunction{iip, specialize}(f; _colorvec, sys, resid_prototype) else NonlinearFunction{iip, specialize, - typeof(f), typeof(mass_matrix), typeof(analytic), typeof(tgrad), + typeof(_f), typeof(mass_matrix), typeof(analytic), typeof(tgrad), typeof(jac), typeof(jvp), typeof(vjp), typeof(jac_prototype), typeof(sparsity), typeof(Wfact), typeof(Wfact_t), typeof(paramjac), typeof(syms), typeof(paramsyms), typeof(observed), - typeof(_colorvec), typeof(sys), typeof(resid_prototype)}(f, mass_matrix, + typeof(_colorvec), typeof(sys), typeof(resid_prototype)}(_f, mass_matrix, analytic, tgrad, jac, jvp, vjp, jac_prototype, sparsity, Wfact, @@ -3865,10 +3882,12 @@ function IntervalNonlinearFunction{iip, specialize}(f; iip, specialize, } + _f = prepare_function(f) + if specialize === NoSpecialize IntervalNonlinearFunction{iip, specialize, Any, Any, typeof(syms), typeof(paramsyms), Any, - typeof(_colorvec), Any}(f, mass_matrix, + typeof(_colorvec), Any}(_f, mass_matrix, analytic, tgrad, jac, jvp, vjp, jac_prototype, @@ -3878,10 +3897,10 @@ function IntervalNonlinearFunction{iip, specialize}(f; _colorvec, sys) else IntervalNonlinearFunction{iip, specialize, - typeof(f), typeof(analytic), typeof(syms), + typeof(_f), typeof(analytic), typeof(syms), typeof(paramsyms), typeof(observed), - typeof(sys)}(f, analytic, syms, + typeof(sys)}(_f, analytic, syms, paramsyms, observed, sys) end @@ -3922,8 +3941,9 @@ function OptimizationFunction{iip}(f, adtype::AbstractADType = NoAD(); lag_hess_colorvec = nothing, expr = nothing, cons_expr = nothing, sys = __has_sys(f) ? f.sys : nothing) where {iip} - isinplace(f, 2; has_two_dispatches = false, isoptimization = true) - OptimizationFunction{iip, typeof(adtype), typeof(f), typeof(grad), typeof(hess), + _f = prepare_function(f) + isinplace(_f, 2; has_two_dispatches = false, isoptimization = true) + OptimizationFunction{iip, typeof(adtype), typeof(_f), typeof(grad), typeof(hess), typeof(hv), typeof(cons), typeof(cons_j), typeof(cons_h), typeof(lag_h), typeof(hess_prototype), @@ -3933,7 +3953,7 @@ function OptimizationFunction{iip}(f, adtype::AbstractADType = NoAD(); typeof(hess_colorvec), typeof(cons_jac_colorvec), typeof(cons_hess_colorvec), typeof(lag_hess_colorvec), typeof(expr), typeof(cons_expr), - typeof(sys)}(f, adtype, grad, hess, + typeof(sys)}(_f, adtype, grad, hess, hv, cons, cons_j, cons_h, lag_h, hess_prototype, cons_jac_prototype, cons_hess_prototype, lag_hess_prototype, syms, @@ -4039,23 +4059,25 @@ function BVPFunction{iip, specialize, twopoint}(f, bc; throw(NonconformingFunctionsError(functions)) end + _f = prepare_function(f) + if specialize === NoSpecialize BVPFunction{iip, specialize, twopoint, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, typeof(syms), typeof(indepsym), typeof(paramsyms), - Any, typeof(_colorvec), typeof(_bccolorvec), Any}(f, bc, mass_matrix, + Any, typeof(_colorvec), typeof(_bccolorvec), Any}(_f, bc, mass_matrix, analytic, tgrad, jac, bcjac, jvp, vjp, jac_prototype, bcjac_prototype, bcresid_prototype, sparsity, Wfact, Wfact_t, paramjac, syms, indepsym, paramsyms, observed, _colorvec, _bccolorvec, sys) else - BVPFunction{iip, specialize, twopoint, typeof(f), typeof(bc), typeof(mass_matrix), + BVPFunction{iip, specialize, twopoint, typeof(_f), typeof(bc), typeof(mass_matrix), typeof(analytic), typeof(tgrad), typeof(jac), typeof(bcjac), typeof(jvp), typeof(vjp), typeof(jac_prototype), typeof(bcjac_prototype), typeof(bcresid_prototype), typeof(sparsity), typeof(Wfact), typeof(Wfact_t), typeof(paramjac), typeof(syms), typeof(indepsym), typeof(paramsyms), typeof(observed), - typeof(_colorvec), typeof(_bccolorvec), typeof(sys)}(f, bc, mass_matrix, analytic, + typeof(_colorvec), typeof(_bccolorvec), typeof(sys)}(_f, bc, mass_matrix, analytic, tgrad, jac, bcjac, jvp, vjp, jac_prototype, bcjac_prototype, bcresid_prototype, sparsity, Wfact, Wfact_t, paramjac, @@ -4074,12 +4096,13 @@ end BVPFunction(f::BVPFunction; kwargs...) = f function IntegralFunction{iip, specialize}(f, integrand_prototype) where {iip, specialize} - IntegralFunction{iip, specialize, typeof(f), typeof(integrand_prototype)}(f, + _f = prepare_function(f) + IntegralFunction{iip, specialize, typeof(_f), typeof(integrand_prototype)}(_f, integrand_prototype) end function IntegralFunction{iip}(f, integrand_prototype) where {iip} - return IntegralFunction{iip, FullSpecialize}(f, integrand_prototype) + IntegralFunction{iip, FullSpecialize}(f, integrand_prototype) end function IntegralFunction(f) calculated_iip = isinplace(f, 3, "integral", true) @@ -4098,12 +4121,13 @@ end function BatchIntegralFunction{iip, specialize}(f, integrand_prototype; max_batch::Integer = typemax(Int)) where {iip, specialize} + _f = prepare_function(f) BatchIntegralFunction{ iip, specialize, - typeof(f), + typeof(_f), typeof(integrand_prototype), - }(f, + }(_f, integrand_prototype, max_batch) end @@ -4221,11 +4245,16 @@ struct IncrementingODEFunction{iip, specialize, F} <: AbstractODEFunction{iip} f::F end +function IncrementingODEFunction{iip, specialize}(f) where {iip, specialize} + _f = prepare_function(f) + IncrementingODEFunction{iip, specialize, typeof(_f)}(_f) +end + function IncrementingODEFunction{iip}(f) where {iip} - IncrementingODEFunction{iip, FullSpecialize, typeof(f)}(f) + IncrementingODEFunction{iip, FullSpecialize}(f) end function IncrementingODEFunction(f) - IncrementingODEFunction{isinplace(f, 7), FullSpecialize, typeof(f)}(f) + IncrementingODEFunction{isinplace(f, 7), FullSpecialize}(f) end (f::IncrementingODEFunction)(args...; kwargs...) = f.f(args...; kwargs...) diff --git a/src/utils.jl b/src/utils.jl index 7f5f4b824..b1490f993 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -495,3 +495,36 @@ end _unwrap_val(::Val{B}) where {B} = B _unwrap_val(B) = B + +""" + prepare_initial_state(u0) = u0 + +Whenever an initial state is passed to the SciML ecosystem, is passed to +`prepare_initial_state` and the result is used instead. If you define a +type which cannot be used as a state but can be converted to something that +can be, then you may define `prepare_initial_state(x::YourType) = ...`. + +!!! warning + This function is experimental and may be removed in the future. + +See also: `prepare_function`. +""" +prepare_initial_state(u0) = u0 + +""" + prepare_function(f) = f + +Whenever a function is passed to the SciML ecosystem, is passed to +`prepare_function` and the result is used instead. If you define a type which +cannot be used as a function in the SciML ecosystem but can be converted to +something that can be, then you may define `prepare_function(x::YourType) = ...`. + +`prepare_function` may be called before or after +the arity of a function is computed with `numargs` + +!!! warning + This function is experimental and may be removed in the future. + +See also: `prepare_initial_state`. +""" +prepare_function(f) = f From a2fdd5e9c538e4f1e10e6327094225925392f40d Mon Sep 17 00:00:00 2001 From: Lilith Hafner Date: Thu, 5 Oct 2023 16:49:36 -0500 Subject: [PATCH 2/7] add tests and extension boilerplate --- .gitignore | 3 ++ Project.toml | 5 ++- test/python/Project.toml | 4 +++ test/python/pythoncall.jl | 69 +++++++++++++++++++++++++++++++++++++++ test/runtests.jl | 3 ++ 5 files changed, 83 insertions(+), 1 deletion(-) create mode 100644 test/python/pythoncall.jl diff --git a/.gitignore b/.gitignore index e557c1593..9ae3c857d 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,6 @@ Manifest.toml # vscode stuff .vscode .vscode/* + +# python extensions +.CondaPkg diff --git a/Project.toml b/Project.toml index 5a5d89a8e..1c74bdb16 100644 --- a/Project.toml +++ b/Project.toml @@ -34,10 +34,12 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [weakdeps] PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0" +PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [extensions] PyCallExt = "PyCall" +PythonCallExt = "PythonCall" ZygoteExt = "Zygote" [compat] @@ -72,6 +74,7 @@ DelayDiffEq = "bcd4f6db-9728-5f36-b5f7-82caef46ccdb" OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0" +PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0" @@ -79,4 +82,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Pkg", "PyCall", "SafeTestsets", "Test", "StaticArrays", "StochasticDiffEq", "Aqua", "Zygote"] +test = ["Pkg", "PyCall", "PythonCall", "SafeTestsets", "Test", "StaticArrays", "StochasticDiffEq", "Aqua", "Zygote"] diff --git a/test/python/Project.toml b/test/python/Project.toml index 8ec377a5f..121f41471 100644 --- a/test/python/Project.toml +++ b/test/python/Project.toml @@ -1,9 +1,13 @@ [deps] +DifferentialEquations = "0c46a032-eb83-5123-abaf-570d42b7fbaa" OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0" +PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" [compat] +DifferentialEquations = "7.11" OrdinaryDiffEq = "6.33" PyCall = "1.96" +PythonCall = "0.9.14" SciMLBase = "2" diff --git a/test/python/pythoncall.jl b/test/python/pythoncall.jl new file mode 100644 index 000000000..5136ffda2 --- /dev/null +++ b/test/python/pythoncall.jl @@ -0,0 +1,69 @@ +using DifferentialEquations, PythonCall + +@testset "Use of DifferentialEquations through PythonCall with user code written in Python" begin + pyexec(""" + from juliacall import Main + de = Main.seval("DifferentialEquations") + + def f(u,p,t): + return -u + + u0 = 0.5 + tspan = (0., 1.) + prob = de.ODEProblem(f, u0, tspan) + sol = de.solve(prob) + """, @__MODULE__) + @test pyconvert(Any, pyeval("sol", @__MODULE__)) isa ODESolution + + pyexec(""" + def f(u,p,t): + x, y, z = u + sigma, rho, beta = p + return [sigma * (y - x), x * (rho - z) - y, x * y - beta * z] + + u0 = [1.0,0.0,0.0] + tspan = (0., 100.) + p = [10.0,28.0,8/3] + prob = de.ODEProblem(f, u0, tspan, p) + sol = de.solve(prob,saveat=0.01) + """, @__MODULE__) + @test pyconvert(Any, pyeval("sol", @__MODULE__)) isa ODESolution + + # TODO: test the types and shapes of sol.t and de.transpose(de.stack(sol.u)) but don't actually plot them in CI + # pyexec(""" + # import matplotlib.pyplot as plt + + # plt.plot(sol.t, de.transpose(de.stack(sol.u))) # :( fails without the conversion + # plt.show() + # """, @__MODULE__) + + @pyexec """ + jul_f = Main.seval(""\" + function f(du,u,p,t) + x, y, z = u + sigma, rho, beta = p + du[1] = sigma * (y - x) + du[2] = x * (rho - z) - y + du[3] = x * y - beta * z + end""\") + u0 = [1.0,0.0,0.0] + tspan = (0., 100.) + p = [10.0,28.0,2.66] + prob = de.ODEProblem(jul_f, u0, tspan, p) + sol = de.solve(prob) + """ + @test pyconvert(Any, pyeval("sol", @__MODULE__)) isa ODESolution + + pyexec(""" + def f(u,p,t): + return 1.01*u + + def g(u,p,t): + return 0.87*u + + u0 = 0.5 + tspan = (0.0,1.0) + prob = de.SDEProblem(f,g,u0,tspan) + sol = de.solve(prob,reltol=1e-3,abstol=1e-3) + """, @__MODULE__) +end diff --git a/test/runtests.jl b/test/runtests.jl index c49d74480..464b596e4 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -105,5 +105,8 @@ end @time @safetestset "PyCall" begin include("python/pycall.jl") end + @time @safetestset "PythonCall" begin + include("python/pythoncall.jl") + end end end From 16efdb21ebf047ddd2bb0a2f74808fc1e031e759 Mon Sep 17 00:00:00 2001 From: Lilith Hafner Date: Fri, 6 Oct 2023 08:46:32 -0500 Subject: [PATCH 3/7] load DifferentialEquations on the Python side --- test/python/pythoncall.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/test/python/pythoncall.jl b/test/python/pythoncall.jl index 5136ffda2..3ac009159 100644 --- a/test/python/pythoncall.jl +++ b/test/python/pythoncall.jl @@ -3,6 +3,7 @@ using DifferentialEquations, PythonCall @testset "Use of DifferentialEquations through PythonCall with user code written in Python" begin pyexec(""" from juliacall import Main + Main.seval("using DifferentialEquations") de = Main.seval("DifferentialEquations") def f(u,p,t): From 180c3ae90a18c46cf274ff8fbbd45cb1b440d638 Mon Sep 17 00:00:00 2001 From: Lilith Hafner Date: Fri, 6 Oct 2023 08:49:03 -0500 Subject: [PATCH 4/7] comment out PyCall tests to avoid spooky Python action at a distance breaking things --- test/runtests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index 464b596e4..dfaafa950 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -103,7 +103,7 @@ end if !is_APPVEYOR && GROUP == "Python" activate_python_env() @time @safetestset "PyCall" begin - include("python/pycall.jl") + # include("python/pycall.jl") end @time @safetestset "PythonCall" begin include("python/pythoncall.jl") From a9ba6aee9b99fa9ee895675de5fd69d4f513f154 Mon Sep 17 00:00:00 2001 From: Lilith Hafner Date: Fri, 6 Oct 2023 09:18:26 -0500 Subject: [PATCH 5/7] configure PythonCall to use the same interpreter as PyCall to avoid the spooky action at a distance --- test/python/pythoncall.jl | 5 +++++ test/runtests.jl | 2 +- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/test/python/pythoncall.jl b/test/python/pythoncall.jl index 3ac009159..a8bd9521f 100644 --- a/test/python/pythoncall.jl +++ b/test/python/pythoncall.jl @@ -1,3 +1,8 @@ +# PyCall and PythonCall must use the same Python interpreter. This environment variable +# tells PythonCall to use the same Python interpreter as PyCall. See +# https://github.com/JuliaPy/PythonCall.jl/blob/5f56a9b96b867a9f6742ab1d1e2361abd844e19f/docs/src/pycall.md#tips +ENV["JULIA_PYTHONCALL_EXE"]="@PyCall" + using DifferentialEquations, PythonCall @testset "Use of DifferentialEquations through PythonCall with user code written in Python" begin diff --git a/test/runtests.jl b/test/runtests.jl index dfaafa950..464b596e4 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -103,7 +103,7 @@ end if !is_APPVEYOR && GROUP == "Python" activate_python_env() @time @safetestset "PyCall" begin - # include("python/pycall.jl") + include("python/pycall.jl") end @time @safetestset "PythonCall" begin include("python/pythoncall.jl") From 943c781b927e8e85d1189f801a820c89b24d3c77 Mon Sep 17 00:00:00 2001 From: Lilith Hafner Date: Fri, 6 Oct 2023 09:35:34 -0500 Subject: [PATCH 6/7] convert u0 everywhere that promote_tspan happens --- src/problems/analytical_problems.jl | 5 +++-- src/problems/bvp_problems.jl | 5 +++-- src/problems/dae_problems.jl | 12 +++++++----- src/problems/dde_problems.jl | 6 ++++-- src/problems/discrete_problems.jl | 5 +++-- src/problems/implicit_discrete_problems.jl | 5 +++-- src/problems/rode_problems.jl | 5 +++-- src/problems/sdde_problems.jl | 5 +++-- src/problems/steady_state_problems.jl | 3 ++- 9 files changed, 31 insertions(+), 20 deletions(-) diff --git a/src/problems/analytical_problems.jl b/src/problems/analytical_problems.jl index 7dabf9807..d6b744885 100644 --- a/src/problems/analytical_problems.jl +++ b/src/problems/analytical_problems.jl @@ -10,11 +10,12 @@ struct AnalyticalProblem{uType, tType, isinplace, P, F, K} <: kwargs::K @add_kwonly function AnalyticalProblem{iip}(f, u0, tspan, p = NullParameters(); kwargs...) where {iip} + _u0 = prepare_initial_state(u0) _tspan = promote_tspan(tspan) warn_paramtype(p) - new{typeof(u0), typeof(_tspan), iip, typeof(p), + new{typeof(_u0), typeof(_tspan), iip, typeof(p), typeof(f), typeof(kwargs)}(f, - u0, + _u0, _tspan, p, kwargs) diff --git a/src/problems/bvp_problems.jl b/src/problems/bvp_problems.jl index 2e9f1feab..3baf71215 100644 --- a/src/problems/bvp_problems.jl +++ b/src/problems/bvp_problems.jl @@ -110,6 +110,7 @@ struct BVProblem{uType, tType, isinplace, P, F, BF, PT, K} <: @add_kwonly function BVProblem{iip}(f::AbstractBVPFunction{iip, TP}, bc, u0, tspan, p = NullParameters(); problem_type=nothing, kwargs...) where {iip, TP} + _u0 = prepare_initial_state(u0) _tspan = promote_tspan(tspan) warn_paramtype(p) prob_type = TP ? TwoPointBVProblem() : StandardBVProblem() @@ -119,8 +120,8 @@ struct BVProblem{uType, tType, isinplace, P, F, BF, PT, K} <: else @assert prob_type === problem_type "This indicates incorrect problem type specification! Users should never pass in `problem_type` kwarg, this exists exclusively for internal use." end - return new{typeof(u0), typeof(_tspan), iip, typeof(p), typeof(f), typeof(bc), - typeof(problem_type), typeof(kwargs)}(f, bc, u0, _tspan, p, problem_type, + return new{typeof(_u0), typeof(_tspan), iip, typeof(p), typeof(f), typeof(bc), + typeof(problem_type), typeof(kwargs)}(f, bc, _u0, _tspan, p, problem_type, kwargs) end diff --git a/src/problems/dae_problems.jl b/src/problems/dae_problems.jl index 5d729cd8d..e5e532aac 100644 --- a/src/problems/dae_problems.jl +++ b/src/problems/dae_problems.jl @@ -80,21 +80,23 @@ struct DAEProblem{uType, duType, tType, isinplace, P, F, K, D} <: du0, u0, tspan, p = NullParameters(); differential_vars = nothing, kwargs...) where {iip} - if !isnothing(u0) + _u0 = prepare_initial_state(u0) + _du0 = prepare_initial_state(du0) + if !isnothing(_u0) # Defend against external solvers like Sundials breaking on non-uniform input dimensions. - size(du0) == size(u0) || + size(_du0) == size(_u0) || throw(ArgumentError("Sizes of u0 and du0 must be the same.")) if !isnothing(differential_vars) - size(u0) == size(differential_vars) || + size(_u0) == size(differential_vars) || throw(ArgumentError("Sizes of u0 and differential_vars must be the same.")) end end _tspan = promote_tspan(tspan) warn_paramtype(p) - new{typeof(u0), typeof(du0), typeof(_tspan), + new{typeof(_u0), typeof(_du0), typeof(_tspan), isinplace(f), typeof(p), typeof(f), typeof(kwargs), - typeof(differential_vars)}(f, du0, u0, _tspan, p, + typeof(differential_vars)}(f, _du0, _u0, _tspan, p, kwargs, differential_vars) end diff --git a/src/problems/dde_problems.jl b/src/problems/dde_problems.jl index 3bf2df8f2..98dfc4a0c 100644 --- a/src/problems/dde_problems.jl +++ b/src/problems/dde_problems.jl @@ -224,11 +224,13 @@ struct DDEProblem{uType, tType, lType, lType2, isinplace, P, F, H, K, PT} <: order_discontinuity_t0 = 0, problem_type = StandardDDEProblem(), kwargs...) where {iip} + _u0 = prepare_initial_state(u0) _tspan = promote_tspan(tspan) warn_paramtype(p) - new{typeof(u0), typeof(_tspan), typeof(constant_lags), typeof(dependent_lags), + new{typeof(_u0), typeof(_tspan), typeof(constant_lags), typeof(dependent_lags), isinplace(f), - typeof(p), typeof(f), typeof(h), typeof(kwargs), typeof(problem_type)}(f, u0, h, + typeof(p), typeof(f), typeof(h), typeof(kwargs), typeof(problem_type)}(f, _u0, + h, _tspan, p, constant_lags, diff --git a/src/problems/discrete_problems.jl b/src/problems/discrete_problems.jl index 9d6f643e9..b2f2a362c 100644 --- a/src/problems/discrete_problems.jl +++ b/src/problems/discrete_problems.jl @@ -90,12 +90,13 @@ struct DiscreteProblem{uType, tType, isinplace, P, F, K} <: @add_kwonly function DiscreteProblem{iip}(f::AbstractDiscreteFunction{iip}, u0, tspan::Tuple, p = NullParameters(); kwargs...) where {iip} + _u0 = prepare_initial_state(u0) _tspan = promote_tspan(tspan) warn_paramtype(p) - new{typeof(u0), typeof(_tspan), isinplace(f, 4), + new{typeof(_u0), typeof(_tspan), isinplace(f, 4), typeof(p), typeof(f), typeof(kwargs)}(f, - u0, + _u0, _tspan, p, kwargs) diff --git a/src/problems/implicit_discrete_problems.jl b/src/problems/implicit_discrete_problems.jl index 063a96d2e..b6463fa6b 100644 --- a/src/problems/implicit_discrete_problems.jl +++ b/src/problems/implicit_discrete_problems.jl @@ -86,12 +86,13 @@ struct ImplicitDiscreteProblem{uType, tType, isinplace, P, F, K} <: u0, tspan::Tuple, p = NullParameters(); kwargs...) where {iip} + _u0 = prepare_initial_state(u0) _tspan = promote_tspan(tspan) warn_paramtype(p) - new{typeof(u0), typeof(_tspan), isinplace(f, 6), + new{typeof(_u0), typeof(_tspan), isinplace(f, 6), typeof(p), typeof(f), typeof(kwargs)}(f, - u0, + _u0, _tspan, p, kwargs) diff --git a/src/problems/rode_problems.jl b/src/problems/rode_problems.jl index d5c04e21b..18c7a587e 100644 --- a/src/problems/rode_problems.jl +++ b/src/problems/rode_problems.jl @@ -69,12 +69,13 @@ mutable struct RODEProblem{uType, tType, isinplace, P, NP, F, K, ND} <: rand_prototype = nothing, noise = nothing, seed = UInt64(0), kwargs...) where {iip} + _u0 = prepare_initial_state(u0) _tspan = promote_tspan(tspan) warn_paramtype(p) - new{typeof(u0), typeof(_tspan), + new{typeof(_u0), typeof(_tspan), isinplace(f), typeof(p), typeof(noise), typeof(f), typeof(kwargs), - typeof(rand_prototype)}(f, u0, _tspan, p, noise, kwargs, + typeof(rand_prototype)}(f, _u0, _tspan, p, noise, kwargs, rand_prototype, seed) end function RODEProblem{iip}(f, u0, tspan, p = NullParameters(); kwargs...) where {iip} diff --git a/src/problems/sdde_problems.jl b/src/problems/sdde_problems.jl index 12cafe9d7..76d4b8dc4 100644 --- a/src/problems/sdde_problems.jl +++ b/src/problems/sdde_problems.jl @@ -126,12 +126,13 @@ struct SDDEProblem{uType, tType, lType, lType2, isinplace, P, NP, F, G, H, K, ND det(f.mass_matrix) != 1, order_discontinuity_t0 = 0 // 1, kwargs...) where {iip} + _u0 = prepare_initial_state(u0) _tspan = promote_tspan(tspan) warn_paramtype(p) - new{typeof(u0), typeof(_tspan), typeof(constant_lags), typeof(dependent_lags), + new{typeof(_u0), typeof(_tspan), typeof(constant_lags), typeof(dependent_lags), isinplace(f), typeof(p), typeof(noise), typeof(f), typeof(g), typeof(h), typeof(kwargs), - typeof(noise_rate_prototype)}(f, g, u0, h, _tspan, p, noise, constant_lags, + typeof(noise_rate_prototype)}(f, g, _u0, h, _tspan, p, noise, constant_lags, dependent_lags, kwargs, noise_rate_prototype, seed, neutral, order_discontinuity_t0) end diff --git a/src/problems/steady_state_problems.jl b/src/problems/steady_state_problems.jl index 5a01410b9..ed4fd3a05 100644 --- a/src/problems/steady_state_problems.jl +++ b/src/problems/steady_state_problems.jl @@ -83,8 +83,9 @@ struct SteadyStateProblem{uType, isinplace, P, F, K} <: @add_kwonly function SteadyStateProblem{iip}(f::AbstractODEFunction{iip}, u0, p = NullParameters(); kwargs...) where {iip} + _u0 = prepare_initial_state(u0) warn_paramtype(p) - new{typeof(u0), isinplace(f), typeof(p), typeof(f), typeof(kwargs)}(f, u0, p, + new{typeof(_u0), isinplace(f), typeof(p), typeof(f), typeof(kwargs)}(f, _u0, p, kwargs) end From 44e8f43b3c61aa297ee37b967ca6d3c3b182c970 Mon Sep 17 00:00:00 2001 From: Lilith Hafner Date: Fri, 6 Oct 2023 10:04:17 -0500 Subject: [PATCH 7/7] implement TODO in tests --- test/python/pythoncall.jl | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/test/python/pythoncall.jl b/test/python/pythoncall.jl index a8bd9521f..3b13cac0c 100644 --- a/test/python/pythoncall.jl +++ b/test/python/pythoncall.jl @@ -35,13 +35,17 @@ using DifferentialEquations, PythonCall """, @__MODULE__) @test pyconvert(Any, pyeval("sol", @__MODULE__)) isa ODESolution - # TODO: test the types and shapes of sol.t and de.transpose(de.stack(sol.u)) but don't actually plot them in CI - # pyexec(""" - # import matplotlib.pyplot as plt - - # plt.plot(sol.t, de.transpose(de.stack(sol.u))) # :( fails without the conversion - # plt.show() - # """, @__MODULE__) + # Test that the types and shapes of sol.t and de.transpose(de.stack(sol.u)) are + # compatible with matplotlib, but don't actually plot anything. + pyexec(""" + u2 = de.transpose(de.stack(sol.u)) + ok = sol.t.shape == (10001,) and \ + u2.shape == (10001, 3) and \ + sol.t[0] == 0 and \ + sol.t[-1] == 100 and \ + type(u2[4123, 2]) == float + """, @__MODULE__) + @test pyconvert(Any, pyeval("ok", @__MODULE__)) @pyexec """ jul_f = Main.seval(""\"