From 4ee070cff0cc49b06901434cef26d736a74b5497 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 29 Nov 2024 13:28:37 +0530 Subject: [PATCH 1/5] feat: add initialization support --- src/integrators/type.jl | 3 ++- src/solve.jl | 21 +++++++++++++++-- test/integrators/initialization.jl | 38 ++++++++++++++++++++++++++++++ test/runtests.jl | 3 +++ 4 files changed, 62 insertions(+), 3 deletions(-) create mode 100644 test/integrators/initialization.jl diff --git a/src/integrators/type.jl b/src/integrators/type.jl index 98f1bfb..06ecda2 100644 --- a/src/integrators/type.jl +++ b/src/integrators/type.jl @@ -33,7 +33,7 @@ mutable struct DDEIntegrator{algType, IIP, uType, tType, P, eigenType, tTypeNoUn ksEltype, SolType, F, CacheType, IType, FP, O, dAbsType, dRelType, H, tstopsType, discType, FSALType, EventErrorType, - CallbackCacheType, DV} <: + CallbackCacheType, DV, IA} <: AbstractDDEIntegrator{algType, IIP, uType, tType} sol::SolType u::uType @@ -95,6 +95,7 @@ mutable struct DDEIntegrator{algType, IIP, uType, tType, P, eigenType, tTypeNoUn integrator::IType fsalfirst::FSALType fsallast::FSALType + initializealg::IA end function (integrator::DDEIntegrator)(t, deriv::Type = Val{0}; idxs = nothing) diff --git a/src/solve.jl b/src/solve.jl index c8d7437..ea4f5d0 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -66,6 +66,7 @@ function DiffEqBase.__init(prob::DiffEqBase.AbstractDDEProblem, discontinuity_interp_points::Int = 10, discontinuity_abstol = eltype(prob.tspan)(1 // Int64(10)^12), discontinuity_reltol = 0, + initializealg = DDEDefaultInit(), kwargs...) if haskey(kwargs, :initial_order) @warn "initial_order has been deprecated. Please specify order_discontinuity_t0 in the DDEProblem instead." @@ -350,7 +351,7 @@ function DiffEqBase.__init(prob::DiffEqBase.AbstractDDEProblem, typeof(d_discontinuities_propagated), typeof(fsalfirst), typeof(last_event_error), typeof(callback_cache), - typeof(differential_vars)}(sol, u, k, + typeof(differential_vars), typeof(initializealg)}(sol, u, k, t0, tType(dt), f_with_history, @@ -402,10 +403,11 @@ function DiffEqBase.__init(prob::DiffEqBase.AbstractDDEProblem, stats, history, differential_vars, - ode_integrator, fsalfirst, fsallast) + ode_integrator, fsalfirst, fsallast, initializealg) # initialize DDE integrator if initialize_integrator + DiffEqBase.initialize_dae!(integrator) initialize_solution!(integrator) OrdinaryDiffEqCore.initialize_callbacks!(integrator, initialize_save) OrdinaryDiffEqCore.initialize!(integrator) @@ -538,3 +540,18 @@ function initialize_tstops_d_discontinuities_propagated(::Type{T}, tstops, return tstops_propagated, d_discontinuities_propagated end + +struct DDEDefaultInit <: DiffEqBase.DAEInitializationAlgorithm end + +function DiffEqBase.initialize_dae!(integrator::DDEIntegrator, initializealg = integrator.initializealg) + OrdinaryDiffEqCore._initialize_dae!(integrator, integrator.sol.prob, initializealg, + Val(DiffEqBase.isinplace(integrator.sol.prob))) +end + +function OrdinaryDiffEqCore._initialize_dae!(integrator::DDEIntegrator, prob, ::DDEDefaultInit, isinplace) + if SciMLBase.has_initializeprob(prob.f) + OrdinaryDiffEqCore._initialize_dae!(integrator, prob, SciMLBase.OverrideInit(), isinplace) + else + OrdinaryDiffEqCore._initialize_dae!(integrator, prob, SciMLBase.CheckInit(), isinplace) + end +end diff --git a/test/integrators/initialization.jl b/test/integrators/initialization.jl new file mode 100644 index 0000000..3d94453 --- /dev/null +++ b/test/integrators/initialization.jl @@ -0,0 +1,38 @@ +using DelayDiffEq +using SciMLBase +using LinearAlgebra +using Test + +@testset "CheckInit" begin + u0_good = [0.99, 0.01, 0.0] + sir_history(p, t) = [1.0, 0.0, 0.0] + tspan = (0.0, 40.0) + p = (γ = 0.5, τ = 4.0) + + function sir_ddae!(du, u, h, p, t) + S, I, R = u + γ, τ = p + infection = γ * I * S + Sd, Id, _ = h(p, t - τ) + recovery = γ * Id * Sd + @inbounds begin + du[1] = -infection + du[2] = infection - recovery + du[3] = S + I + R - 1 + end + nothing + end + + prob_ddae = DDEProblem( + DDEFunction{true}(sir_ddae!; + mass_matrix = Diagonal([1.0, 1.0, 0.0])), + u0, + sir_history, + tspan, + p; + constant_lags = (p.τ,)) + alg = MethodOfSteps(Rosenbrock23()) + @test_nowarn init(prob_ddae, alg) + prob.u0[1] = 2.0 + @test_throws SciMLBase.CheckInitFailureError init(prob_ddae, alg) +end diff --git a/test/runtests.jl b/test/runtests.jl index ea629d2..afffc0e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -84,6 +84,9 @@ if GROUP == "All" || GROUP == "Integrators" @time @safetestset "Verner Tests" begin include("integrators/verner.jl") end + @time @safetestset "Initialization" begin + include("integrators/initialization.jl") + end end if GROUP == "All" || GROUP == "Regression" From ee896020cf66ee329e8fb73122b5a114fd8958cb Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 10 Dec 2024 15:34:43 +0530 Subject: [PATCH 2/5] build: add SII as dependency --- Project.toml | 2 ++ src/DelayDiffEq.jl | 1 + 2 files changed, 3 insertions(+) diff --git a/Project.toml b/Project.toml index 3064a29..a79ffbb 100644 --- a/Project.toml +++ b/Project.toml @@ -21,6 +21,7 @@ Reexport = "189a3867-3050-52da-a836-e630ba90ab69" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" SimpleNonlinearSolve = "727e6d20-b764-4bd8-a329-72de5adea6c7" SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a" +SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5" [compat] ArrayInterface = "7" @@ -37,4 +38,5 @@ Reexport = "0.2, 1.0" SciMLBase = "2.59.2" SimpleNonlinearSolve = "0.1, 1, 2" SimpleUnPack = "1" +SymbolicIndexingInterface = "0.3.36" julia = "1.9" diff --git a/src/DelayDiffEq.jl b/src/DelayDiffEq.jl index f6a7a39..9ecc323 100644 --- a/src/DelayDiffEq.jl +++ b/src/DelayDiffEq.jl @@ -13,6 +13,7 @@ using SimpleUnPack import ArrayInterface import SimpleNonlinearSolve +import SymbolicIndexingInterface as SII using DiffEqBase: AbstractDDEAlgorithm, AbstractDDEIntegrator, AbstractODEIntegrator, DEIntegrator, AbstractDDEProblem From 01eb277c2b45d7ef49631cd847df2c95f20f405f Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 10 Dec 2024 15:34:57 +0530 Subject: [PATCH 3/5] feat: implement `SII.get_history_function` for `DDEIntegrator` --- src/integrators/type.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/integrators/type.jl b/src/integrators/type.jl index 06ecda2..bda3a8c 100644 --- a/src/integrators/type.jl +++ b/src/integrators/type.jl @@ -106,3 +106,7 @@ function (integrator::DDEIntegrator)(val::AbstractArray, t::Union{Number, Abstra deriv::Type = Val{0}; idxs = nothing) OrdinaryDiffEq.current_interpolant!(val, t, integrator, idxs, deriv) end + +function SII.get_history_function(integrator::DDEIntegrator) + return integrator.history +end From 2c3231d3997b0d3282b6033a58736468540dddf8 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 10 Dec 2024 15:38:00 +0530 Subject: [PATCH 4/5] feat: add `initialization_data` to `ODEFunctionWrapper` --- src/functionwrapper.jl | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/functionwrapper.jl b/src/functionwrapper.jl index 0e97e12..b86aaaf 100644 --- a/src/functionwrapper.jl +++ b/src/functionwrapper.jl @@ -24,7 +24,7 @@ macro wrap_h(signature) end |> esc end -struct ODEFunctionWrapper{iip, F, H, TMM, Ta, Tt, TJ, JP, SP, TW, TWt, TPJ, S, TCV} <: +struct ODEFunctionWrapper{iip, F, H, TMM, Ta, Tt, TJ, JP, SP, TW, TWt, TPJ, S, TCV, ID} <: DiffEqBase.AbstractODEFunction{iip} f::F h::H @@ -39,6 +39,7 @@ struct ODEFunctionWrapper{iip, F, H, TMM, Ta, Tt, TJ, JP, SP, TW, TWt, TPJ, S, T paramjac::TPJ sys::S colorvec::TCV + initialization_data::ID end function ODEFunctionWrapper(f::DiffEqBase.AbstractDDEFunction, h) @@ -51,7 +52,8 @@ function ODEFunctionWrapper(f::DiffEqBase.AbstractDDEFunction, h) typeof(f.analytic), typeof(f.tgrad), typeof(jac), typeof(f.jac_prototype), typeof(f.sparsity), typeof(Wfact), typeof(Wfact_t), - typeof(f.paramjac), typeof(f.sys), typeof(f.colorvec)}(f.f, h, + typeof(f.paramjac), typeof(f.sys), typeof(f.colorvec), + typeof(f.initialization_data)}(f.f, h, f.mass_matrix, f.analytic, f.tgrad, jac, @@ -61,7 +63,8 @@ function ODEFunctionWrapper(f::DiffEqBase.AbstractDDEFunction, h) Wfact_t, f.paramjac, f.sys, - f.colorvec) + f.colorvec, + f.initialization_data) end (f::ODEFunctionWrapper{true})(du, u, p, t) = f.f(du, u, f.h, p, t) From 12957d031852047c844ced521bd410e0ccf28e42 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Sat, 14 Dec 2024 11:11:24 +0530 Subject: [PATCH 5/5] build: bump SciMLBase compat --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index a79ffbb..ddcb178 100644 --- a/Project.toml +++ b/Project.toml @@ -35,7 +35,7 @@ OrdinaryDiffEqNonlinearSolve = "1.2.2" OrdinaryDiffEqRosenbrock = "1.2.0" RecursiveArrayTools = "2, 3" Reexport = "0.2, 1.0" -SciMLBase = "2.59.2" +SciMLBase = "2.68" SimpleNonlinearSolve = "0.1, 1, 2" SimpleUnPack = "1" SymbolicIndexingInterface = "0.3.36"