From 182c9fff2b42a970afd0d1658a92aa599a114137 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 26 Nov 2024 16:55:54 +0530 Subject: [PATCH] feat: allow using `SCCNonlinearProblem` for initialization --- lib/OrdinaryDiffEqCore/src/OrdinaryDiffEqCore.jl | 5 +++-- lib/OrdinaryDiffEqCore/src/initialize_dae.jl | 11 ++++++----- .../src/OrdinaryDiffEqNonlinearSolve.jl | 3 ++- .../src/initialize_dae.jl | 6 +++--- 4 files changed, 14 insertions(+), 11 deletions(-) diff --git a/lib/OrdinaryDiffEqCore/src/OrdinaryDiffEqCore.jl b/lib/OrdinaryDiffEqCore/src/OrdinaryDiffEqCore.jl index ac2e671b21..fbfcd66dd3 100644 --- a/lib/OrdinaryDiffEqCore/src/OrdinaryDiffEqCore.jl +++ b/lib/OrdinaryDiffEqCore/src/OrdinaryDiffEqCore.jl @@ -62,7 +62,7 @@ using FastBroadcast: @.., True, False using SciMLBase: NoInit, CheckInit, OverrideInit, AbstractDEProblem, _unwrap_val -import SciMLBase: alg_order +import SciMLBase: AbstractNonlinearProblem, alg_order import DiffEqBase: calculate_residuals, calculate_residuals!, unwrap_cache, @@ -76,7 +76,8 @@ import Accessors: @reset using SciMLStructures: canonicalize, Tunable, isscimlstructure -using SymbolicIndexingInterface: parameter_values, is_variable, variable_index, symbolic_type, NotSymbolic +using SymbolicIndexingInterface: state_values, parameter_values, is_variable, variable_index, + symbolic_type, NotSymbolic const CompiledFloats = Union{Float32, Float64} import Preferences diff --git a/lib/OrdinaryDiffEqCore/src/initialize_dae.jl b/lib/OrdinaryDiffEqCore/src/initialize_dae.jl index d769122a84..0034e55d9a 100644 --- a/lib/OrdinaryDiffEqCore/src/initialize_dae.jl +++ b/lib/OrdinaryDiffEqCore/src/initialize_dae.jl @@ -101,7 +101,7 @@ default_nlsolve(alg, isinplace, u, initprob, autodiff = false) = alg ## If the initialization is trivial just use nothing alg function default_nlsolve( - ::Nothing, isinplace::Val{true}, u::Nothing, ::NonlinearProblem, autodiff = false) + ::Nothing, isinplace::Val{true}, u::Nothing, ::AbstractNonlinearProblem, autodiff = false) nothing end @@ -111,7 +111,7 @@ function default_nlsolve( end function default_nlsolve( - ::Nothing, isinplace::Val{false}, u::Nothing, ::NonlinearProblem, autodiff = false) + ::Nothing, isinplace::Val{false}, u::Nothing, ::AbstractNonlinearProblem, autodiff = false) nothing end @@ -122,7 +122,7 @@ function default_nlsolve( end function OrdinaryDiffEqCore.default_nlsolve( - ::Nothing, isinplace, u, ::NonlinearProblem, autodiff = false) + ::Nothing, isinplace, u, ::AbstractNonlinearProblem, autodiff = false) error("This ODE requires a DAE initialization and thus a nonlinear solve but no nonlinear solve has been loaded. To solve this problem, do `using OrdinaryDiffEqNonlinearSolve` or pass a custom `nlsolve` choice into the `initializealg`.") end @@ -146,7 +146,8 @@ function _initialize_dae!(integrator, prob::AbstractDEProblem, # If it doesn't have autodiff, assume it comes from symbolic system like ModelingToolkit # Since then it's the case of not a DAE but has initializeprob # In which case, it should be differentiable - isAD = if initializeprob.u0 === nothing + iu0 = state_values(initializeprob) + isAD = if iu0 === nothing AutoForwardDiff elseif has_autodiff(integrator.alg) alg_autodiff(integrator.alg) isa AutoForwardDiff @@ -154,7 +155,7 @@ function _initialize_dae!(integrator, prob::AbstractDEProblem, true end - nlsolve_alg = default_nlsolve(alg.nlsolve, isinplace, initializeprob.u0, initializeprob, isAD) + nlsolve_alg = default_nlsolve(alg.nlsolve, isinplace, iu0, initializeprob, isAD) u0, p, success = SciMLBase.get_initial_values(prob, integrator, prob.f, alg, isinplace; nlsolve_alg, abstol = integrator.opts.abstol, reltol = integrator.opts.reltol) diff --git a/lib/OrdinaryDiffEqNonlinearSolve/src/OrdinaryDiffEqNonlinearSolve.jl b/lib/OrdinaryDiffEqNonlinearSolve/src/OrdinaryDiffEqNonlinearSolve.jl index 3453c6fd73..593f27ee77 100644 --- a/lib/OrdinaryDiffEqNonlinearSolve/src/OrdinaryDiffEqNonlinearSolve.jl +++ b/lib/OrdinaryDiffEqNonlinearSolve/src/OrdinaryDiffEqNonlinearSolve.jl @@ -6,7 +6,8 @@ import SciMLBase import SciMLBase: init, solve, solve!, remake using SciMLBase: DAEFunction, DEIntegrator, NonlinearFunction, NonlinearProblem, NonlinearLeastSquaresProblem, LinearProblem, ODEProblem, DAEProblem, - update_coefficients!, get_tmp_cache, AbstractSciMLOperator, ReturnCode + update_coefficients!, get_tmp_cache, AbstractSciMLOperator, ReturnCode, + AbstractNonlinearProblem import DiffEqBase import PreallocationTools using SimpleNonlinearSolve: SimpleTrustRegion, SimpleGaussNewton diff --git a/lib/OrdinaryDiffEqNonlinearSolve/src/initialize_dae.jl b/lib/OrdinaryDiffEqNonlinearSolve/src/initialize_dae.jl index 67c3fb177e..b5cffde762 100644 --- a/lib/OrdinaryDiffEqNonlinearSolve/src/initialize_dae.jl +++ b/lib/OrdinaryDiffEqNonlinearSolve/src/initialize_dae.jl @@ -1,5 +1,5 @@ function default_nlsolve( - ::Nothing, isinplace::Val{true}, u, ::NonlinearProblem, autodiff = false) + ::Nothing, isinplace::Val{true}, u, ::AbstractNonlinearProblem, autodiff = false) FastShortcutNonlinearPolyalg(; autodiff = autodiff ? AutoForwardDiff() : AutoFiniteDiff()) end @@ -8,7 +8,7 @@ function default_nlsolve( FastShortcutNLLSPolyalg(; autodiff = autodiff ? AutoForwardDiff() : AutoFiniteDiff()) end function default_nlsolve( - ::Nothing, isinplace::Val{false}, u, ::NonlinearProblem, autodiff = false) + ::Nothing, isinplace::Val{false}, u, ::AbstractNonlinearProblem, autodiff = false) FastShortcutNonlinearPolyalg(; autodiff = autodiff ? AutoForwardDiff() : AutoFiniteDiff()) end @@ -17,7 +17,7 @@ function default_nlsolve( FastShortcutNLLSPolyalg(; autodiff = autodiff ? AutoForwardDiff() : AutoFiniteDiff()) end function default_nlsolve(::Nothing, isinplace::Val{false}, u::StaticArray, - ::NonlinearProblem, autodiff = false) + ::AbstractNonlinearProblem, autodiff = false) SimpleTrustRegion(autodiff = autodiff ? AutoForwardDiff() : AutoFiniteDiff()) end function default_nlsolve(::Nothing, isinplace::Val{false}, u::StaticArray,