Skip to content

Commit

Permalink
rename things and fix broadcast
Browse files Browse the repository at this point in the history
  • Loading branch information
Shreyas-Ekanathan committed Nov 13, 2024
1 parent 6f18857 commit e50bf9a
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 16 deletions.
4 changes: 2 additions & 2 deletions lib/OrdinaryDiffEqFIRK/src/algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ end

function AdaptiveRadau(; chunk_size = Val{0}(), autodiff = Val{true}(),
standardtag = Val{true}(), concrete_jac = nothing,
diff_type = Val{:forward}, min_stages = 3, max_stages = 7,
diff_type = Val{:forward}, min_order = 5, max_order = 13,
linsolve = nothing, precs = DEFAULT_PRECS,
extrapolant = :dense, fast_convergence_cutoff = 1 // 5,
new_W_γdt_cutoff = 1 // 5,
Expand All @@ -187,6 +187,6 @@ function AdaptiveRadau(; chunk_size = Val{0}(), autodiff = Val{true}(),
fast_convergence_cutoff,
new_W_γdt_cutoff,
controller,
step_limiter!, min_stages, max_stages)
step_limiter!, min_order, max_order)
end

6 changes: 3 additions & 3 deletions lib/OrdinaryDiffEqFIRK/src/controllers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@ function step_accept_controller!(integrator, controller::PredictiveController, a
hist_iter = hist_iter * 0.8 + iter * 0.2
cache.hist_iter = hist_iter
if (step > 10)
if (hist_iter < 2.6 && num_stages < alg.max_stages)
if (hist_iter < 2.6 && num_stages < (alg.max_order + 1) ÷ 2)
cache.num_stages += 2
cache.step = 1
cache.hist_iter = iter
elseif ((hist_iter > 8 || cache.status == VerySlowConvergence || cache.status == Divergence) && num_stages > alg.min_stages)
elseif ((hist_iter > 8 || cache.status == VerySlowConvergence || cache.status == Divergence) && num_stages > (alg.min_order + 1) ÷ 2)
cache.num_stages -= 2
cache.step = 1
cache.hist_iter = iter
Expand All @@ -45,7 +45,7 @@ function step_reject_controller!(integrator, controller::PredictiveController, a
hist_iter = hist_iter * 0.8 + iter * 0.2
cache.hist_iter = hist_iter
if (step > 10)
if ((hist_iter > 8 || cache.status == VerySlowConvergence || cache.status == Divergence) && num_stages > alg.min_stages)
if ((hist_iter > 8 || cache.status == VerySlowConvergence || cache.status == Divergence) && num_stages > (alg.min_order + 1) ÷ 2)
cache.num_stages -= 2
cache.step = 1
cache.hist_iter = iter
Expand Down
19 changes: 12 additions & 7 deletions lib/OrdinaryDiffEqFIRK/src/firk_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -497,12 +497,12 @@ function alg_cache(alg::AdaptiveRadau, u, rate_prototype, ::Type{uEltypeNoUnits}
::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
uf = UDerivativeWrapper(f, t, p)
uToltype = constvalue(uBottomEltypeNoUnits)
num_stages = alg.min_stages
max = alg.max_stages
max = (alg.max_order + 1) ÷ 2
num_stages = (alg.min_order + 1) ÷ 2
tabs = [BigRadauIIA5Tableau(uToltype, constvalue(tTypeNoUnits)), BigRadauIIA9Tableau(uToltype, constvalue(tTypeNoUnits)), BigRadauIIA13Tableau(uToltype, constvalue(tTypeNoUnits))]

i = 9
while i <= alg.max_stages
while i <= max
push!(tabs, adaptiveRadauTableau(uToltype, constvalue(tTypeNoUnits), i))
i += 2
end
Expand All @@ -525,6 +525,8 @@ mutable struct AdaptiveRadauCache{uType, cuType, tType, uNoUnitsType, rateType,
z::Vector{uType}
w::Vector{uType}
c_prime::Vector{tType}
αdt::Vector{tType}
βdt::Vector{tType}
dw1::uType
ubuff::uType
dw2::Vector{cuType}
Expand Down Expand Up @@ -568,8 +570,8 @@ function alg_cache(alg::AdaptiveRadau, u, rate_prototype, ::Type{uEltypeNoUnits}
uf = UJacobianWrapper(f, t, p)
uToltype = constvalue(uBottomEltypeNoUnits)

max = alg.max_stages
num_stages = alg.min_stages
max = (alg.max_order + 1) ÷ 2
num_stages = (alg.min_order + 1) ÷ 2

tabs = [BigRadauIIA5Tableau(uToltype, constvalue(tTypeNoUnits)), BigRadauIIA9Tableau(uToltype, constvalue(tTypeNoUnits)), BigRadauIIA13Tableau(uToltype, constvalue(tTypeNoUnits))]
i = 9
Expand All @@ -583,9 +585,12 @@ function alg_cache(alg::AdaptiveRadau, u, rate_prototype, ::Type{uEltypeNoUnits}
z = Vector{typeof(u)}(undef, max)
w = Vector{typeof(u)}(undef, max)
for i in 1 : max
z[i] = w[i] = zero(u)
z[i] = zero(u)
w[i] = zero(u)
end

αdt = [zero(t) for i in 1:max]
βdt = [zero(t) for i in 1:max]
c_prime = Vector{typeof(t)}(undef, max) #time stepping
for i in 1 : max
c_prime[i] = zero(t)
Expand Down Expand Up @@ -641,7 +646,7 @@ function alg_cache(alg::AdaptiveRadau, u, rate_prototype, ::Type{uEltypeNoUnits}
atol = reltol isa Number ? reltol : zero(reltol)

AdaptiveRadauCache(u, uprev,
z, w, c_prime, dw1, ubuff, dw2, cubuff, dw, cont, derivatives,
z, w, c_prime, αdt, βdt, dw1, ubuff, dw2, cubuff, dw, cont, derivatives,
du1, fsalfirst, ks, k, fw,
J, W1, W2,
uf, tabs, κ, one(uToltype), 10000, tmp,
Expand Down
13 changes: 9 additions & 4 deletions lib/OrdinaryDiffEqFIRK/src/firk_perform_step.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1598,7 +1598,7 @@ end
@unpack num_stages, tabs = cache
tab = tabs[(num_stages - 1) ÷ 2]
@unpack T, TI, γ, α, β, c, e = tab
@unpack κ, cont, derivatives, z, w, c_prime = cache
@unpack κ, cont, derivatives, z, w, c_prime, αdt, βdt= cache
@unpack dw1, ubuff, dw2, cubuff, dw = cache
@unpack ks, k, fw, J, W1, W2 = cache
@unpack tmp, atmp, jac_config, linsolve1, linsolve2, rtol, atol, step_limiter! = cache
Expand All @@ -1608,7 +1608,12 @@ end
mass_matrix = integrator.f.mass_matrix

# precalculations
γdt, αdt, βdt = γ / dt, α ./ dt, β ./ dt
γdt = γ / dt
for i in 1 : (num_stages - 1) ÷ 2
αdt[i] = α[i]/dt
βdt[i] = β[i]/dt
end

(new_jac = do_newJ(integrator, alg, cache, repeat_step)) &&
(calc_J!(J, integrator, cache); cache.W_γdt = dt)
if (new_W = do_newW(integrator, alg, new_jac, cache.W_γdt))
Expand Down Expand Up @@ -1750,12 +1755,12 @@ end
# transform `w` to `z`
#mul!(z, T, w)
for i in 1:num_stages - 1
z[i] = zero(u)
@.. z[i] = zero(u)
for j in 1:num_stages
@.. z[i] += T[i,j] * w[j]
end
end
z[num_stages] = T[num_stages, 1] * w[1]
@.. z[num_stages] = T[num_stages, 1] * w[1]
i = 2
while i < num_stages
@.. z[num_stages] += w[i]
Expand Down

0 comments on commit e50bf9a

Please sign in to comment.