From 9ecdae6fc9fec2df92db2c6c5e8cf63ac7fd865a Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 10 Dec 2024 15:34:12 +0530 Subject: [PATCH] feat: generalize `CheckInit` to DDEs --- src/initialization.jl | 25 +++++++++++++++++++++++-- 1 file changed, 23 insertions(+), 2 deletions(-) diff --git a/src/initialization.jl b/src/initialization.jl index 8b45bb6a6..48f2fe396 100644 --- a/src/initialization.jl +++ b/src/initialization.jl @@ -111,6 +111,27 @@ function _evaluate_f(integrator, f, isinplace::Val{false}, args...) return f(args...) end +""" +Utility function to evaluate the RHS, adding extra arguments (such as history function for +DDEs) wherever necessary. +""" +function evaluate_f(integrator::DEIntegrator, prob, f, isinplace, u, p, t) + return _evaluate_f(integrator, f, isinplace, u, p, t) +end + +function evaluate_f( + integrator::DEIntegrator, prob::AbstractDAEProblem, f, isinplace, u, p, t) + return _evaluate_f(integrator, f, isinplace, integrator.du, u, p, t) +end + +function evaluate_f(integrator::AbstractDDEIntegrator, prob, f, isinplace, u, p, t) + return _evaluate_f(integrator, f, isinplace, u, get_history_function(integrator), p, t) +end + +function evaluate_f(integrator::AbstractSDDEIntegrator, prob, f, isinplace, u, p, t) + return _evaluate_f(integrator, f, isinplace, u, get_history_function(integrator), p, t) +end + """ $(TYPEDSIGNATURES) @@ -147,7 +168,7 @@ function get_initial_values( algebraic_eqs = [all(iszero, x) for x in eachrow(M)] (iszero(algebraic_vars) || iszero(algebraic_eqs)) && return u0, p, true update_coefficients!(M, u0, p, t) - tmp = _evaluate_f(integrator, f, isinplace, u0, p, t) + tmp = evaluate_f(integrator, prob, f, isinplace, u0, p, t) tmp .= ArrayInterface.restructure(tmp, algebraic_eqs .* _vec(tmp)) normresid = isdefined(integrator.opts, :internalnorm) ? @@ -165,7 +186,7 @@ function get_initial_values( p = parameter_values(integrator) t = current_time(integrator) - resid = _evaluate_f(integrator, f, isinplace, integrator.du, u0, p, t) + resid = evaluate_f(integrator, prob, f, isinplace, u0, p, t) normresid = isdefined(integrator.opts, :internalnorm) ? integrator.opts.internalnorm(resid, t) : norm(resid)