Skip to content

Commit

Permalink
Merge pull request #2544 from AayushSabharwal/as/scc-init
Browse files Browse the repository at this point in the history
feat: allow using `SCCNonlinearProblem` for initialization
  • Loading branch information
ChrisRackauckas authored Nov 29, 2024
2 parents 14e6cb6 + 074b1ca commit b37a5cc
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 12 deletions.
5 changes: 4 additions & 1 deletion .github/workflows/Downstream.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,10 @@ jobs:
- {user: SciML, repo: SciMLSensitivity.jl, group: Core3}
- {user: SciML, repo: SciMLSensitivity.jl, group: Core4}
- {user: SciML, repo: SciMLSensitivity.jl, group: Core5}
- {user: SciML, repo: ModelingToolkit.jl, group: All}
- {user: SciML, repo: ModelingToolkit.jl, group: InterfaceI}
- {user: SciML, repo: ModelingToolkit.jl, group: InterfaceII}
- {user: SciML, repo: ModelingToolkit.jl, group: Initialization}
- {user: SciML, repo: ModelingToolkit.jl, group: SymbolicIndexingInterface}
- {user: SciML, repo: DiffEqDevTools.jl, group: Core}
- {user: nathanaelbosch, repo: ProbNumDiffEq.jl, group: Downstream}
- {user: SKopecz, repo: PositiveIntegrators.jl, group: Downstream}
Expand Down
5 changes: 3 additions & 2 deletions lib/OrdinaryDiffEqCore/src/OrdinaryDiffEqCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ using FastBroadcast: @.., True, False

using SciMLBase: NoInit, CheckInit, OverrideInit, AbstractDEProblem, _unwrap_val

import SciMLBase: alg_order
import SciMLBase: AbstractNonlinearProblem, alg_order

import DiffEqBase: calculate_residuals,
calculate_residuals!, unwrap_cache,
Expand All @@ -76,7 +76,8 @@ import Accessors: @reset

using SciMLStructures: canonicalize, Tunable, isscimlstructure

using SymbolicIndexingInterface: parameter_values, is_variable, variable_index, symbolic_type, NotSymbolic
using SymbolicIndexingInterface: state_values, parameter_values, is_variable, variable_index,
symbolic_type, NotSymbolic

const CompiledFloats = Union{Float32, Float64}
import Preferences
Expand Down
11 changes: 6 additions & 5 deletions lib/OrdinaryDiffEqCore/src/initialize_dae.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ default_nlsolve(alg, isinplace, u, initprob, autodiff = false) = alg

## If the initialization is trivial just use nothing alg
function default_nlsolve(
::Nothing, isinplace::Val{true}, u::Nothing, ::NonlinearProblem, autodiff = false)
::Nothing, isinplace::Val{true}, u::Nothing, ::AbstractNonlinearProblem, autodiff = false)
nothing
end

Expand All @@ -111,7 +111,7 @@ function default_nlsolve(
end

function default_nlsolve(
::Nothing, isinplace::Val{false}, u::Nothing, ::NonlinearProblem, autodiff = false)
::Nothing, isinplace::Val{false}, u::Nothing, ::AbstractNonlinearProblem, autodiff = false)
nothing
end

Expand All @@ -122,7 +122,7 @@ function default_nlsolve(
end

function OrdinaryDiffEqCore.default_nlsolve(
::Nothing, isinplace, u, ::NonlinearProblem, autodiff = false)
::Nothing, isinplace, u, ::AbstractNonlinearProblem, autodiff = false)
error("This ODE requires a DAE initialization and thus a nonlinear solve but no nonlinear solve has been loaded. To solve this problem, do `using OrdinaryDiffEqNonlinearSolve` or pass a custom `nlsolve` choice into the `initializealg`.")
end

Expand All @@ -146,15 +146,16 @@ function _initialize_dae!(integrator, prob::AbstractDEProblem,
# If it doesn't have autodiff, assume it comes from symbolic system like ModelingToolkit
# Since then it's the case of not a DAE but has initializeprob
# In which case, it should be differentiable
isAD = if initializeprob.u0 === nothing
iu0 = state_values(initializeprob)
isAD = if iu0 === nothing
AutoForwardDiff
elseif has_autodiff(integrator.alg)
alg_autodiff(integrator.alg) isa AutoForwardDiff
else
true
end

nlsolve_alg = default_nlsolve(alg.nlsolve, isinplace, initializeprob.u0, initializeprob, isAD)
nlsolve_alg = default_nlsolve(alg.nlsolve, isinplace, iu0, initializeprob, isAD)

u0, p, success = SciMLBase.get_initial_values(prob, integrator, prob.f, alg, isinplace; nlsolve_alg, abstol = integrator.opts.abstol, reltol = integrator.opts.reltol)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ import SciMLBase
import SciMLBase: init, solve, solve!, remake
using SciMLBase: DAEFunction, DEIntegrator, NonlinearFunction, NonlinearProblem,
NonlinearLeastSquaresProblem, LinearProblem, ODEProblem, DAEProblem,
update_coefficients!, get_tmp_cache, AbstractSciMLOperator, ReturnCode
update_coefficients!, get_tmp_cache, AbstractSciMLOperator, ReturnCode,
AbstractNonlinearProblem
import DiffEqBase
import PreallocationTools
using SimpleNonlinearSolve: SimpleTrustRegion, SimpleGaussNewton
Expand Down
6 changes: 3 additions & 3 deletions lib/OrdinaryDiffEqNonlinearSolve/src/initialize_dae.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
function default_nlsolve(
::Nothing, isinplace::Val{true}, u, ::NonlinearProblem, autodiff = false)
::Nothing, isinplace::Val{true}, u, ::AbstractNonlinearProblem, autodiff = false)
FastShortcutNonlinearPolyalg(;
autodiff = autodiff ? AutoForwardDiff() : AutoFiniteDiff())
end
Expand All @@ -8,7 +8,7 @@ function default_nlsolve(
FastShortcutNLLSPolyalg(; autodiff = autodiff ? AutoForwardDiff() : AutoFiniteDiff())
end
function default_nlsolve(
::Nothing, isinplace::Val{false}, u, ::NonlinearProblem, autodiff = false)
::Nothing, isinplace::Val{false}, u, ::AbstractNonlinearProblem, autodiff = false)
FastShortcutNonlinearPolyalg(;
autodiff = autodiff ? AutoForwardDiff() : AutoFiniteDiff())
end
Expand All @@ -17,7 +17,7 @@ function default_nlsolve(
FastShortcutNLLSPolyalg(; autodiff = autodiff ? AutoForwardDiff() : AutoFiniteDiff())
end
function default_nlsolve(::Nothing, isinplace::Val{false}, u::StaticArray,
::NonlinearProblem, autodiff = false)
::AbstractNonlinearProblem, autodiff = false)
SimpleTrustRegion(autodiff = autodiff ? AutoForwardDiff() : AutoFiniteDiff())
end
function default_nlsolve(::Nothing, isinplace::Val{false}, u::StaticArray,
Expand Down

0 comments on commit b37a5cc

Please sign in to comment.