Skip to content

Commit

Permalink
changes
Browse files Browse the repository at this point in the history
  • Loading branch information
Shreyas-Ekanathan committed Dec 2, 2024
1 parent 60ab391 commit afd385b
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 12 deletions.
11 changes: 7 additions & 4 deletions lib/OrdinaryDiffEqFIRK/src/controllers.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
function step_accept_controller!(integrator, controller::PredictiveController, alg::AdaptiveRadau, q)
@unpack qmin, qmax, gamma, qsteady_min, qsteady_max = integrator.opts
@unpack cache = integrator
@unpack num_stages, step, iter, hist_iter = cache
@unpack num_stages, step, iter, hist_iter, index = cache

EEst = DiffEqBase.value(integrator.EEst)

Expand All @@ -25,12 +25,14 @@ function step_accept_controller!(integrator, controller::PredictiveController, a
max_stages = (alg.max_order - 1) ÷ 4 * 2 + 1
min_stages = (alg.min_order - 1) ÷ 4 * 2 + 1
if (step > 10)
if (hist_iter < 2.6 && num_stages <= max_stages)
if (hist_iter < 2.6 && num_stages < max_stages)
cache.num_stages += 2
cache.index += 1
cache.step = 1
cache.hist_iter = iter
elseif ((hist_iter > 8 || cache.status == VerySlowConvergence || cache.status == Divergence) && num_stages >= min_stages)
elseif ((hist_iter > 8 || cache.status == VerySlowConvergence || cache.status == Divergence) && num_stages > min_stages)
cache.num_stages -= 2
cache.index -= 1
cache.step = 1
cache.hist_iter = iter
end
Expand All @@ -48,8 +50,9 @@ function step_reject_controller!(integrator, controller::PredictiveController, a
cache.hist_iter = hist_iter
min_stages = (alg.min_order - 1) ÷ 4 * 2 + 1
if (step > 10)
if ((hist_iter > 8 || cache.status == VerySlowConvergence || cache.status == Divergence) && num_stages >= min_stages)
if ((hist_iter > 8 || cache.status == VerySlowConvergence || cache.status == Divergence) && num_stages > min_stages)
cache.num_stages -= 2
cache.index -= 1
cache.step = 1
cache.hist_iter = iter
end
Expand Down
38 changes: 34 additions & 4 deletions lib/OrdinaryDiffEqFIRK/src/firk_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,7 @@ mutable struct AdaptiveRadauConstantCache{F, Tab, Tol, Dt, U, JType} <:
num_stages::Int
step::Int
hist_iter::Float64
index::Int
end

function alg_cache(alg::AdaptiveRadau, u, rate_prototype, ::Type{uEltypeNoUnits},
Expand All @@ -518,7 +519,11 @@ function alg_cache(alg::AdaptiveRadau, u, rate_prototype, ::Type{uEltypeNoUnits}
num_stages = min

tabs = [RadauIIATableau5(uToltype, constvalue(tTypeNoUnits)), RadauIIATableau9(uToltype, constvalue(tTypeNoUnits)), RadauIIATableau13(uToltype, constvalue(tTypeNoUnits))]
i = 9
if (min == 3 || min == 5 || min == 7)
i = 9
else
i = min
end
while i <= max
push!(tabs, RadauIIATableau(uToltype, constvalue(tTypeNoUnits), i))
i += 2
Expand All @@ -528,10 +533,20 @@ function alg_cache(alg::AdaptiveRadau, u, rate_prototype, ::Type{uEltypeNoUnits}
cont[i] = zero(u)
end

if (min == 3)
index = 1
elseif (min == 5)
index = 2
elseif (min == 7)
index = 3
else
index = 4
end

κ = alg.κ !== nothing ? convert(uToltype, alg.κ) : convert(uToltype, 1 // 100)
J = false .* _vec(rate_prototype) .* _vec(rate_prototype)'
AdaptiveRadauConstantCache(uf, tabs, κ, one(uToltype), 10000, cont, dt, dt,
Convergence, J, num_stages, 1, 0.0)
Convergence, J, num_stages, 1, 0.0, index)
end

mutable struct AdaptiveRadauCache{uType, cuType, tType, uNoUnitsType, rateType, JType, W1Type, W2Type,
Expand Down Expand Up @@ -578,6 +593,7 @@ mutable struct AdaptiveRadauCache{uType, cuType, tType, uNoUnitsType, rateType,
num_stages::Int
step::Int
hist_iter::Float64
index::Int
end

function alg_cache(alg::AdaptiveRadau, u, rate_prototype, ::Type{uEltypeNoUnits},
Expand All @@ -599,12 +615,26 @@ function alg_cache(alg::AdaptiveRadau, u, rate_prototype, ::Type{uEltypeNoUnits}
num_stages = min

tabs = [RadauIIATableau5(uToltype, constvalue(tTypeNoUnits)), RadauIIATableau9(uToltype, constvalue(tTypeNoUnits)), RadauIIATableau13(uToltype, constvalue(tTypeNoUnits))]
i = 9
if (min == 3 || min == 5 || min == 7)
i = 9
else
i = min
end
while i <= max
push!(tabs, RadauIIATableau(uToltype, constvalue(tTypeNoUnits), i))
i += 2
end

if (min == 3)
index = 1
elseif (min == 5)
index = 2
elseif (min == 7)
index = 3
else
index = 4
end

κ = alg.κ !== nothing ? convert(uToltype, alg.κ) : convert(uToltype, 1 // 100)

z = Vector{typeof(u)}(undef, max)
Expand Down Expand Up @@ -677,6 +707,6 @@ function alg_cache(alg::AdaptiveRadau, u, rate_prototype, ::Type{uEltypeNoUnits}
uf, tabs, κ, one(uToltype), 10000, tmp,
atmp, jac_config,
linsolve1, linsolve2, rtol, atol, dt, dt,
Convergence, alg.step_limiter!, num_stages, 1, 0.0)
Convergence, alg.step_limiter!, num_stages, 1, 0.0, index)
end

8 changes: 4 additions & 4 deletions lib/OrdinaryDiffEqFIRK/src/firk_perform_step.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1354,8 +1354,8 @@ end
@muladd function perform_step!(integrator, cache::AdaptiveRadauConstantCache,
repeat_step = false)
@unpack t, dt, uprev, u, f, p = integrator
@unpack tabs, num_stages = cache
tab = tabs[(num_stages - 1) ÷ 2]
@unpack tabs, num_stages, index = cache
tab = tabs[index]
@unpack T, TI, γ, α, β, c, e = tab
@unpack κ, cont = cache
@unpack internalnorm, abstol, reltol, adaptive = integrator.opts
Expand Down Expand Up @@ -1595,8 +1595,8 @@ end

@muladd function perform_step!(integrator, cache::AdaptiveRadauCache, repeat_step = false)
@unpack t, dt, uprev, u, f, p, fsallast, fsalfirst = integrator
@unpack num_stages, tabs = cache
tab = tabs[(num_stages - 1) ÷ 2]
@unpack num_stages, tabs, index = cache
tab = tabs[index]
@unpack T, TI, γ, α, β, c, e = tab
@unpack κ, cont, derivatives, z, w, c_prime, αdt, βdt= cache
@unpack dw1, ubuff, dw2, cubuff, dw = cache
Expand Down

0 comments on commit afd385b

Please sign in to comment.