Skip to content

Commit

Permalink
Merge branch 'SciML:master' into master
Browse files Browse the repository at this point in the history
  • Loading branch information
mcarmesin authored Nov 21, 2024
2 parents 2b1c402 + 79ff531 commit df1e6d5
Show file tree
Hide file tree
Showing 13 changed files with 127 additions and 183 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -95,4 +95,4 @@ jobs:
with:
token: ${{ secrets.CODECOV_TOKEN }}
file: lcov.info
fail_ci_if_error: true
fail_ci_if_error: false
4 changes: 2 additions & 2 deletions lib/OrdinaryDiffEqCore/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "OrdinaryDiffEqCore"
uuid = "bbf590c4-e513-4bbe-9b18-05decba2e5d8"
authors = ["ParamThakkar123 <[email protected]>"]
version = "1.11.0"
version = "1.12.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down Expand Up @@ -70,7 +70,7 @@ Random = "<0.0.1, 1"
RecursiveArrayTools = "2.36, 3"
Reexport = "1.0"
SafeTestsets = "0.1.0"
SciMLBase = "2.60"
SciMLBase = "2.62"
SciMLOperators = "0.3"
SciMLStructures = "1"
SimpleUnPack = "1"
Expand Down
2 changes: 1 addition & 1 deletion lib/OrdinaryDiffEqCore/src/OrdinaryDiffEqCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ using DiffEqBase: check_error!, @def, _vec, _reshape

using FastBroadcast: @.., True, False

using SciMLBase: NoInit, CheckInit, _unwrap_val
using SciMLBase: NoInit, CheckInit, OverrideInit, AbstractDEProblem, _unwrap_val

import SciMLBase: alg_order

Expand Down
123 changes: 17 additions & 106 deletions lib/OrdinaryDiffEqCore/src/initialize_dae.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,6 @@ function BrownFullBasicInit(; abstol = 1e-10, nlsolve = nothing)
end
BrownFullBasicInit(abstol) = BrownFullBasicInit(; abstol = abstol, nlsolve = nothing)

struct OverrideInit{T, F} <: DiffEqBase.DAEInitializationAlgorithm
abstol::T
nlsolve::F
end

function OverrideInit(; abstol = 1e-10, nlsolve = nothing)
OverrideInit(abstol, nlsolve)
end
OverrideInit(abstol) = OverrideInit(; abstol = abstol, nlsolve = nothing)

## Notes

#=
Expand Down Expand Up @@ -143,19 +133,15 @@ end

## NoInit

