diff --git a/src/scimlfunctions.jl b/src/scimlfunctions.jl index a5e9b3c3c..b2f609c11 100644 --- a/src/scimlfunctions.jl +++ b/src/scimlfunctions.jl @@ -402,7 +402,7 @@ numerically-defined functions. """ struct ODEFunction{iip, specialize, F, TMM, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TWt, WP, TPJ, O, TCV, - SYS} <: AbstractODEFunction{iip} + SYS, IProb, IProbMap} <: AbstractODEFunction{iip} f::F mass_matrix::TMM analytic::Ta @@ -419,6 +419,8 @@ struct ODEFunction{iip, specialize, F, TMM, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TW observed::O colorvec::TCV sys::SYS + initializeprob::IProb + initializeprobmap::IProbMap end TruncatedStacktraces.@truncate_stacktrace ODEFunction 1 2 @@ -1522,7 +1524,7 @@ automatically symbolically generating the Jacobian and more from the numerically-defined functions. """ struct DAEFunction{iip, specialize, F, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TWt, TPJ, O, TCV, - SYS} <: + SYS, IProb, IProbMap} <: AbstractDAEFunction{iip} f::F analytic::Ta @@ -1538,6 +1540,8 @@ struct DAEFunction{iip, specialize, F, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TWt, TP observed::O colorvec::TCV sys::SYS + initializeprob::IProb + initializeprobmap::IProbMap end TruncatedStacktraces.@truncate_stacktrace DAEFunction 1 2 @@ -2276,7 +2280,10 @@ function ODEFunction{iip, specialize}(f; observed = __has_observed(f) ? f.observed : DEFAULT_OBSERVED, colorvec = __has_colorvec(f) ? f.colorvec : nothing, - sys = __has_sys(f) ? f.sys : nothing) where {iip, + sys = __has_sys(f) ? f.sys : nothing, + initializeprob = __has_initializeprob(f) ? f.sys : nothing, + initializeprobmap = __has_initializeprobmap(f) ? f.sys : nothing +) where {iip, specialize } if mass_matrix === I && f isa Tuple @@ -2321,7 +2328,11 @@ function ODEFunction{iip, specialize}(f; _f = prepare_function(f) - sys = sys_or_symbolcache(sys, syms, paramsyms, indepsym) + sys = something(sys, SymbolCache(syms, paramsyms, indepsym)) + + @assert typeof(initializeprob) <: + Union{Nothing, NonlinearProblem, NonlinearLeastSquaresProblem} + if specialize === NoSpecialize ODEFunction{iip, specialize, Any, Any, Any, Any, @@ -2329,10 +2340,10 @@ function ODEFunction{iip, specialize}(f; typeof(sparsity), Any, Any, typeof(W_prototype), Any, Any, typeof(_colorvec), - typeof(sys)}(_f, mass_matrix, analytic, tgrad, jac, + typeof(sys), Any, Any}(_f, mass_matrix, analytic, tgrad, jac, jvp, vjp, jac_prototype, sparsity, Wfact, Wfact_t, W_prototype, paramjac, - observed, _colorvec, sys) + observed, _colorvec, sys, initializeprob, initializeprobmap) elseif specialize === false ODEFunction{iip, FunctionWrapperSpecialize, typeof(_f), typeof(mass_matrix), typeof(analytic), typeof(tgrad), @@ -2341,10 +2352,11 @@ function ODEFunction{iip, specialize}(f; typeof(paramjac), typeof(observed), typeof(_colorvec), - typeof(sys)}(_f, mass_matrix, analytic, tgrad, jac, + typeof(sys), typeof(initializeprob), + typeof(initializeprobmap)}(_f, mass_matrix, analytic, tgrad, jac, jvp, vjp, jac_prototype, sparsity, Wfact, Wfact_t, W_prototype, paramjac, - observed, _colorvec, sys) + observed, _colorvec, sys, initializeprob, initializeprobmap) else ODEFunction{iip, specialize, typeof(_f), typeof(mass_matrix), typeof(analytic), typeof(tgrad), @@ -2353,10 +2365,11 @@ function ODEFunction{iip, specialize}(f; typeof(paramjac), typeof(observed), typeof(_colorvec), - typeof(sys)}(_f, mass_matrix, analytic, tgrad, jac, + typeof(sys), typeof(initializeprob), + typeof(initializeprobmap)}(_f, mass_matrix, analytic, tgrad, jac, jvp, vjp, jac_prototype, sparsity, Wfact, Wfact_t, W_prototype, paramjac, - observed, _colorvec, sys) + observed, _colorvec, sys, initializeprob, initializeprobmap) end end @@ -2373,10 +2386,10 @@ function unwrapped_f(f::ODEFunction, newf = unwrapped_f(f.f)) Any, Any, Any, Any, typeof(f.jac_prototype), typeof(f.sparsity), Any, Any, Any, Any, typeof(f.colorvec), - typeof(f.sys)}(newf, f.mass_matrix, f.analytic, f.tgrad, f.jac, + typeof(f.sys), Any, Any}(newf, f.mass_matrix, f.analytic, f.tgrad, f.jac, f.jvp, f.vjp, f.jac_prototype, f.sparsity, f.Wfact, f.Wfact_t, f.W_prototype, f.paramjac, - f.observed, f.colorvec, f.sys) + f.observed, f.colorvec, f.sys, f.initializeprob, f.initializeprobmap) else ODEFunction{isinplace(f), specialization(f), typeof(newf), typeof(f.mass_matrix), typeof(f.analytic), typeof(f.tgrad), @@ -2384,10 +2397,12 @@ function unwrapped_f(f::ODEFunction, newf = unwrapped_f(f.f)) typeof(f.sparsity), typeof(f.Wfact), typeof(f.Wfact_t), typeof(f.W_prototype), typeof(f.paramjac), typeof(f.observed), typeof(f.colorvec), - typeof(f.sys)}(newf, f.mass_matrix, f.analytic, f.tgrad, f.jac, + typeof(f.sys), typeof(f.initializeprob), + typeof(f.initializeprobmap)}(newf, f.mass_matrix, f.analytic, f.tgrad, f.jac, f.jvp, f.vjp, f.jac_prototype, f.sparsity, f.Wfact, f.Wfact_t, f.W_prototype, f.paramjac, - f.observed, f.colorvec, f.sys) + f.observed, f.colorvec, f.sys, f.initializeprob, + f.initializeprobmap) end end @@ -3177,7 +3192,9 @@ function DAEFunction{iip, specialize}(f; observed = __has_observed(f) ? f.observed : DEFAULT_OBSERVED, colorvec = __has_colorvec(f) ? f.colorvec : nothing, - sys = __has_sys(f) ? f.sys : nothing) where {iip, + sys = __has_sys(f) ? f.sys : nothing, + initializeprob = __has_initializeprob(f) ? f.sys : nothing, + initializeprobmap = __has_initializeprobmap(f) ? f.sys : nothing) where {iip, specialize } if jac === nothing && isa(jac_prototype, AbstractSciMLOperator) @@ -3209,24 +3226,28 @@ function DAEFunction{iip, specialize}(f; _f = prepare_function(f) sys = sys_or_symbolcache(sys, syms, paramsyms, indepsym) + @assert typeof(initializeprob) <: + Union{Nothing, NonlinearProblem, NonlinearLeastSquaresProblem} + if specialize === NoSpecialize DAEFunction{iip, specialize, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, - Any, typeof(_colorvec), Any}(_f, analytic, tgrad, jac, jvp, + Any, typeof(_colorvec), Any, Any, Any}(_f, analytic, tgrad, jac, jvp, vjp, jac_prototype, sparsity, Wfact, Wfact_t, paramjac, observed, - _colorvec, sys) + _colorvec, sys, initializeprob, initializeprobmap) else DAEFunction{iip, specialize, typeof(_f), typeof(analytic), typeof(tgrad), typeof(jac), typeof(jvp), typeof(vjp), typeof(jac_prototype), typeof(sparsity), typeof(Wfact), typeof(Wfact_t), typeof(paramjac), typeof(observed), typeof(_colorvec), - typeof(sys)}(_f, analytic, tgrad, jac, jvp, vjp, + typeof(sys), typeof(initializeprob), typeof(initializeprobmap)}( + _f, analytic, tgrad, jac, jvp, vjp, jac_prototype, sparsity, Wfact, Wfact_t, paramjac, observed, - _colorvec, sys) + _colorvec, sys, initializeprob, initializeprobmap) end end @@ -3945,6 +3966,8 @@ __has_colorvec(f) = isdefined(f, :colorvec) __has_sys(f) = isdefined(f, :sys) __has_analytic_full(f) = isdefined(f, :analytic_full) __has_resid_prototype(f) = isdefined(f, :resid_prototype) +__has_initializeprob(f) = isdefined(f, :initializeprob) +__has_initializeprobmap(f) = isdefined(f, :initializeprobmap) # compatibility has_invW(f::AbstractSciMLFunction) = false @@ -3957,6 +3980,13 @@ has_Wfact(f::AbstractSciMLFunction) = __has_Wfact(f) && f.Wfact !== nothing has_Wfact_t(f::AbstractSciMLFunction) = __has_Wfact_t(f) && f.Wfact_t !== nothing has_paramjac(f::AbstractSciMLFunction) = __has_paramjac(f) && f.paramjac !== nothing has_sys(f::AbstractSciMLFunction) = __has_sys(f) && f.sys !== nothing +function has_initializeprob(f::AbstractSciMLFunction) + __has_initializeprob(f) && f.initializeprob !== nothing +end +function has_initializeprobmap(f::AbstractSciMLFunction) + __has_initializeprobmap(f) && f.initializeprobmap !== nothing +end + function has_syms(f::AbstractSciMLFunction) if __has_syms(f) f.syms !== nothing