Skip to content

Commit

Permalink
Merge pull request #633 from SciML/initialize_prob
Browse files Browse the repository at this point in the history
Allow for tagging an initialization problem to ODEFunction/DAEFunction
  • Loading branch information
ChrisRackauckas authored Feb 25, 2024
2 parents 0d0eed9 + a5f2942 commit 7882284
Showing 1 changed file with 49 additions and 19 deletions.
68 changes: 49 additions & 19 deletions src/scimlfunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,7 @@ numerically-defined functions.
"""
struct ODEFunction{iip, specialize, F, TMM, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TWt, WP, TPJ,
O, TCV,
SYS} <: AbstractODEFunction{iip}
SYS, IProb, IProbMap} <: AbstractODEFunction{iip}
f::F
mass_matrix::TMM
analytic::Ta
Expand All @@ -419,6 +419,8 @@ struct ODEFunction{iip, specialize, F, TMM, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TW
observed::O
colorvec::TCV
sys::SYS
initializeprob::IProb
initializeprobmap::IProbMap
end

TruncatedStacktraces.@truncate_stacktrace ODEFunction 1 2
Expand Down Expand Up @@ -1522,7 +1524,7 @@ automatically symbolically generating the Jacobian and more from the
numerically-defined functions.
"""
struct DAEFunction{iip, specialize, F, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TWt, TPJ, O, TCV,
SYS} <:
SYS, IProb, IProbMap} <:
AbstractDAEFunction{iip}
f::F
analytic::Ta
Expand All @@ -1538,6 +1540,8 @@ struct DAEFunction{iip, specialize, F, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TWt, TP
observed::O
colorvec::TCV
sys::SYS
initializeprob::IProb
initializeprobmap::IProbMap
end

TruncatedStacktraces.@truncate_stacktrace DAEFunction 1 2
Expand Down Expand Up @@ -2276,7 +2280,10 @@ function ODEFunction{iip, specialize}(f;
observed = __has_observed(f) ? f.observed :
DEFAULT_OBSERVED,
colorvec = __has_colorvec(f) ? f.colorvec : nothing,
sys = __has_sys(f) ? f.sys : nothing) where {iip,
sys = __has_sys(f) ? f.sys : nothing,
initializeprob = __has_initializeprob(f) ? f.sys : nothing,
initializeprobmap = __has_initializeprobmap(f) ? f.sys : nothing
) where {iip,
specialize
}
if mass_matrix === I && f isa Tuple
Expand Down Expand Up @@ -2321,18 +2328,22 @@ function ODEFunction{iip, specialize}(f;

_f = prepare_function(f)

sys = sys_or_symbolcache(sys, syms, paramsyms, indepsym)
sys = something(sys, SymbolCache(syms, paramsyms, indepsym))

@assert typeof(initializeprob) <:
Union{Nothing, NonlinearProblem, NonlinearLeastSquaresProblem}

if specialize === NoSpecialize
ODEFunction{iip, specialize,
Any, Any, Any, Any,
Any, Any, Any, typeof(jac_prototype),
typeof(sparsity), Any, Any, typeof(W_prototype), Any,
Any,
typeof(_colorvec),
typeof(sys)}(_f, mass_matrix, analytic, tgrad, jac,
typeof(sys), Any, Any}(_f, mass_matrix, analytic, tgrad, jac,
jvp, vjp, jac_prototype, sparsity, Wfact,
Wfact_t, W_prototype, paramjac,
observed, _colorvec, sys)
observed, _colorvec, sys, initializeprob, initializeprobmap)
elseif specialize === false
ODEFunction{iip, FunctionWrapperSpecialize,
typeof(_f), typeof(mass_matrix), typeof(analytic), typeof(tgrad),
Expand All @@ -2341,10 +2352,11 @@ function ODEFunction{iip, specialize}(f;
typeof(paramjac),
typeof(observed),
typeof(_colorvec),
typeof(sys)}(_f, mass_matrix, analytic, tgrad, jac,
typeof(sys), typeof(initializeprob),
typeof(initializeprobmap)}(_f, mass_matrix, analytic, tgrad, jac,
jvp, vjp, jac_prototype, sparsity, Wfact,
Wfact_t, W_prototype, paramjac,
observed, _colorvec, sys)
observed, _colorvec, sys, initializeprob, initializeprobmap)
else
ODEFunction{iip, specialize,
typeof(_f), typeof(mass_matrix), typeof(analytic), typeof(tgrad),
Expand All @@ -2353,10 +2365,11 @@ function ODEFunction{iip, specialize}(f;
typeof(paramjac),
typeof(observed),
typeof(_colorvec),
typeof(sys)}(_f, mass_matrix, analytic, tgrad, jac,
typeof(sys), typeof(initializeprob),
typeof(initializeprobmap)}(_f, mass_matrix, analytic, tgrad, jac,
jvp, vjp, jac_prototype, sparsity, Wfact,
Wfact_t, W_prototype, paramjac,
observed, _colorvec, sys)
observed, _colorvec, sys, initializeprob, initializeprobmap)
end
end

Expand All @@ -2373,21 +2386,23 @@ function unwrapped_f(f::ODEFunction, newf = unwrapped_f(f.f))
Any, Any, Any, Any, typeof(f.jac_prototype),
typeof(f.sparsity), Any, Any, Any,
Any, typeof(f.colorvec),
typeof(f.sys)}(newf, f.mass_matrix, f.analytic, f.tgrad, f.jac,
typeof(f.sys), Any, Any}(newf, f.mass_matrix, f.analytic, f.tgrad, f.jac,
f.jvp, f.vjp, f.jac_prototype, f.sparsity, f.Wfact,
f.Wfact_t, f.W_prototype, f.paramjac,
f.observed, f.colorvec, f.sys)
f.observed, f.colorvec, f.sys, f.initializeprob, f.initializeprobmap)
else
ODEFunction{isinplace(f), specialization(f), typeof(newf), typeof(f.mass_matrix),
typeof(f.analytic), typeof(f.tgrad),
typeof(f.jac), typeof(f.jvp), typeof(f.vjp), typeof(f.jac_prototype),
typeof(f.sparsity), typeof(f.Wfact), typeof(f.Wfact_t), typeof(f.W_prototype),
typeof(f.paramjac),
typeof(f.observed), typeof(f.colorvec),
typeof(f.sys)}(newf, f.mass_matrix, f.analytic, f.tgrad, f.jac,
typeof(f.sys), typeof(f.initializeprob),
typeof(f.initializeprobmap)}(newf, f.mass_matrix, f.analytic, f.tgrad, f.jac,
f.jvp, f.vjp, f.jac_prototype, f.sparsity, f.Wfact,
f.Wfact_t, f.W_prototype, f.paramjac,
f.observed, f.colorvec, f.sys)
f.observed, f.colorvec, f.sys, f.initializeprob,
f.initializeprobmap)
end
end

