From d571a8298d80a42377f61959b439ce423b23f3f6 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Sat, 24 Feb 2024 13:23:40 -0500 Subject: [PATCH] handle DAEFunction as well --- src/scimlfunctions.jl | 25 ++++++++++++++++++------- 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/src/scimlfunctions.jl b/src/scimlfunctions.jl index a495e9dff..cf57ad0b8 100644 --- a/src/scimlfunctions.jl +++ b/src/scimlfunctions.jl @@ -1524,7 +1524,7 @@ automatically symbolically generating the Jacobian and more from the numerically-defined functions. """ struct DAEFunction{iip, specialize, F, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TWt, TPJ, O, TCV, - SYS} <: + SYS, IProb, IProbMap} <: AbstractDAEFunction{iip} f::F analytic::Ta @@ -1540,6 +1540,8 @@ struct DAEFunction{iip, specialize, F, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TWt, TP observed::O colorvec::TCV sys::SYS + initializeprob::IProb + initializeprobmap::IProbMap end TruncatedStacktraces.@truncate_stacktrace DAEFunction 1 2 @@ -2279,8 +2281,8 @@ function ODEFunction{iip, specialize}(f; DEFAULT_OBSERVED, colorvec = __has_colorvec(f) ? f.colorvec : nothing, sys = __has_sys(f) ? f.sys : nothing, - initializeprob = nothing, - initializeprobmap = nothing + initializeprob = _has_initializeprob(f) ? f.sys : nothing, + initializeprobmap = _has_initializeprobmap(f) ? f.sys : nothing ) where {iip, specialize, } @@ -2328,7 +2330,7 @@ function ODEFunction{iip, specialize}(f; sys = something(sys, SymbolCache(syms, paramsyms, indepsym)) - @assert typeof(initializeprob) <: Union{NonlinearProblem, NonlinearLeastSquaresProblem} + @assert typeof(initializeprob) <: Union{Nothing, NonlinearProblem, NonlinearLeastSquaresProblem} if specialize === NoSpecialize ODEFunction{iip, specialize, @@ -3189,7 +3191,9 @@ function DAEFunction{iip, specialize}(f; observed = __has_observed(f) ? f.observed : DEFAULT_OBSERVED, colorvec = __has_colorvec(f) ? f.colorvec : nothing, - sys = __has_sys(f) ? f.sys : nothing) where {iip, + sys = __has_sys(f) ? f.sys : nothing, + initializeprob = _has_initializeprob(f) ? f.sys : nothing, + initializeprobmap = _has_initializeprobmap(f) ? f.sys : nothing) where {iip, specialize } if jac === nothing && isa(jac_prototype, AbstractSciMLOperator) @@ -3221,6 +3225,8 @@ function DAEFunction{iip, specialize}(f; _f = prepare_function(f) sys = sys_or_symbolcache(sys, syms, paramsyms, indepsym) + @assert typeof(initializeprob) <: Union{Nothing, NonlinearProblem, NonlinearLeastSquaresProblem} + if specialize === NoSpecialize DAEFunction{iip, specialize, Any, Any, Any, Any, Any, Any, Any, Any, @@ -3228,7 +3234,7 @@ function DAEFunction{iip, specialize}(f; Any, typeof(_colorvec), Any}(_f, analytic, tgrad, jac, jvp, vjp, jac_prototype, sparsity, Wfact, Wfact_t, paramjac, observed, - _colorvec, sys) + _colorvec, sys, initializeprob, initializeprobmap) else DAEFunction{iip, specialize, typeof(_f), typeof(analytic), typeof(tgrad), typeof(jac), typeof(jvp), typeof(vjp), typeof(jac_prototype), @@ -3238,7 +3244,7 @@ function DAEFunction{iip, specialize}(f; typeof(sys)}(_f, analytic, tgrad, jac, jvp, vjp, jac_prototype, sparsity, Wfact, Wfact_t, paramjac, observed, - _colorvec, sys) + _colorvec, sys, initializeprob, initializeprobmap) end end @@ -3957,6 +3963,8 @@ __has_colorvec(f) = isdefined(f, :colorvec) __has_sys(f) = isdefined(f, :sys) __has_analytic_full(f) = isdefined(f, :analytic_full) __has_resid_prototype(f) = isdefined(f, :resid_prototype) +__has_initializeprob(f) = isdefined(f, :initializeprob) +__has_initializeprobmap(f) = isdefined(f, :initializeprobmap) # compatibility has_invW(f::AbstractSciMLFunction) = false @@ -3969,6 +3977,9 @@ has_Wfact(f::AbstractSciMLFunction) = __has_Wfact(f) && f.Wfact !== nothing has_Wfact_t(f::AbstractSciMLFunction) = __has_Wfact_t(f) && f.Wfact_t !== nothing has_paramjac(f::AbstractSciMLFunction) = __has_paramjac(f) && f.paramjac !== nothing has_sys(f::AbstractSciMLFunction) = __has_sys(f) && f.sys !== nothing +has_initializeprob(f::AbstractSciMLFunction) = __has_initializeprob(f) && f.initializeprob !== nothing +has_initializeprobmap(f::AbstractSciMLFunction) = __has_initializeprobmap(f) && f.initializeprobmap !== nothing + function has_syms(f::AbstractSciMLFunction) if __has_syms(f) f.syms !== nothing