From f86d0ea385aefad0839a10d522c46aced9218a51 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 11 Nov 2024 23:09:46 +0530 Subject: [PATCH 1/4] refactor: generalize `_initialize_dae!` to use SciMLBase implementations --- .../src/OrdinaryDiffEqCore.jl | 2 +- lib/OrdinaryDiffEqCore/src/initialize_dae.jl | 123 +++--------------- 2 files changed, 18 insertions(+), 107 deletions(-) diff --git a/lib/OrdinaryDiffEqCore/src/OrdinaryDiffEqCore.jl b/lib/OrdinaryDiffEqCore/src/OrdinaryDiffEqCore.jl index a7cbd95167..ac2e671b21 100644 --- a/lib/OrdinaryDiffEqCore/src/OrdinaryDiffEqCore.jl +++ b/lib/OrdinaryDiffEqCore/src/OrdinaryDiffEqCore.jl @@ -60,7 +60,7 @@ using DiffEqBase: check_error!, @def, _vec, _reshape using FastBroadcast: @.., True, False -using SciMLBase: NoInit, CheckInit, _unwrap_val +using SciMLBase: NoInit, CheckInit, OverrideInit, AbstractDEProblem, _unwrap_val import SciMLBase: alg_order diff --git a/lib/OrdinaryDiffEqCore/src/initialize_dae.jl b/lib/OrdinaryDiffEqCore/src/initialize_dae.jl index ebefa3d91c..8ad99e7f1e 100644 --- a/lib/OrdinaryDiffEqCore/src/initialize_dae.jl +++ b/lib/OrdinaryDiffEqCore/src/initialize_dae.jl @@ -20,16 +20,6 @@ function BrownFullBasicInit(; abstol = 1e-10, nlsolve = nothing) end BrownFullBasicInit(abstol) = BrownFullBasicInit(; abstol = abstol, nlsolve = nothing) -struct OverrideInit{T, F} <: DiffEqBase.DAEInitializationAlgorithm - abstol::T - nlsolve::F -end - -function OverrideInit(; abstol = 1e-10, nlsolve = nothing) - OverrideInit(abstol, nlsolve) -end -OverrideInit(abstol) = OverrideInit(; abstol = abstol, nlsolve = nothing) - ## Notes #= @@ -143,19 +133,15 @@ end ## NoInit -function _initialize_dae!(integrator, prob::Union{ODEProblem, DAEProblem}, +function _initialize_dae!(integrator, prob::AbstractDEProblem, alg::NoInit, x::Union{Val{true}, Val{false}}) end ## OverrideInit -function _initialize_dae!(integrator, prob::Union{ODEProblem, DAEProblem}, +function _initialize_dae!(integrator, prob::AbstractDEProblem, alg::OverrideInit, isinplace::Union{Val{true}, Val{false}}) - initializeprob = prob.f.initializeprob - - if SciMLBase.has_update_initializeprob!(prob.f) - prob.f.update_initializeprob!(initializeprob, prob) - end + initializeprob = prob.f.initialization_data.initializeprob # If it doesn't have autodiff, assume it comes from symbolic system like ModelingToolkit # Since then it's the case of not a DAE but has initializeprob @@ -168,105 +154,30 @@ function _initialize_dae!(integrator, prob::Union{ODEProblem, DAEProblem}, true end - alg = default_nlsolve(alg.nlsolve, isinplace, initializeprob.u0, initializeprob, isAD) - nlsol = solve(initializeprob, alg, abstol = integrator.opts.abstol, reltol = integrator.opts.reltol) + nlsolve_alg = default_nlsolve(alg.nlsolve, isinplace, initializeprob.u0, initializeprob, isAD) + + u0, p, success = SciMLBase.get_initial_values(prob, prob.f, integrator, alg, isinplace; nlsolve_alg, abstol = integrator.opts.abstol, reltol = integrator.opts.reltol) + if isinplace === Val{true}() - integrator.u .= prob.f.initializeprobmap(nlsol) + integrator.u .= u0 elseif isinplace === Val{false}() - integrator.u = prob.f.initializeprobmap(nlsol) + integrator.u = u0 else error("Unreachable reached. Report this error.") end - if SciMLBase.has_initializeprobpmap(prob.f) - integrator.p = prob.f.initializeprobpmap(prob, nlsol) - sol = integrator.sol - @reset sol.prob.p = integrator.p - integrator.sol = sol - end + integrator.p = p + sol = integrator.sol + @reset sol.prob.p = integrator.p + integrator.sol = sol - if nlsol.retcode != ReturnCode.Success + if !success integrator.sol = SciMLBase.solution_new_retcode(integrator.sol, ReturnCode.InitialFailure) end end ## CheckInit -struct CheckInitFailureError <: Exception - normresid::Any - abstol::Any -end - -function Base.showerror(io::IO, e::CheckInitFailureError) - print(io, - "CheckInit specified but initialization not satisifed. normresid = $(e.normresid) > abstol = $(e.abstol)") -end - -function _initialize_dae!(integrator, prob::ODEProblem, alg::CheckInit, - isinplace::Val{true}) - @unpack p, t, f = integrator - M = integrator.f.mass_matrix - tmp = first(get_tmp_cache(integrator)) - u0 = integrator.u - - algebraic_vars = [all(iszero, x) for x in eachcol(M)] - algebraic_eqs = [all(iszero, x) for x in eachrow(M)] - (iszero(algebraic_vars) || iszero(algebraic_eqs)) && return - update_coefficients!(M, u0, p, t) - f(tmp, u0, p, t) - tmp .= ArrayInterface.restructure(tmp, algebraic_eqs .* _vec(tmp)) - - normresid = integrator.opts.internalnorm(tmp, t) - if normresid > integrator.opts.abstol - throw(CheckInitFailureError(normresid, integrator.opts.abstol)) - end -end - -function _initialize_dae!(integrator, prob::ODEProblem, alg::CheckInit, - isinplace::Val{false}) - @unpack p, t, f = integrator - u0 = integrator.u - M = integrator.f.mass_matrix - - algebraic_vars = [all(iszero, x) for x in eachcol(M)] - algebraic_eqs = [all(iszero, x) for x in eachrow(M)] - (iszero(algebraic_vars) || iszero(algebraic_eqs)) && return - update_coefficients!(M, u0, p, t) - du = f(u0, p, t) - resid = _vec(du)[algebraic_eqs] - - normresid = integrator.opts.internalnorm(resid, t) - if normresid > integrator.opts.abstol - throw(CheckInitFailureError(normresid, integrator.opts.abstol)) - end -end - -function _initialize_dae!(integrator, prob::DAEProblem, - alg::CheckInit, isinplace::Val{true}) - @unpack p, t, f = integrator - u0 = integrator.u - resid = get_tmp_cache(integrator)[2] - - f(resid, integrator.du, u0, p, t) - normresid = integrator.opts.internalnorm(resid, t) - if normresid > integrator.opts.abstol - throw(CheckInitFailureError(normresid, integrator.opts.abstol)) - end -end - -function _initialize_dae!(integrator, prob::DAEProblem, - alg::CheckInit, isinplace::Val{false}) - @unpack p, t, f = integrator - u0 = integrator.u - - nlequation_oop = u -> begin - f((u - u0) / dt, u, p, t) - end - - nlequation = (u, _) -> nlequation_oop(u) - - resid = f(integrator.du, u0, p, t) - normresid = integrator.opts.internalnorm(resid, t) - if normresid > integrator.opts.abstol - throw(CheckInitFailureError(normresid, integrator.opts.abstol)) - end +function _initialize_dae!(integrator, prob::AbstractDEProblem, alg::CheckInit, + isinplace::Union{Val{true}, Val{false}}) + SciMLBase.get_initial_values(prob, integrator, prob.f, alg, isinplace; abstol = integrator.opts.abstol) end From 8f9a1b248d298fe7b5357374eaec9e22ede9599e Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 19 Nov 2024 11:23:15 +0530 Subject: [PATCH 2/4] build: bump SciMLBase compat --- lib/OrdinaryDiffEqCore/Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/OrdinaryDiffEqCore/Project.toml b/lib/OrdinaryDiffEqCore/Project.toml index ea0eed4466..1c1476d5e1 100644 --- a/lib/OrdinaryDiffEqCore/Project.toml +++ b/lib/OrdinaryDiffEqCore/Project.toml @@ -70,7 +70,7 @@ Random = "<0.0.1, 1" RecursiveArrayTools = "2.36, 3" Reexport = "1.0" SafeTestsets = "0.1.0" -SciMLBase = "2.60" +SciMLBase = "2.62" SciMLOperators = "0.3" SciMLStructures = "1" SimpleUnPack = "1" From 349bfcdd5f7d7ded70a7504d84b5e78aa8d8c611 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 20 Nov 2024 13:13:10 +0530 Subject: [PATCH 3/4] test: fix `CheckInit` tests --- test/interface/checkinit_tests.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/interface/checkinit_tests.jl b/test/interface/checkinit_tests.jl index 8320449151..e71b412e2c 100644 --- a/test/interface/checkinit_tests.jl +++ b/test/interface/checkinit_tests.jl @@ -24,9 +24,9 @@ roberf_oop = ODEFunction{false}(rober, mass_matrix = M) prob_mm = ODEProblem(roberf, [1.0, 0.0, 0.2], (0.0, 1e5), (0.04, 3e7, 1e4)) prob_mm_oop = ODEProblem(roberf_oop, [1.0, 0.0, 0.2], (0.0, 1e5), (0.04, 3e7, 1e4)) -@test_throws OrdinaryDiffEqCore.CheckInitFailureError solve( +@test_throws SciMLBase.CheckInitFailureError solve( prob_mm, Rodas5P(), reltol = 1e-8, abstol = 1e-8, initializealg = SciMLBase.CheckInit()) -@test_throws OrdinaryDiffEqCore.CheckInitFailureError solve( +@test_throws SciMLBase.CheckInitFailureError solve( prob_mm_oop, Rodas5P(), reltol = 1e-8, abstol = 1e-8, initializealg = SciMLBase.CheckInit()) @@ -49,7 +49,7 @@ tspan = (0.0, 100000.0) differential_vars = [true, true, false] prob = DAEProblem(f, du₀, u₀, tspan, differential_vars = differential_vars) prob_oop = DAEProblem(f_oop, du₀, u₀, tspan, differential_vars = differential_vars) -@test_throws OrdinaryDiffEqCore.CheckInitFailureError solve( +@test_throws SciMLBase.CheckInitFailureError solve( prob, DFBDF(), reltol = 1e-8, abstol = 1e-8, initializealg = SciMLBase.CheckInit()) -@test_throws OrdinaryDiffEqCore.CheckInitFailureError solve( +@test_throws SciMLBase.CheckInitFailureError solve( prob_oop, DFBDF(), reltol = 1e-8, abstol = 1e-8, initializealg = SciMLBase.CheckInit()) From c7b28ff3d56941539e1b7324e93c741595faef1f Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Wed, 20 Nov 2024 06:10:39 -0500 Subject: [PATCH 4/4] Update CI.yml --- .github/workflows/CI.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 7d11cf320e..1c05dbf029 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -95,4 +95,4 @@ jobs: with: token: ${{ secrets.CODECOV_TOKEN }} file: lcov.info - fail_ci_if_error: true + fail_ci_if_error: false