Skip to content

Commit

Permalink
feat: generalize CheckInit to DDEs
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Dec 10, 2024
1 parent 089e31a commit 9ecdae6
Showing 1 changed file with 23 additions and 2 deletions.
25 changes: 23 additions & 2 deletions src/initialization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,27 @@ function _evaluate_f(integrator, f, isinplace::Val{false}, args...)
return f(args...)
end

"""
Utility function to evaluate the RHS, adding extra arguments (such as history function for
DDEs) wherever necessary.
"""
function evaluate_f(integrator::DEIntegrator, prob, f, isinplace, u, p, t)
return _evaluate_f(integrator, f, isinplace, u, p, t)
end

function evaluate_f(
integrator::DEIntegrator, prob::AbstractDAEProblem, f, isinplace, u, p, t)
return _evaluate_f(integrator, f, isinplace, integrator.du, u, p, t)
end

function evaluate_f(integrator::AbstractDDEIntegrator, prob, f, isinplace, u, p, t)
return _evaluate_f(integrator, f, isinplace, u, get_history_function(integrator), p, t)
end

function evaluate_f(integrator::AbstractSDDEIntegrator, prob, f, isinplace, u, p, t)
return _evaluate_f(integrator, f, isinplace, u, get_history_function(integrator), p, t)
end

"""
$(TYPEDSIGNATURES)
Expand Down Expand Up @@ -147,7 +168,7 @@ function get_initial_values(
algebraic_eqs = [all(iszero, x) for x in eachrow(M)]
(iszero(algebraic_vars) || iszero(algebraic_eqs)) && return u0, p, true
update_coefficients!(M, u0, p, t)
tmp = _evaluate_f(integrator, f, isinplace, u0, p, t)
tmp = evaluate_f(integrator, prob, f, isinplace, u0, p, t)
tmp .= ArrayInterface.restructure(tmp, algebraic_eqs .* _vec(tmp))

normresid = isdefined(integrator.opts, :internalnorm) ?
Expand All @@ -165,7 +186,7 @@ function get_initial_values(
p = parameter_values(integrator)
t = current_time(integrator)

resid = _evaluate_f(integrator, f, isinplace, integrator.du, u0, p, t)
resid = evaluate_f(integrator, prob, f, isinplace, u0, p, t)
normresid = isdefined(integrator.opts, :internalnorm) ?
integrator.opts.internalnorm(resid, t) : norm(resid)

Expand Down

0 comments on commit 9ecdae6

Please sign in to comment.