Skip to content

Commit

Permalink
handle DAEFunction as well
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisRackauckas committed Feb 24, 2024
1 parent d363171 commit d571a82
Showing 1 changed file with 18 additions and 7 deletions.
25 changes: 18 additions & 7 deletions src/scimlfunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1524,7 +1524,7 @@ automatically symbolically generating the Jacobian and more from the
numerically-defined functions.
"""
struct DAEFunction{iip, specialize, F, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TWt, TPJ, O, TCV,
SYS} <:
SYS, IProb, IProbMap} <:
AbstractDAEFunction{iip}
f::F
analytic::Ta
Expand All @@ -1540,6 +1540,8 @@ struct DAEFunction{iip, specialize, F, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TWt, TP
observed::O
colorvec::TCV
sys::SYS
initializeprob::IProb
initializeprobmap::IProbMap
end

TruncatedStacktraces.@truncate_stacktrace DAEFunction 1 2
Expand Down Expand Up @@ -2279,8 +2281,8 @@ function ODEFunction{iip, specialize}(f;
DEFAULT_OBSERVED,
colorvec = __has_colorvec(f) ? f.colorvec : nothing,
sys = __has_sys(f) ? f.sys : nothing,
initializeprob = nothing,
initializeprobmap = nothing
initializeprob = _has_initializeprob(f) ? f.sys : nothing,
initializeprobmap = _has_initializeprobmap(f) ? f.sys : nothing
) where {iip,
specialize,
}
Expand Down Expand Up @@ -2328,7 +2330,7 @@ function ODEFunction{iip, specialize}(f;

sys = something(sys, SymbolCache(syms, paramsyms, indepsym))

Check warning on line 2331 in src/scimlfunctions.jl

View check run for this annotation

Codecov / codecov/patch

src/scimlfunctions.jl#L2331

Added line #L2331 was not covered by tests

@assert typeof(initializeprob) <: Union{NonlinearProblem, NonlinearLeastSquaresProblem}
@assert typeof(initializeprob) <: Union{Nothing, NonlinearProblem, NonlinearLeastSquaresProblem}

Check warning on line 2333 in src/scimlfunctions.jl

View check run for this annotation

Codecov / codecov/patch

src/scimlfunctions.jl#L2333

Added line #L2333 was not covered by tests

if specialize === NoSpecialize
ODEFunction{iip, specialize,
Expand Down Expand Up @@ -3189,7 +3191,9 @@ function DAEFunction{iip, specialize}(f;
observed = __has_observed(f) ? f.observed :
DEFAULT_OBSERVED,
colorvec = __has_colorvec(f) ? f.colorvec : nothing,
sys = __has_sys(f) ? f.sys : nothing) where {iip,
sys = __has_sys(f) ? f.sys : nothing,
initializeprob = _has_initializeprob(f) ? f.sys : nothing,
initializeprobmap = _has_initializeprobmap(f) ? f.sys : nothing) where {iip,
specialize
}
if jac === nothing && isa(jac_prototype, AbstractSciMLOperator)
Expand Down Expand Up @@ -3221,14 +3225,16 @@ function DAEFunction{iip, specialize}(f;
_f = prepare_function(f)
sys = sys_or_symbolcache(sys, syms, paramsyms, indepsym)

@assert typeof(initializeprob) <: Union{Nothing, NonlinearProblem, NonlinearLeastSquaresProblem}

Check warning on line 3228 in src/scimlfunctions.jl

View check run for this annotation

Codecov / codecov/patch

src/scimlfunctions.jl#L3228

Added line #L3228 was not covered by tests

if specialize === NoSpecialize
DAEFunction{iip, specialize, Any, Any, Any,
Any, Any, Any, Any, Any,
Any, Any, Any,
Any, typeof(_colorvec), Any}(_f, analytic, tgrad, jac, jvp,
vjp, jac_prototype, sparsity,
Wfact, Wfact_t, paramjac, observed,
_colorvec, sys)
_colorvec, sys, initializeprob, initializeprobmap)
else
DAEFunction{iip, specialize, typeof(_f), typeof(analytic), typeof(tgrad),
typeof(jac), typeof(jvp), typeof(vjp), typeof(jac_prototype),
Expand All @@ -3238,7 +3244,7 @@ function DAEFunction{iip, specialize}(f;
typeof(sys)}(_f, analytic, tgrad, jac, jvp, vjp,
jac_prototype, sparsity, Wfact, Wfact_t,
paramjac, observed,
_colorvec, sys)
_colorvec, sys, initializeprob, initializeprobmap)
end
end

Expand Down Expand Up @@ -3957,6 +3963,8 @@ __has_colorvec(f) = isdefined(f, :colorvec)
__has_sys(f) = isdefined(f, :sys)
__has_analytic_full(f) = isdefined(f, :analytic_full)
__has_resid_prototype(f) = isdefined(f, :resid_prototype)
__has_initializeprob(f) = isdefined(f, :initializeprob)
__has_initializeprobmap(f) = isdefined(f, :initializeprobmap)

Check warning on line 3967 in src/scimlfunctions.jl

View check run for this annotation

Codecov / codecov/patch

src/scimlfunctions.jl#L3966-L3967

Added lines #L3966 - L3967 were not covered by tests

# compatibility
has_invW(f::AbstractSciMLFunction) = false
Expand All @@ -3969,6 +3977,9 @@ has_Wfact(f::AbstractSciMLFunction) = __has_Wfact(f) && f.Wfact !== nothing
has_Wfact_t(f::AbstractSciMLFunction) = __has_Wfact_t(f) && f.Wfact_t !== nothing
has_paramjac(f::AbstractSciMLFunction) = __has_paramjac(f) && f.paramjac !== nothing
has_sys(f::AbstractSciMLFunction) = __has_sys(f) && f.sys !== nothing
has_initializeprob(f::AbstractSciMLFunction) = __has_initializeprob(f) && f.initializeprob !== nothing
has_initializeprobmap(f::AbstractSciMLFunction) = __has_initializeprobmap(f) && f.initializeprobmap !== nothing

Check warning on line 3981 in src/scimlfunctions.jl

View check run for this annotation

Codecov / codecov/patch

src/scimlfunctions.jl#L3980-L3981

Added lines #L3980 - L3981 were not covered by tests

function has_syms(f::AbstractSciMLFunction)
if __has_syms(f)
f.syms !== nothing
Expand Down

0 comments on commit d571a82

Please sign in to comment.