Skip to content

Commit

Permalink
feat: add initialization support
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Nov 29, 2024
1 parent aaefbb7 commit 4ee070c
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 3 deletions.
3 changes: 2 additions & 1 deletion src/integrators/type.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ mutable struct DDEIntegrator{algType, IIP, uType, tType, P, eigenType, tTypeNoUn
ksEltype, SolType, F, CacheType, IType, FP, O, dAbsType,
dRelType, H,
tstopsType, discType, FSALType, EventErrorType,
CallbackCacheType, DV} <:
CallbackCacheType, DV, IA} <:
AbstractDDEIntegrator{algType, IIP, uType, tType}
sol::SolType
u::uType
Expand Down Expand Up @@ -95,6 +95,7 @@ mutable struct DDEIntegrator{algType, IIP, uType, tType, P, eigenType, tTypeNoUn
integrator::IType
fsalfirst::FSALType
fsallast::FSALType
initializealg::IA
end

function (integrator::DDEIntegrator)(t, deriv::Type = Val{0}; idxs = nothing)
Expand Down
21 changes: 19 additions & 2 deletions src/solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ function DiffEqBase.__init(prob::DiffEqBase.AbstractDDEProblem,
discontinuity_interp_points::Int = 10,
discontinuity_abstol = eltype(prob.tspan)(1 // Int64(10)^12),
discontinuity_reltol = 0,
initializealg = DDEDefaultInit(),
kwargs...)
if haskey(kwargs, :initial_order)
@warn "initial_order has been deprecated. Please specify order_discontinuity_t0 in the DDEProblem instead."
Expand Down Expand Up @@ -350,7 +351,7 @@ function DiffEqBase.__init(prob::DiffEqBase.AbstractDDEProblem,
typeof(d_discontinuities_propagated),
typeof(fsalfirst),
typeof(last_event_error), typeof(callback_cache),
typeof(differential_vars)}(sol, u, k,
typeof(differential_vars), typeof(initializealg)}(sol, u, k,
t0,
tType(dt),
f_with_history,
Expand Down Expand Up @@ -402,10 +403,11 @@ function DiffEqBase.__init(prob::DiffEqBase.AbstractDDEProblem,
stats,
history,
differential_vars,
ode_integrator, fsalfirst, fsallast)
ode_integrator, fsalfirst, fsallast, initializealg)

# initialize DDE integrator
if initialize_integrator
DiffEqBase.initialize_dae!(integrator)
initialize_solution!(integrator)
OrdinaryDiffEqCore.initialize_callbacks!(integrator, initialize_save)
OrdinaryDiffEqCore.initialize!(integrator)
Expand Down Expand Up @@ -538,3 +540,18 @@ function initialize_tstops_d_discontinuities_propagated(::Type{T}, tstops,

return tstops_propagated, d_discontinuities_propagated
end

struct DDEDefaultInit <: DiffEqBase.DAEInitializationAlgorithm end

function DiffEqBase.initialize_dae!(integrator::DDEIntegrator, initializealg = integrator.initializealg)
OrdinaryDiffEqCore._initialize_dae!(integrator, integrator.sol.prob, initializealg,
Val(DiffEqBase.isinplace(integrator.sol.prob)))
end

function OrdinaryDiffEqCore._initialize_dae!(integrator::DDEIntegrator, prob, ::DDEDefaultInit, isinplace)
if SciMLBase.has_initializeprob(prob.f)
OrdinaryDiffEqCore._initialize_dae!(integrator, prob, SciMLBase.OverrideInit(), isinplace)
else
OrdinaryDiffEqCore._initialize_dae!(integrator, prob, SciMLBase.CheckInit(), isinplace)
end
end
38 changes: 38 additions & 0 deletions test/integrators/initialization.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
using DelayDiffEq
using SciMLBase
using LinearAlgebra
using Test

@testset "CheckInit" begin
u0_good = [0.99, 0.01, 0.0]
sir_history(p, t) = [1.0, 0.0, 0.0]
tspan = (0.0, 40.0)
p == 0.5, τ = 4.0)

function sir_ddae!(du, u, h, p, t)
S, I, R = u
γ, τ = p
infection = γ * I * S
Sd, Id, _ = h(p, t - τ)
recovery = γ * Id * Sd
@inbounds begin
du[1] = -infection
du[2] = infection - recovery
du[3] = S + I + R - 1
end
nothing
end

prob_ddae = DDEProblem(
DDEFunction{true}(sir_ddae!;
mass_matrix = Diagonal([1.0, 1.0, 0.0])),
u0,
sir_history,
tspan,
p;
constant_lags = (p.τ,))
alg = MethodOfSteps(Rosenbrock23())
@test_nowarn init(prob_ddae, alg)
prob.u0[1] = 2.0
@test_throws SciMLBase.CheckInitFailureError init(prob_ddae, alg)
end
3 changes: 3 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,9 @@ if GROUP == "All" || GROUP == "Integrators"
@time @safetestset "Verner Tests" begin
include("integrators/verner.jl")
end
@time @safetestset "Initialization" begin
include("integrators/initialization.jl")
end
end

if GROUP == "All" || GROUP == "Regression"
Expand Down

0 comments on commit 4ee070c

Please sign in to comment.