From 4ee070cff0cc49b06901434cef26d736a74b5497 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 29 Nov 2024 13:28:37 +0530 Subject: [PATCH] feat: add initialization support --- src/integrators/type.jl | 3 ++- src/solve.jl | 21 +++++++++++++++-- test/integrators/initialization.jl | 38 ++++++++++++++++++++++++++++++ test/runtests.jl | 3 +++ 4 files changed, 62 insertions(+), 3 deletions(-) create mode 100644 test/integrators/initialization.jl diff --git a/src/integrators/type.jl b/src/integrators/type.jl index 98f1bfb..06ecda2 100644 --- a/src/integrators/type.jl +++ b/src/integrators/type.jl @@ -33,7 +33,7 @@ mutable struct DDEIntegrator{algType, IIP, uType, tType, P, eigenType, tTypeNoUn ksEltype, SolType, F, CacheType, IType, FP, O, dAbsType, dRelType, H, tstopsType, discType, FSALType, EventErrorType, - CallbackCacheType, DV} <: + CallbackCacheType, DV, IA} <: AbstractDDEIntegrator{algType, IIP, uType, tType} sol::SolType u::uType @@ -95,6 +95,7 @@ mutable struct DDEIntegrator{algType, IIP, uType, tType, P, eigenType, tTypeNoUn integrator::IType fsalfirst::FSALType fsallast::FSALType + initializealg::IA end function (integrator::DDEIntegrator)(t, deriv::Type = Val{0}; idxs = nothing) diff --git a/src/solve.jl b/src/solve.jl index c8d7437..ea4f5d0 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -66,6 +66,7 @@ function DiffEqBase.__init(prob::DiffEqBase.AbstractDDEProblem, discontinuity_interp_points::Int = 10, discontinuity_abstol = eltype(prob.tspan)(1 // Int64(10)^12), discontinuity_reltol = 0, + initializealg = DDEDefaultInit(), kwargs...) if haskey(kwargs, :initial_order) @warn "initial_order has been deprecated. Please specify order_discontinuity_t0 in the DDEProblem instead." @@ -350,7 +351,7 @@ function DiffEqBase.__init(prob::DiffEqBase.AbstractDDEProblem, typeof(d_discontinuities_propagated), typeof(fsalfirst), typeof(last_event_error), typeof(callback_cache), - typeof(differential_vars)}(sol, u, k, + typeof(differential_vars), typeof(initializealg)}(sol, u, k, t0, tType(dt), f_with_history, @@ -402,10 +403,11 @@ function DiffEqBase.__init(prob::DiffEqBase.AbstractDDEProblem, stats, history, differential_vars, - ode_integrator, fsalfirst, fsallast) + ode_integrator, fsalfirst, fsallast, initializealg) # initialize DDE integrator if initialize_integrator + DiffEqBase.initialize_dae!(integrator) initialize_solution!(integrator) OrdinaryDiffEqCore.initialize_callbacks!(integrator, initialize_save) OrdinaryDiffEqCore.initialize!(integrator) @@ -538,3 +540,18 @@ function initialize_tstops_d_discontinuities_propagated(::Type{T}, tstops, return tstops_propagated, d_discontinuities_propagated end + +struct DDEDefaultInit <: DiffEqBase.DAEInitializationAlgorithm end + +function DiffEqBase.initialize_dae!(integrator::DDEIntegrator, initializealg = integrator.initializealg) + OrdinaryDiffEqCore._initialize_dae!(integrator, integrator.sol.prob, initializealg, + Val(DiffEqBase.isinplace(integrator.sol.prob))) +end + +function OrdinaryDiffEqCore._initialize_dae!(integrator::DDEIntegrator, prob, ::DDEDefaultInit, isinplace) + if SciMLBase.has_initializeprob(prob.f) + OrdinaryDiffEqCore._initialize_dae!(integrator, prob, SciMLBase.OverrideInit(), isinplace) + else + OrdinaryDiffEqCore._initialize_dae!(integrator, prob, SciMLBase.CheckInit(), isinplace) + end +end diff --git a/test/integrators/initialization.jl b/test/integrators/initialization.jl new file mode 100644 index 0000000..3d94453 --- /dev/null +++ b/test/integrators/initialization.jl @@ -0,0 +1,38 @@ +using DelayDiffEq +using SciMLBase +using LinearAlgebra +using Test + +@testset "CheckInit" begin + u0_good = [0.99, 0.01, 0.0] + sir_history(p, t) = [1.0, 0.0, 0.0] + tspan = (0.0, 40.0) + p = (γ = 0.5, τ = 4.0) + + function sir_ddae!(du, u, h, p, t) + S, I, R = u + γ, τ = p + infection = γ * I * S + Sd, Id, _ = h(p, t - τ) + recovery = γ * Id * Sd + @inbounds begin + du[1] = -infection + du[2] = infection - recovery + du[3] = S + I + R - 1 + end + nothing + end + + prob_ddae = DDEProblem( + DDEFunction{true}(sir_ddae!; + mass_matrix = Diagonal([1.0, 1.0, 0.0])), + u0, + sir_history, + tspan, + p; + constant_lags = (p.τ,)) + alg = MethodOfSteps(Rosenbrock23()) + @test_nowarn init(prob_ddae, alg) + prob.u0[1] = 2.0 + @test_throws SciMLBase.CheckInitFailureError init(prob_ddae, alg) +end diff --git a/test/runtests.jl b/test/runtests.jl index ea629d2..afffc0e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -84,6 +84,9 @@ if GROUP == "All" || GROUP == "Integrators" @time @safetestset "Verner Tests" begin include("integrators/verner.jl") end + @time @safetestset "Initialization" begin + include("integrators/initialization.jl") + end end if GROUP == "All" || GROUP == "Regression"