Expand Down Expand Up @@ -3177,7 +3192,9 @@ function DAEFunction{iip, specialize}(f;
observed = __has_observed(f) ? f.observed :
DEFAULT_OBSERVED,
colorvec = __has_colorvec(f) ? f.colorvec : nothing,
sys = __has_sys(f) ? f.sys : nothing) where {iip,
sys = __has_sys(f) ? f.sys : nothing,
initializeprob = __has_initializeprob(f) ? f.sys : nothing,
initializeprobmap = __has_initializeprobmap(f) ? f.sys : nothing) where {iip,
specialize
}
if jac === nothing && isa(jac_prototype, AbstractSciMLOperator)
Expand Down Expand Up @@ -3209,24 +3226,28 @@ function DAEFunction{iip, specialize}(f;
_f = prepare_function(f)
sys = sys_or_symbolcache(sys, syms, paramsyms, indepsym)

@assert typeof(initializeprob) <:
Union{Nothing, NonlinearProblem, NonlinearLeastSquaresProblem}

if specialize === NoSpecialize
DAEFunction{iip, specialize, Any, Any, Any,
Any, Any, Any, Any, Any,
Any, Any, Any,
Any, typeof(_colorvec), Any}(_f, analytic, tgrad, jac, jvp,
Any, typeof(_colorvec), Any, Any, Any}(_f, analytic, tgrad, jac, jvp,
vjp, jac_prototype, sparsity,
Wfact, Wfact_t, paramjac, observed,
_colorvec, sys)
_colorvec, sys, initializeprob, initializeprobmap)
else
DAEFunction{iip, specialize, typeof(_f), typeof(analytic), typeof(tgrad),
typeof(jac), typeof(jvp), typeof(vjp), typeof(jac_prototype),
typeof(sparsity), typeof(Wfact), typeof(Wfact_t),
typeof(paramjac),
typeof(observed), typeof(_colorvec),
typeof(sys)}(_f, analytic, tgrad, jac, jvp, vjp,
typeof(sys), typeof(initializeprob), typeof(initializeprobmap)}(
_f, analytic, tgrad, jac, jvp, vjp,
jac_prototype, sparsity, Wfact, Wfact_t,
paramjac, observed,
_colorvec, sys)
_colorvec, sys, initializeprob, initializeprobmap)
end
end

Expand Down Expand Up @@ -3945,6 +3966,8 @@ __has_colorvec(f) = isdefined(f, :colorvec)
__has_sys(f) = isdefined(f, :sys)
__has_analytic_full(f) = isdefined(f, :analytic_full)
__has_resid_prototype(f) = isdefined(f, :resid_prototype)
__has_initializeprob(f) = isdefined(f, :initializeprob)
__has_initializeprobmap(f) = isdefined(f, :initializeprobmap)

# compatibility
has_invW(f::AbstractSciMLFunction) = false
Expand All @@ -3957,6 +3980,13 @@ has_Wfact(f::AbstractSciMLFunction) = __has_Wfact(f) && f.Wfact !== nothing
has_Wfact_t(f::AbstractSciMLFunction) = __has_Wfact_t(f) && f.Wfact_t !== nothing
has_paramjac(f::AbstractSciMLFunction) = __has_paramjac(f) && f.paramjac !== nothing
has_sys(f::AbstractSciMLFunction) = __has_sys(f) && f.sys !== nothing
function has_initializeprob(f::AbstractSciMLFunction)
__has_initializeprob(f) && f.initializeprob !== nothing
end
function has_initializeprobmap(f::AbstractSciMLFunction)
__has_initializeprobmap(f) && f.initializeprobmap !== nothing
end

function has_syms(f::AbstractSciMLFunction)
if __has_syms(f)
f.syms !== nothing
Expand Down

0 comments on commit 7882284

Please sign in to comment.