function _initialize_dae!(integrator, prob::Union{ODEProblem, DAEProblem},
function _initialize_dae!(integrator, prob::AbstractDEProblem,
alg::NoInit, x::Union{Val{true}, Val{false}})
end

## OverrideInit

function _initialize_dae!(integrator, prob::Union{ODEProblem, DAEProblem},
function _initialize_dae!(integrator, prob::AbstractDEProblem,
alg::OverrideInit, isinplace::Union{Val{true}, Val{false}})
initializeprob = prob.f.initializeprob

if SciMLBase.has_update_initializeprob!(prob.f)
prob.f.update_initializeprob!(initializeprob, prob)
end
initializeprob = prob.f.initialization_data.initializeprob

# If it doesn't have autodiff, assume it comes from symbolic system like ModelingToolkit
# Since then it's the case of not a DAE but has initializeprob
Expand All @@ -168,105 +154,30 @@ function _initialize_dae!(integrator, prob::Union{ODEProblem, DAEProblem},
true
end

alg = default_nlsolve(alg.nlsolve, isinplace, initializeprob.u0, initializeprob, isAD)
nlsol = solve(initializeprob, alg, abstol = integrator.opts.abstol, reltol = integrator.opts.reltol)
nlsolve_alg = default_nlsolve(alg.nlsolve, isinplace, initializeprob.u0, initializeprob, isAD)

u0, p, success = SciMLBase.get_initial_values(prob, prob.f, integrator, alg, isinplace; nlsolve_alg, abstol = integrator.opts.abstol, reltol = integrator.opts.reltol)

if isinplace === Val{true}()
integrator.u .= prob.f.initializeprobmap(nlsol)
integrator.u .= u0
elseif isinplace === Val{false}()
integrator.u = prob.f.initializeprobmap(nlsol)
integrator.u = u0
else
error("Unreachable reached. Report this error.")
end
if SciMLBase.has_initializeprobpmap(prob.f)
integrator.p = prob.f.initializeprobpmap(prob, nlsol)
sol = integrator.sol
@reset sol.prob.p = integrator.p
integrator.sol = sol
end
integrator.p = p
sol = integrator.sol
@reset sol.prob.p = integrator.p
integrator.sol = sol

if nlsol.retcode != ReturnCode.Success
if !success
integrator.sol = SciMLBase.solution_new_retcode(integrator.sol,
ReturnCode.InitialFailure)
end
end

## CheckInit
struct CheckInitFailureError <: Exception
normresid::Any
abstol::Any
end

function Base.showerror(io::IO, e::CheckInitFailureError)
print(io,
"CheckInit specified but initialization not satisifed. normresid = $(e.normresid) > abstol = $(e.abstol)")
end

function _initialize_dae!(integrator, prob::ODEProblem, alg::CheckInit,
isinplace::Val{true})
@unpack p, t, f = integrator
M = integrator.f.mass_matrix
tmp = first(get_tmp_cache(integrator))
u0 = integrator.u

algebraic_vars = [all(iszero, x) for x in eachcol(M)]
algebraic_eqs = [all(iszero, x) for x in eachrow(M)]
(iszero(algebraic_vars) || iszero(algebraic_eqs)) && return
update_coefficients!(M, u0, p, t)
f(tmp, u0, p, t)
tmp .= ArrayInterface.restructure(tmp, algebraic_eqs .* _vec(tmp))

normresid = integrator.opts.internalnorm(tmp, t)
if normresid > integrator.opts.abstol
throw(CheckInitFailureError(normresid, integrator.opts.abstol))
end
end

function _initialize_dae!(integrator, prob::ODEProblem, alg::CheckInit,
isinplace::Val{false})
@unpack p, t, f = integrator
u0 = integrator.u
M = integrator.f.mass_matrix

algebraic_vars = [all(iszero, x) for x in eachcol(M)]
algebraic_eqs = [all(iszero, x) for x in eachrow(M)]
(iszero(algebraic_vars) || iszero(algebraic_eqs)) && return
update_coefficients!(M, u0, p, t)
du = f(u0, p, t)
resid = _vec(du)[algebraic_eqs]

normresid = integrator.opts.internalnorm(resid, t)
if normresid > integrator.opts.abstol
throw(CheckInitFailureError(normresid, integrator.opts.abstol))
end
end

function _initialize_dae!(integrator, prob::DAEProblem,
alg::CheckInit, isinplace::Val{true})
@unpack p, t, f = integrator
u0 = integrator.u
resid = get_tmp_cache(integrator)[2]

f(resid, integrator.du, u0, p, t)
normresid = integrator.opts.internalnorm(resid, t)
if normresid > integrator.opts.abstol
throw(CheckInitFailureError(normresid, integrator.opts.abstol))
end
end

function _initialize_dae!(integrator, prob::DAEProblem,
alg::CheckInit, isinplace::Val{false})
@unpack p, t, f = integrator
u0 = integrator.u

nlequation_oop = u -> begin
f((u - u0) / dt, u, p, t)
end

nlequation = (u, _) -> nlequation_oop(u)

resid = f(integrator.du, u0, p, t)
normresid = integrator.opts.internalnorm(resid, t)
if normresid > integrator.opts.abstol
throw(CheckInitFailureError(normresid, integrator.opts.abstol))
end
function _initialize_dae!(integrator, prob::AbstractDEProblem, alg::CheckInit,
isinplace::Union{Val{true}, Val{false}})
SciMLBase.get_initial_values(prob, integrator, prob.f, alg, isinplace; abstol = integrator.opts.abstol)
end
2 changes: 1 addition & 1 deletion lib/OrdinaryDiffEqFIRK/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "OrdinaryDiffEqFIRK"
uuid = "5960d6e9-dd7a-4743-88e7-cf307b64f125"
authors = ["ParamThakkar123 <[email protected]>"]
version = "1.4.0"
version = "1.5.0"

[deps]
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
Expand Down
8 changes: 4 additions & 4 deletions lib/OrdinaryDiffEqFIRK/src/algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -163,13 +163,13 @@ struct AdaptiveRadau{CS, AD, F, P, FDT, ST, CJ, Tol, C1, C2, StepLimiter} <:
new_W_γdt_cutoff::C2
controller::Symbol
step_limiter!::StepLimiter
min_stages::Int
max_stages::Int
min_order::Int
max_order::Int
end

function AdaptiveRadau(; chunk_size = Val{0}(), autodiff = Val{true}(),
standardtag = Val{true}(), concrete_jac = nothing,
diff_type = Val{:forward}, min_stages = 3, max_stages = 7,
diff_type = Val{:forward}, min_order = 5, max_order = 13,
linsolve = nothing, precs = DEFAULT_PRECS,
extrapolant = :dense, fast_convergence_cutoff = 1 // 5,
new_W_γdt_cutoff = 1 // 5,
Expand All @@ -187,6 +187,6 @@ function AdaptiveRadau(; chunk_size = Val{0}(), autodiff = Val{true}(),
fast_convergence_cutoff,
new_W_γdt_cutoff,
controller,
step_limiter!, min_stages, max_stages)
step_limiter!, min_order, max_order)
end

9 changes: 6 additions & 3 deletions lib/OrdinaryDiffEqFIRK/src/controllers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,14 @@ function step_accept_controller!(integrator, controller::PredictiveController, a
cache.step = step + 1
hist_iter = hist_iter * 0.8 + iter * 0.2
cache.hist_iter = hist_iter
max_stages = (alg.max_order - 1) ÷ 4 * 2 + 1
min_stages = (alg.min_order - 1) ÷ 4 * 2 + 1
if (step > 10)
if (hist_iter < 2.6 && num_stages < alg.max_stages)
if (hist_iter < 2.6 && num_stages <= max_stages)
cache.num_stages += 2
cache.step = 1
cache.hist_iter = iter
elseif ((hist_iter > 8 || cache.status == VerySlowConvergence || cache.status == Divergence) && num_stages > alg.min_stages)
elseif ((hist_iter > 8 || cache.status == VerySlowConvergence || cache.status == Divergence) && num_stages >= min_stages)
cache.num_stages -= 2
cache.step = 1
cache.hist_iter = iter
Expand All @@ -44,8 +46,9 @@ function step_reject_controller!(integrator, controller::PredictiveController, a
cache.step = step + 1
hist_iter = hist_iter * 0.8 + iter * 0.2
cache.hist_iter = hist_iter
min_stages = (alg.min_order - 1) ÷ 4 * 2 + 1
if (step > 10)
if ((hist_iter > 8 || cache.status == VerySlowConvergence || cache.status == Divergence) && num_stages > alg.min_stages)
if ((hist_iter > 8 || cache.status == VerySlowConvergence || cache.status == Divergence) && num_stages >= min_stages)
cache.num_stages -= 2
cache.step = 1
cache.hist_iter = iter
Expand Down
50 changes: 40 additions & 10 deletions lib/OrdinaryDiffEqFIRK/src/firk_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,10 @@ mutable struct RadauIIA9Cache{uType, cuType, uNoUnitsType, rateType, JType, W1Ty
tmp4::uType
tmp5::uType
tmp6::uType
tmp7::uType
tmp8::uType
tmp9::uType
tmp10::uType
atmp::uNoUnitsType
jac_config::JC
linsolve1::F1
Expand Down Expand Up @@ -440,6 +444,10 @@ function alg_cache(alg::RadauIIA9, u, rate_prototype, ::Type{uEltypeNoUnits},
tmp4 = zero(u)
tmp5 = zero(u)
tmp6 = zero(u)
tmp7 = zero(u)
tmp8 = zero(u)
tmp9 = zero(u)
tmp10 = zero(u)
atmp = similar(u, uEltypeNoUnits)
recursivefill!(atmp, false)
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, dw1)
Expand Down Expand Up @@ -469,7 +477,7 @@ function alg_cache(alg::RadauIIA9, u, rate_prototype, ::Type{uEltypeNoUnits},
du1, fsalfirst, k, k2, k3, k4, k5, fw1, fw2, fw3, fw4, fw5,
J, W1, W2, W3,
uf, tab, κ, one(uToltype), 10000,
tmp, tmp2, tmp3, tmp4, tmp5, tmp6, atmp, jac_config,
tmp, tmp2, tmp3, tmp4, tmp5, tmp6, tmp7, tmp8, tmp9, tmp10, atmp, jac_config,
linsolve1, linsolve2, linsolve3, rtol, atol, dt, dt,
Convergence, alg.step_limiter!)
end
Expand Down Expand Up @@ -497,17 +505,26 @@ function alg_cache(alg::AdaptiveRadau, u, rate_prototype, ::Type{uEltypeNoUnits}
::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
uf = UDerivativeWrapper(f, t, p)
uToltype = constvalue(uBottomEltypeNoUnits)
num_stages = alg.min_stages
max = alg.max_stages

max_order = alg.max_order
min_order = alg.min_order
max = (max_order - 1) ÷ 4 * 2 + 1
min = (min_order - 1) ÷ 4 * 2 + 1
if (alg.min_order < 5)
error("min_order choice $min_order below 5 is not compatible with the algorithm")
elseif (max < min)
error("max_order $max_order is below min_order $min_order")
end
num_stages = min

tabs = [BigRadauIIA5Tableau(uToltype, constvalue(tTypeNoUnits)), BigRadauIIA9Tableau(uToltype, constvalue(tTypeNoUnits)), BigRadauIIA13Tableau(uToltype, constvalue(tTypeNoUnits))]

i = 9
while i <= alg.max_stages
while i <= max
push!(tabs, adaptiveRadauTableau(uToltype, constvalue(tTypeNoUnits), i))
i += 2
end
cont = Vector{typeof(u)}(undef, max)
for i in 1: max
for i in 1:max
cont[i] = zero(u)
end

Expand All @@ -525,6 +542,8 @@ mutable struct AdaptiveRadauCache{uType, cuType, tType, uNoUnitsType, rateType,
z::Vector{uType}
w::Vector{uType}
c_prime::Vector{tType}
αdt::Vector{tType}
βdt::Vector{tType}
dw1::uType
ubuff::uType
dw2::Vector{cuType}
Expand Down Expand Up @@ -568,8 +587,16 @@ function alg_cache(alg::AdaptiveRadau, u, rate_prototype, ::Type{uEltypeNoUnits}
uf = UJacobianWrapper(f, t, p)
uToltype = constvalue(uBottomEltypeNoUnits)

max = alg.max_stages
num_stages = alg.min_stages
max_order = alg.max_order
min_order = alg.min_order
max = (max_order - 1) ÷ 4 * 2 + 1
min = (min_order - 1) ÷ 4 * 2 + 1
if (alg.min_order < 5)
error("min_order choice $min_order below 5 is not compatible with the algorithm")
elseif (max < min)
error("max_order $max_order is below min_order $min_order")
end
num_stages = min

tabs = [BigRadauIIA5Tableau(uToltype, constvalue(tTypeNoUnits)), BigRadauIIA9Tableau(uToltype, constvalue(tTypeNoUnits)), BigRadauIIA13Tableau(uToltype, constvalue(tTypeNoUnits))]
i = 9
Expand All @@ -583,9 +610,12 @@ function alg_cache(alg::AdaptiveRadau, u, rate_prototype, ::Type{uEltypeNoUnits}
z = Vector{typeof(u)}(undef, max)
w = Vector{typeof(u)}(undef, max)
for i in 1 : max
z[i] = w[i] = zero(u)
z[i] = zero(u)
w[i] = zero(u)
end

αdt = [zero(t) for i in 1:max]
βdt = [zero(t) for i in 1:max]
c_prime = Vector{typeof(t)}(undef, max) #time stepping
for i in 1 : max
c_prime[i] = zero(t)
Expand Down Expand Up @@ -641,7 +671,7 @@ function alg_cache(alg::AdaptiveRadau, u, rate_prototype, ::Type{uEltypeNoUnits}
atol = reltol isa Number ? reltol : zero(reltol)

AdaptiveRadauCache(u, uprev,
z, w, c_prime, dw1, ubuff, dw2, cubuff, dw, cont, derivatives,
z, w, c_prime, αdt, βdt, dw1, ubuff, dw2, cubuff, dw, cont, derivatives,
du1, fsalfirst, ks, k, fw,
J, W1, W2,
uf, tabs, κ, one(uToltype), 10000, tmp,
Expand Down
Loading

0 comments on commit df1e6d5

Please sign in to comment.