From ff257cf88ba1e7c33b2a41d8a25c9cce5c80e02b Mon Sep 17 00:00:00 2001 From: Lilith Hafner Date: Sat, 23 Sep 2023 19:17:19 -0500 Subject: [PATCH] Add initial draft TODO: add prepare_u0 to all problems (currently just ODE and SDE) TODO: add tests TODO: add package extension boilerplate (e.g. update Project.toml) --- ext/PythonCallExt.jl | 23 +++++++ src/problems/ode_problems.jl | 18 +++--- src/problems/sde_problems.jl | 5 +- src/scimlfunctions.jl | 117 ++++++++++++++++++++++------------- src/utils.jl | 33 ++++++++++ 5 files changed, 143 insertions(+), 53 deletions(-) create mode 100644 ext/PythonCallExt.jl diff --git a/ext/PythonCallExt.jl b/ext/PythonCallExt.jl new file mode 100644 index 0000000000..e69096017c --- /dev/null +++ b/ext/PythonCallExt.jl @@ -0,0 +1,23 @@ +module PythonCallExt + +using PythonCall: Py, PyList, pyimport, hasproperty, pyconvert, pyisinstance, pybuiltins +using SciMLBase: SciMLBase + +# SciML uses a function's arity (number of arguments) to determine if it operates in place. +# PythonCall does not preserve arity, so we inspect Python functions to find their arity. +function SciMLBase.numargs(f::Py) + inspect = pyimport("inspect") + f2 = hasproperty(f, :py_func) ? f.py_func : f + # if `f` is a bound method (i.e., `self.f`), `getfullargspec` includes + # `self` in the `args` list. So, we subtract 1 in that case: + pyconvert(Int, length(first(inspect.getfullargspec(f2))) - inspect.ismethod(f2)) +end + +_pyconvert(x::Py) = pyisinstance(x, pybuiltins.list) ? [_pyconvert(x) for x in x] : pyconvert(Any, x) +_pyconvert(x::PyList) = [_pyconvert(x) for x in x] +_pyconvert(x) = x + +SciMLBase.prepare_initial_state(u0::Union{Py, PyList}) = _pyconvert(u0) +SciMLBase.prepare_function(f::Py) = _pyconvert ∘ f + +end diff --git a/src/problems/ode_problems.jl b/src/problems/ode_problems.jl index ce4425db3e..75355d983f 100644 --- a/src/problems/ode_problems.jl +++ b/src/problems/ode_problems.jl @@ -112,13 +112,14 @@ mutable struct ODEProblem{uType, tType, isinplace, P, F, K, PT} <: u0, tspan, p = NullParameters(), problem_type = StandardODEProblem(); kwargs...) where {iip} + _u0 = prepare_initial_state(u0) _tspan = promote_tspan(tspan) warn_paramtype(p) - new{typeof(u0), typeof(_tspan), + new{typeof(_u0), typeof(_tspan), isinplace(f), typeof(p), typeof(f), typeof(kwargs), typeof(problem_type)}(f, - u0, + _u0, _tspan, p, kwargs, @@ -133,9 +134,10 @@ mutable struct ODEProblem{uType, tType, isinplace, P, F, K, PT} <: This is determined automatically, but not inferred. """ function ODEProblem{iip}(f, u0, tspan, p = NullParameters(); kwargs...) where {iip} + _u0 = prepare_initial_state(u0) _tspan = promote_tspan(tspan) _f = ODEFunction{iip, DEFAULT_SPECIALIZATION}(f) - ODEProblem(_f, u0, _tspan, p; kwargs...) + ODEProblem(_f, _u0, _tspan, p; kwargs...) end @add_kwonly function ODEProblem{iip, recompile}(f, u0, tspan, p = NullParameters(); @@ -145,19 +147,20 @@ mutable struct ODEProblem{uType, tType, isinplace, P, F, K, PT} <: function ODEProblem{iip, FunctionWrapperSpecialize}(f, u0, tspan, p = NullParameters(); kwargs...) where {iip} + _u0 = prepare_initial_state(u0) _tspan = promote_tspan(tspan) if !(f isa FunctionWrappersWrappers.FunctionWrappersWrapper) if iip ff = ODEFunction{iip, FunctionWrapperSpecialize}(wrapfun_iip(f, - (u0, u0, p, + (_u0, _u0, p, _tspan[1]))) else ff = ODEFunction{iip, FunctionWrapperSpecialize}(wrapfun_oop(f, - (u0, p, + (_u0, p, _tspan[1]))) end end - ODEProblem{iip}(ff, u0, _tspan, p; kwargs...) + ODEProblem{iip}(ff, _u0, _tspan, p; kwargs...) end end TruncatedStacktraces.@truncate_stacktrace ODEProblem 3 1 2 @@ -183,9 +186,10 @@ end function ODEProblem(f, u0, tspan, p = NullParameters(); kwargs...) iip = isinplace(f, 4) + _u0 = prepare_initial_state(u0) _tspan = promote_tspan(tspan) _f = ODEFunction{iip, DEFAULT_SPECIALIZATION}(f) - ODEProblem(_f, u0, _tspan, p; kwargs...) + ODEProblem(_f, _u0, _tspan, p; kwargs...) end """ diff --git a/src/problems/sde_problems.jl b/src/problems/sde_problems.jl index b3840db48c..445a9e8d2d 100644 --- a/src/problems/sde_problems.jl +++ b/src/problems/sde_problems.jl @@ -99,13 +99,14 @@ struct SDEProblem{uType, tType, isinplace, P, NP, F, G, K, ND} <: noise_rate_prototype = nothing, noise = nothing, seed = UInt64(0), kwargs...) where {iip} + _u0 = prepare_initial_state(u0) _tspan = promote_tspan(tspan) warn_paramtype(p) - new{typeof(u0), typeof(_tspan), + new{typeof(_u0), typeof(_tspan), isinplace(f), typeof(p), typeof(noise), typeof(f), typeof(f.g), typeof(kwargs), - typeof(noise_rate_prototype)}(f, f.g, u0, _tspan, p, + typeof(noise_rate_prototype)}(f, f.g, _u0, _tspan, p, noise, kwargs, noise_rate_prototype, seed) end diff --git a/src/scimlfunctions.jl b/src/scimlfunctions.jl index c74808df26..a7b2442176 100644 --- a/src/scimlfunctions.jl +++ b/src/scimlfunctions.jl @@ -2515,6 +2515,8 @@ function ODEFunction{iip, specialize}(f; throw(NonconformingFunctionsError(functions)) end + _f = prepare_function(f) + if specialize === NoSpecialize ODEFunction{iip, specialize, Any, Any, Any, Any, @@ -2522,31 +2524,31 @@ function ODEFunction{iip, specialize}(f; typeof(sparsity), Any, Any, typeof(W_prototype), Any, typeof(syms), typeof(indepsym), typeof(paramsyms), Any, typeof(_colorvec), - typeof(sys)}(f, mass_matrix, analytic, tgrad, jac, + typeof(sys)}(_f, mass_matrix, analytic, tgrad, jac, jvp, vjp, jac_prototype, sparsity, Wfact, Wfact_t, W_prototype, paramjac, syms, indepsym, paramsyms, observed, _colorvec, sys) elseif specialize === false ODEFunction{iip, FunctionWrapperSpecialize, - typeof(f), typeof(mass_matrix), typeof(analytic), typeof(tgrad), + typeof(_f), typeof(mass_matrix), typeof(analytic), typeof(tgrad), typeof(jac), typeof(jvp), typeof(vjp), typeof(jac_prototype), typeof(sparsity), typeof(Wfact), typeof(Wfact_t), typeof(W_prototype), typeof(paramjac), typeof(syms), typeof(indepsym), typeof(paramsyms), typeof(observed), typeof(_colorvec), - typeof(sys)}(f, mass_matrix, analytic, tgrad, jac, + typeof(sys)}(_f, mass_matrix, analytic, tgrad, jac, jvp, vjp, jac_prototype, sparsity, Wfact, Wfact_t, W_prototype, paramjac, syms, indepsym, paramsyms, observed, _colorvec, sys) else ODEFunction{iip, specialize, - typeof(f), typeof(mass_matrix), typeof(analytic), typeof(tgrad), + typeof(_f), typeof(mass_matrix), typeof(analytic), typeof(tgrad), typeof(jac), typeof(jvp), typeof(vjp), typeof(jac_prototype), typeof(sparsity), typeof(Wfact), typeof(Wfact_t), typeof(W_prototype), typeof(paramjac), typeof(syms), typeof(indepsym), typeof(paramsyms), typeof(observed), typeof(_colorvec), - typeof(sys)}(f, mass_matrix, analytic, tgrad, jac, + typeof(sys)}(_f, mass_matrix, analytic, tgrad, jac, jvp, vjp, jac_prototype, sparsity, Wfact, Wfact_t, W_prototype, paramjac, syms, indepsym, paramsyms, observed, _colorvec, sys) @@ -2875,15 +2877,16 @@ function DiscreteFunction{iip, specialize}(f; sys = __has_sys(f) ? f.sys : nothing) where {iip, specialize, } + _f = prepare_function(f) if specialize === NoSpecialize - DiscreteFunction{iip, specialize, Any, Any, Any, Any, Any, Any, Any}(f, analytic, + DiscreteFunction{iip, specialize, Any, Any, Any, Any, Any, Any, Any}(_f, analytic, syms, indepsym, parasmsyms, observed, sys) else - DiscreteFunction{iip, specialize, typeof(f), typeof(analytic), + DiscreteFunction{iip, specialize, typeof(_f), typeof(analytic), typeof(syms), typeof(indepsym), typeof(paramsyms), - typeof(observed), typeof(sys)}(f, analytic, syms, indepsym, + typeof(observed), typeof(sys)}(_f, analytic, syms, indepsym, paramsyms, observed, sys) end end @@ -2931,8 +2934,9 @@ function ImplicitDiscreteFunction{iip, specialize}(f; iip, specialize, } + _f = prepare_function(f) if specialize === NoSpecialize - ImplicitDiscreteFunction{iip, specialize, Any, Any, Any, Any, Any, Any, Any}(f, + ImplicitDiscreteFunction{iip, specialize, Any, Any, Any, Any, Any, Any, Any}(_f, analytic, syms, indepsym, @@ -2940,9 +2944,9 @@ function ImplicitDiscreteFunction{iip, specialize}(f; observed, sys) else - ImplicitDiscreteFunction{iip, specialize, typeof(f), typeof(analytic), + ImplicitDiscreteFunction{iip, specialize, typeof(_f), typeof(analytic), typeof(syms), typeof(indepsym), typeof(paramsyms), - typeof(observed), typeof(sys)}(f, analytic, syms, indepsym, + typeof(observed), typeof(sys)}(_f, analytic, syms, indepsym, paramsyms, observed, sys) end end @@ -3034,24 +3038,26 @@ function SDEFunction{iip, specialize}(f, g; throw(NonconformingFunctionsError(functions)) end + _f = prepare_function(f) + _g = prepare_function(g) if specialize === NoSpecialize SDEFunction{iip, specialize, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, typeof(syms), typeof(indepsym), typeof(paramsyms), Any, - typeof(_colorvec), typeof(sys)}(f, g, mass_matrix, analytic, + typeof(_colorvec), typeof(sys)}(_f, _g, mass_matrix, analytic, tgrad, jac, jvp, vjp, jac_prototype, sparsity, Wfact, Wfact_t, paramjac, ggprime, syms, indepsym, paramsyms, observed, _colorvec, sys) else - SDEFunction{iip, specialize, typeof(f), typeof(g), + SDEFunction{iip, specialize, typeof(_f), typeof(_g), typeof(mass_matrix), typeof(analytic), typeof(tgrad), typeof(jac), typeof(jvp), typeof(vjp), typeof(jac_prototype), typeof(sparsity), typeof(Wfact), typeof(Wfact_t), typeof(paramjac), typeof(ggprime), typeof(syms), typeof(indepsym), typeof(paramsyms), - typeof(observed), typeof(_colorvec), typeof(sys)}(f, g, mass_matrix, + typeof(observed), typeof(_colorvec), typeof(sys)}(_f, _g, mass_matrix, analytic, tgrad, jac, jvp, vjp, jac_prototype, @@ -3361,11 +3367,13 @@ function RODEFunction{iip, specialize}(f; end =# + _f = prepare_function(f) + if specialize === NoSpecialize RODEFunction{iip, specialize, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, typeof(syms), typeof(indepsym), typeof(paramsyms), Any, - typeof(_colorvec), Any}(f, mass_matrix, analytic, + typeof(_colorvec), Any}(_f, mass_matrix, analytic, tgrad, jac, jvp, vjp, jac_prototype, @@ -3375,13 +3383,13 @@ function RODEFunction{iip, specialize}(f; _colorvec, sys, analytic_full) else - RODEFunction{iip, specialize, typeof(f), typeof(mass_matrix), + RODEFunction{iip, specialize, typeof(_f), typeof(mass_matrix), typeof(analytic), typeof(tgrad), typeof(jac), typeof(jvp), typeof(vjp), typeof(jac_prototype), typeof(sparsity), typeof(Wfact), typeof(Wfact_t), typeof(paramjac), typeof(syms), typeof(indepsym), typeof(paramsyms), typeof(observed), typeof(_colorvec), - typeof(sys)}(f, mass_matrix, analytic, tgrad, + typeof(sys)}(_f, mass_matrix, analytic, tgrad, jac, jvp, vjp, jac_prototype, sparsity, Wfact, Wfact_t, paramjac, syms, indepsym, paramsyms, observed, _colorvec, sys, analytic_full) @@ -3447,23 +3455,25 @@ function DAEFunction{iip, specialize}(f; throw(NonconformingFunctionsError(functions)) end + _f = prepare_function(f) + if specialize === NoSpecialize DAEFunction{iip, specialize, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, typeof(syms), typeof(indepsym), typeof(paramsyms), - Any, typeof(_colorvec), Any}(f, analytic, tgrad, jac, jvp, + Any, typeof(_colorvec), Any}(_f, analytic, tgrad, jac, jvp, vjp, jac_prototype, sparsity, Wfact, Wfact_t, paramjac, syms, indepsym, paramsyms, observed, _colorvec, sys) else - DAEFunction{iip, specialize, typeof(f), typeof(analytic), typeof(tgrad), + DAEFunction{iip, specialize, typeof(_f), typeof(analytic), typeof(tgrad), typeof(jac), typeof(jvp), typeof(vjp), typeof(jac_prototype), typeof(sparsity), typeof(Wfact), typeof(Wfact_t), typeof(paramjac), typeof(syms), typeof(indepsym), typeof(paramsyms), typeof(observed), typeof(_colorvec), - typeof(sys)}(f, analytic, tgrad, jac, jvp, vjp, + typeof(sys)}(_f, analytic, tgrad, jac, jvp, vjp, jac_prototype, sparsity, Wfact, Wfact_t, paramjac, syms, indepsym, paramsyms, observed, _colorvec, sys) @@ -3534,11 +3544,13 @@ function DDEFunction{iip, specialize}(f; throw(NonconformingFunctionsError(functions)) end + _f = prepare_function(f) + if specialize === NoSpecialize DDEFunction{iip, specialize, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, typeof(syms), typeof(indepsym), typeof(paramsyms), - Any, typeof(_colorvec), Any}(f, mass_matrix, + Any, typeof(_colorvec), Any}(_f, mass_matrix, analytic, tgrad, jac, jvp, vjp, @@ -3550,13 +3562,13 @@ function DDEFunction{iip, specialize}(f; observed, _colorvec, sys) else - DDEFunction{iip, specialize, typeof(f), typeof(mass_matrix), typeof(analytic), + DDEFunction{iip, specialize, typeof(_f), typeof(mass_matrix), typeof(analytic), typeof(tgrad), typeof(jac), typeof(jvp), typeof(vjp), typeof(jac_prototype), typeof(sparsity), typeof(Wfact), typeof(Wfact_t), typeof(paramjac), typeof(syms), typeof(indepsym), typeof(paramsyms), typeof(observed), - typeof(_colorvec), typeof(sys)}(f, mass_matrix, analytic, + typeof(_colorvec), typeof(sys)}(_f, mass_matrix, analytic, tgrad, jac, jvp, vjp, jac_prototype, sparsity, Wfact, Wfact_t, paramjac, @@ -3706,11 +3718,14 @@ function SDDEFunction{iip, specialize}(f, g; _colorvec = colorvec end + _f = prepare_function(f) + _g = prepare_function(g) + if specialize === NoSpecialize SDDEFunction{iip, specialize, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, typeof(syms), typeof(indepsym), typeof(paramsyms), - Any, typeof(_colorvec), Any}(f, g, mass_matrix, + Any, typeof(_colorvec), Any}(_f, _g, mass_matrix, analytic, tgrad, jac, jvp, @@ -3724,13 +3739,13 @@ function SDDEFunction{iip, specialize}(f, g; _colorvec, sys) else - SDDEFunction{iip, specialize, typeof(f), typeof(g), + SDDEFunction{iip, specialize, typeof(_f), typeof(_g), typeof(mass_matrix), typeof(analytic), typeof(tgrad), typeof(jac), typeof(jvp), typeof(vjp), typeof(jac_prototype), typeof(sparsity), typeof(Wfact), typeof(Wfact_t), typeof(paramjac), typeof(ggprime), typeof(syms), typeof(indepsym), typeof(paramsyms), typeof(observed), - typeof(_colorvec), typeof(sys)}(f, g, mass_matrix, + typeof(_colorvec), typeof(sys)}(_f, _g, mass_matrix, analytic, tgrad, jac, jvp, vjp, jac_prototype, sparsity, Wfact, @@ -3810,12 +3825,14 @@ function NonlinearFunction{iip, specialize}(f; throw(NonconformingFunctionsError(functions)) end + _f = prepare_function(f) + if specialize === NoSpecialize NonlinearFunction{iip, specialize, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, typeof(syms), typeof(paramsyms), Any, - typeof(_colorvec), Any, Any}(f, mass_matrix, + typeof(_colorvec), Any, Any}(_f, mass_matrix, analytic, tgrad, jac, jvp, vjp, jac_prototype, @@ -3825,13 +3842,13 @@ function NonlinearFunction{iip, specialize}(f; _colorvec, sys, resid_prototype) else NonlinearFunction{iip, specialize, - typeof(f), typeof(mass_matrix), typeof(analytic), typeof(tgrad), + typeof(_f), typeof(mass_matrix), typeof(analytic), typeof(tgrad), typeof(jac), typeof(jvp), typeof(vjp), typeof(jac_prototype), typeof(sparsity), typeof(Wfact), typeof(Wfact_t), typeof(paramjac), typeof(syms), typeof(paramsyms), typeof(observed), - typeof(_colorvec), typeof(sys), typeof(resid_prototype)}(f, mass_matrix, + typeof(_colorvec), typeof(sys), typeof(resid_prototype)}(_f, mass_matrix, analytic, tgrad, jac, jvp, vjp, jac_prototype, sparsity, Wfact, @@ -3865,10 +3882,12 @@ function IntervalNonlinearFunction{iip, specialize}(f; iip, specialize, } + _f = prepare_function(f) + if specialize === NoSpecialize IntervalNonlinearFunction{iip, specialize, Any, Any, typeof(syms), typeof(paramsyms), Any, - typeof(_colorvec), Any}(f, mass_matrix, + typeof(_colorvec), Any}(_f, mass_matrix, analytic, tgrad, jac, jvp, vjp, jac_prototype, @@ -3878,10 +3897,10 @@ function IntervalNonlinearFunction{iip, specialize}(f; _colorvec, sys) else IntervalNonlinearFunction{iip, specialize, - typeof(f), typeof(analytic), typeof(syms), + typeof(_f), typeof(analytic), typeof(syms), typeof(paramsyms), typeof(observed), - typeof(sys)}(f, analytic, syms, + typeof(sys)}(_f, analytic, syms, paramsyms, observed, sys) end @@ -3922,8 +3941,9 @@ function OptimizationFunction{iip}(f, adtype::AbstractADType = NoAD(); lag_hess_colorvec = nothing, expr = nothing, cons_expr = nothing, sys = __has_sys(f) ? f.sys : nothing) where {iip} - isinplace(f, 2; has_two_dispatches = false, isoptimization = true) - OptimizationFunction{iip, typeof(adtype), typeof(f), typeof(grad), typeof(hess), + _f = prepare_function(f) + isinplace(_f, 2; has_two_dispatches = false, isoptimization = true) + OptimizationFunction{iip, typeof(adtype), typeof(_f), typeof(grad), typeof(hess), typeof(hv), typeof(cons), typeof(cons_j), typeof(cons_h), typeof(lag_h), typeof(hess_prototype), @@ -3933,7 +3953,7 @@ function OptimizationFunction{iip}(f, adtype::AbstractADType = NoAD(); typeof(hess_colorvec), typeof(cons_jac_colorvec), typeof(cons_hess_colorvec), typeof(lag_hess_colorvec), typeof(expr), typeof(cons_expr), - typeof(sys)}(f, adtype, grad, hess, + typeof(sys)}(_f, adtype, grad, hess, hv, cons, cons_j, cons_h, lag_h, hess_prototype, cons_jac_prototype, cons_hess_prototype, lag_hess_prototype, syms, @@ -4039,23 +4059,25 @@ function BVPFunction{iip, specialize, twopoint}(f, bc; throw(NonconformingFunctionsError(functions)) end + _f = prepare_function(f) + if specialize === NoSpecialize BVPFunction{iip, specialize, twopoint, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, typeof(syms), typeof(indepsym), typeof(paramsyms), - Any, typeof(_colorvec), typeof(_bccolorvec), Any}(f, bc, mass_matrix, + Any, typeof(_colorvec), typeof(_bccolorvec), Any}(_f, bc, mass_matrix, analytic, tgrad, jac, bcjac, jvp, vjp, jac_prototype, bcjac_prototype, bcresid_prototype, sparsity, Wfact, Wfact_t, paramjac, syms, indepsym, paramsyms, observed, _colorvec, _bccolorvec, sys) else - BVPFunction{iip, specialize, twopoint, typeof(f), typeof(bc), typeof(mass_matrix), + BVPFunction{iip, specialize, twopoint, typeof(_f), typeof(bc), typeof(mass_matrix), typeof(analytic), typeof(tgrad), typeof(jac), typeof(bcjac), typeof(jvp), typeof(vjp), typeof(jac_prototype), typeof(bcjac_prototype), typeof(bcresid_prototype), typeof(sparsity), typeof(Wfact), typeof(Wfact_t), typeof(paramjac), typeof(syms), typeof(indepsym), typeof(paramsyms), typeof(observed), - typeof(_colorvec), typeof(_bccolorvec), typeof(sys)}(f, bc, mass_matrix, analytic, + typeof(_colorvec), typeof(_bccolorvec), typeof(sys)}(_f, bc, mass_matrix, analytic, tgrad, jac, bcjac, jvp, vjp, jac_prototype, bcjac_prototype, bcresid_prototype, sparsity, Wfact, Wfact_t, paramjac, @@ -4074,12 +4096,13 @@ end BVPFunction(f::BVPFunction; kwargs...) = f function IntegralFunction{iip, specialize}(f, integrand_prototype) where {iip, specialize} - IntegralFunction{iip, specialize, typeof(f), typeof(integrand_prototype)}(f, + _f = prepare_function(f) + IntegralFunction{iip, specialize, typeof(_f), typeof(integrand_prototype)}(_f, integrand_prototype) end function IntegralFunction{iip}(f, integrand_prototype) where {iip} - return IntegralFunction{iip, FullSpecialize}(f, integrand_prototype) + IntegralFunction{iip, FullSpecialize}(f, integrand_prototype) end function IntegralFunction(f) calculated_iip = isinplace(f, 3, "integral", true) @@ -4098,12 +4121,13 @@ end function BatchIntegralFunction{iip, specialize}(f, integrand_prototype; max_batch::Integer = typemax(Int)) where {iip, specialize} + _f = prepare_function(f) BatchIntegralFunction{ iip, specialize, - typeof(f), + typeof(_f), typeof(integrand_prototype), - }(f, + }(_f, integrand_prototype, max_batch) end @@ -4221,11 +4245,16 @@ struct IncrementingODEFunction{iip, specialize, F} <: AbstractODEFunction{iip} f::F end +function IncrementingODEFunction{iip, specialize}(f) where {iip, specialize} + _f = prepare_function(f) + IncrementingODEFunction{iip, specialize, typeof(_f)}(_f) +end + function IncrementingODEFunction{iip}(f) where {iip} - IncrementingODEFunction{iip, FullSpecialize, typeof(f)}(f) + IncrementingODEFunction{iip, FullSpecialize}(f) end function IncrementingODEFunction(f) - IncrementingODEFunction{isinplace(f, 7), FullSpecialize, typeof(f)}(f) + IncrementingODEFunction{isinplace(f, 7), FullSpecialize}(f) end (f::IncrementingODEFunction)(args...; kwargs...) = f.f(args...; kwargs...) diff --git a/src/utils.jl b/src/utils.jl index 7f5f4b8246..b1490f9938 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -495,3 +495,36 @@ end _unwrap_val(::Val{B}) where {B} = B _unwrap_val(B) = B + +""" + prepare_initial_state(u0) = u0 + +Whenever an initial state is passed to the SciML ecosystem, is passed to +`prepare_initial_state` and the result is used instead. If you define a +type which cannot be used as a state but can be converted to something that +can be, then you may define `prepare_initial_state(x::YourType) = ...`. + +!!! warning + This function is experimental and may be removed in the future. + +See also: `prepare_function`. +""" +prepare_initial_state(u0) = u0 + +""" + prepare_function(f) = f + +Whenever a function is passed to the SciML ecosystem, is passed to +`prepare_function` and the result is used instead. If you define a type which +cannot be used as a function in the SciML ecosystem but can be converted to +something that can be, then you may define `prepare_function(x::YourType) = ...`. + +`prepare_function` may be called before or after +the arity of a function is computed with `numargs` + +!!! warning + This function is experimental and may be removed in the future. + +See also: `prepare_initial_state`. +""" +prepare_function(f) = f