Skip to content

Commit

Permalink
speed ups
Browse files Browse the repository at this point in the history
  • Loading branch information
Shreyas-Ekanathan committed Nov 19, 2024
1 parent 9ec5720 commit 937641f
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 52 deletions.
4 changes: 2 additions & 2 deletions lib/OrdinaryDiffEqFIRK/src/algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,8 @@ struct AdaptiveRadau{CS, AD, F, P, FDT, ST, CJ, Tol, C1, C2, StepLimiter} <:
new_W_γdt_cutoff::C2
controller::Symbol
step_limiter!::StepLimiter
min_stages::Int
max_stages::Int
min_order::Int
max_order::Int
end

function AdaptiveRadau(; chunk_size = Val{0}(), autodiff = Val{true}(),
Expand Down
9 changes: 6 additions & 3 deletions lib/OrdinaryDiffEqFIRK/src/controllers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,14 @@ function step_accept_controller!(integrator, controller::PredictiveController, a
cache.step = step + 1
hist_iter = hist_iter * 0.8 + iter * 0.2
cache.hist_iter = hist_iter
max_stages = (alg.max_order - 1) ÷ 4 * 2 + 1
min_stages = (alg.min_order - 1) ÷ 4 * 2 + 1
if (step > 10)
if (hist_iter < 2.6 && num_stages < (alg.max_order + 1) ÷ 2)
if (hist_iter < 2.6 && num_stages <= max_stages)
cache.num_stages += 2
cache.step = 1
cache.hist_iter = iter
elseif ((hist_iter > 8 || cache.status == VerySlowConvergence || cache.status == Divergence) && num_stages > (alg.min_order + 1) ÷ 2)
elseif ((hist_iter > 8 || cache.status == VerySlowConvergence || cache.status == Divergence) && num_stages >= min_stages)
cache.num_stages -= 2
cache.step = 1
cache.hist_iter = iter
Expand All @@ -44,8 +46,9 @@ function step_reject_controller!(integrator, controller::PredictiveController, a
cache.step = step + 1
hist_iter = hist_iter * 0.8 + iter * 0.2
cache.hist_iter = hist_iter
min_stages = (alg.min_order - 1) ÷ 4 * 2 + 1
if (step > 10)
if ((hist_iter > 8 || cache.status == VerySlowConvergence || cache.status == Divergence) && num_stages > (alg.min_order + 1) ÷ 2)
if ((hist_iter > 8 || cache.status == VerySlowConvergence || cache.status == Divergence) && num_stages >= min_stages)
cache.num_stages -= 2
cache.step = 1
cache.hist_iter = iter
Expand Down
39 changes: 32 additions & 7 deletions lib/OrdinaryDiffEqFIRK/src/firk_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,10 @@ mutable struct RadauIIA9Cache{uType, cuType, uNoUnitsType, rateType, JType, W1Ty
tmp4::uType
tmp5::uType
tmp6::uType
tmp7::uType
tmp8::uType
tmp9::uType
tmp10::uType
atmp::uNoUnitsType
jac_config::JC
linsolve1::F1
Expand Down Expand Up @@ -440,6 +444,10 @@ function alg_cache(alg::RadauIIA9, u, rate_prototype, ::Type{uEltypeNoUnits},
tmp4 = zero(u)
tmp5 = zero(u)
tmp6 = zero(u)
tmp7 = zero(u)
tmp8 = zero(u)
tmp9 = zero(u)
tmp10 = zero(u)
atmp = similar(u, uEltypeNoUnits)
recursivefill!(atmp, false)
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, dw1)
Expand Down Expand Up @@ -469,7 +477,7 @@ function alg_cache(alg::RadauIIA9, u, rate_prototype, ::Type{uEltypeNoUnits},
du1, fsalfirst, k, k2, k3, k4, k5, fw1, fw2, fw3, fw4, fw5,
J, W1, W2, W3,
uf, tab, κ, one(uToltype), 10000,
tmp, tmp2, tmp3, tmp4, tmp5, tmp6, atmp, jac_config,
tmp, tmp2, tmp3, tmp4, tmp5, tmp6, tmp7, tmp8, tmp9, tmp10, atmp, jac_config,
linsolve1, linsolve2, linsolve3, rtol, atol, dt, dt,
Convergence, alg.step_limiter!)
end
Expand Down Expand Up @@ -497,17 +505,26 @@ function alg_cache(alg::AdaptiveRadau, u, rate_prototype, ::Type{uEltypeNoUnits}
::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
uf = UDerivativeWrapper(f, t, p)
uToltype = constvalue(uBottomEltypeNoUnits)
max = (alg.max_order + 1) ÷ 2
num_stages = (alg.min_order + 1) ÷ 2

max_order = alg.max_order
min_order = alg.min_order
max = (max_order - 1) ÷ 4 * 2 + 1
min = (min_order - 1) ÷ 4 * 2 + 1
if (alg.min_order < 5)
error("min_order choice $min_order below 5 is not compatible with the algorithm")
elseif (max < min)
error("max_order $max_order is below min_order $min_order")
end
num_stages = min

tabs = [BigRadauIIA5Tableau(uToltype, constvalue(tTypeNoUnits)), BigRadauIIA9Tableau(uToltype, constvalue(tTypeNoUnits)), BigRadauIIA13Tableau(uToltype, constvalue(tTypeNoUnits))]

i = 9
while i <= max
push!(tabs, adaptiveRadauTableau(uToltype, constvalue(tTypeNoUnits), i))
i += 2
end
cont = Vector{typeof(u)}(undef, max)
for i in 1: max
for i in 1:max
cont[i] = zero(u)
end

Expand Down Expand Up @@ -570,8 +587,16 @@ function alg_cache(alg::AdaptiveRadau, u, rate_prototype, ::Type{uEltypeNoUnits}
uf = UJacobianWrapper(f, t, p)
uToltype = constvalue(uBottomEltypeNoUnits)

max = (alg.max_order + 1) ÷ 2
num_stages = (alg.min_order + 1) ÷ 2
max_order = alg.max_order
min_order = alg.min_order
max = (max_order - 1) ÷ 4 * 2 + 1
min = (min_order - 1) ÷ 4 * 2 + 1
if (alg.min_order < 5)
error("min_order choice $min_order below 5 is not compatible with the algorithm")
elseif (max < min)
error("max_order $max_order is below min_order $min_order")
end
num_stages = min

tabs = [BigRadauIIA5Tableau(uToltype, constvalue(tTypeNoUnits)), BigRadauIIA9Tableau(uToltype, constvalue(tTypeNoUnits)), BigRadauIIA13Tableau(uToltype, constvalue(tTypeNoUnits))]
i = 9
Expand Down
75 changes: 35 additions & 40 deletions lib/OrdinaryDiffEqFIRK/src/firk_perform_step.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1032,7 +1032,7 @@ end
@unpack dw1, ubuff, dw23, dw45, cubuff1, cubuff2 = cache
@unpack k, k2, k3, k4, k5, fw1, fw2, fw3, fw4, fw5 = cache
@unpack J, W1, W2, W3 = cache
@unpack tmp, tmp2, tmp3, tmp4, tmp5, tmp6, atmp, jac_config, linsolve1, linsolve2, linsolve3, rtol, atol, step_limiter! = cache
@unpack tmp, tmp2, tmp3, tmp4, tmp5, tmp6, tmp7, tmp8, tmp9, tmp10, atmp, jac_config, linsolve1, linsolve2, linsolve3, rtol, atol, step_limiter! = cache
@unpack internalnorm, abstol, reltol, adaptive = integrator.opts
alg = unwrap_alg(integrator, true)
@unpack maxiters = alg
Expand Down Expand Up @@ -1087,30 +1087,30 @@ end
c2′ = c2 * c5′
c3′ = c3 * c5′
c4′ = c4 * c5′
z1 = @.. c1′ * (cont1 +
@.. z1 = c1′ * (cont1 +
(c1′-c4m1) * (cont2 +
(c1′ - c3m1) * (cont3 +
(c1′ - c2m1) * (cont4 + (c1′ - c1m1) * cont5))))
z2 = @.. c2′ * (cont1 +
@.. z2 = c2′ * (cont1 +
(c2′-c4m1) * (cont2 +
(c2′ - c3m1) * (cont3 +
(c2′ - c2m1) * (cont4 + (c2′ - c1m1) * cont5))))
z3 = @.. c3′ * (cont1 +
@.. z3 = c3′ * (cont1 +
(c3′-c4m1) * (cont2 +
(c3′ - c3m1) * (cont3 +
(c3′ - c2m1) * (cont4 + (c3′ - c1m1) * cont5))))
z4 = @.. c4′ * (cont1 +
@.. z4 = c4′ * (cont1 +
(c4′-c4m1) * (cont2 +
(c4′ - c3m1) * (cont3 +
(c4′ - c2m1) * (cont4 + (c4′ - c1m1) * cont5))))
z5 = @.. c5′ * (cont1 +
@.. z5 = c5′ * (cont1 +
(c5′-c4m1) * (cont2 +
(c5′ - c3m1) * (cont3 + (c5′ - c2m1) * (cont4 + (c5′ - c1m1) * cont5))))
w1 = @.. broadcast=false TI11*z1+TI12*z2+TI13*z3+TI14*z4+TI15*z5
w2 = @.. broadcast=false TI21*z1+TI22*z2+TI23*z3+TI24*z4+TI25*z5
w3 = @.. broadcast=false TI31*z1+TI32*z2+TI33*z3+TI34*z4+TI35*z5
w4 = @.. broadcast=false TI41*z1+TI42*z2+TI43*z3+TI44*z4+TI45*z5
w5 = @.. broadcast=false TI51*z1+TI52*z2+TI53*z3+TI54*z4+TI55*z5
@.. w1 = TI11*z1+TI12*z2+TI13*z3+TI14*z4+TI15*z5
@.. w2 = TI21*z1+TI22*z2+TI23*z3+TI24*z4+TI25*z5
@.. w3 = TI31*z1+TI32*z2+TI33*z3+TI34*z4+TI35*z5
@.. w4 = TI41*z1+TI42*z2+TI43*z3+TI44*z4+TI45*z5
@.. w5 = TI51*z1+TI52*z2+TI53*z3+TI54*z4+TI55*z5
end

# Newton iteration
Expand Down Expand Up @@ -1328,21 +1328,21 @@ end
if integrator.EEst <= oneunit(integrator.EEst)
cache.dtprev = dt
if alg.extrapolant != :constant
cache.cont1 = @.. (z4 - z5) / c4m1 # first derivative on [c4, 1]
tmp1 = @.. (z3 - z4) / c3mc4 # first derivative on [c3, c4]
cache.cont2 = @.. (tmp1 - cache.cont1) / c3m1 # second derivative on [c3, 1]
tmp2 = @.. (z2 - z3) / c2mc3 # first derivative on [c2, c3]
tmp3 = @.. (tmp2 - tmp1) / c2mc4 # second derivative on [c2, c4]
cache.cont3 = @.. (tmp3 - cache.cont2) / c2m1 # third derivative on [c2, 1]
tmp4 = @.. (z1 - z2) / c1mc2 # first derivative on [c1, c2]
tmp5 = @.. (tmp4 - tmp2) / c1mc3 # second derivative on [c1, c3]
tmp6 = @.. (tmp5 - tmp3) / c1mc4 # third derivative on [c1, c4]
cache.cont4 = @.. (tmp6 - cache.cont3) / c1m1 #fourth derivative on [c1, 1]
tmp7 = @.. z1 / c1 #first derivative on [0, c1]
tmp8 = @.. (tmp4 - tmp7) / c2 #second derivative on [0, c2]
tmp9 = @.. (tmp5 - tmp8) / c3 #third derivative on [0, c3]
tmp10 = @.. (tmp6 - tmp9) / c4 #fourth derivative on [0,c4]
cache.cont5 = @.. cache.cont4 - tmp10 #fifth derivative on [0,1]
@.. cache.cont1 = (z4 - z5) / c4m1 # first derivative on [c4, 1]
@.. tmp = (z3 - z4) / c3mc4 # first derivative on [c3, c4]
@.. cache.cont2 = (tmp - cache.cont1) / c3m1 # second derivative on [c3, 1]
@.. tmp2 = (z2 - z3) / c2mc3 # first derivative on [c2, c3]
@.. tmp3 = (tmp2 - tmp) / c2mc4 # second derivative on [c2, c4]
@.. cache.cont3 = (tmp3 - cache.cont2) / c2m1 # third derivative on [c2, 1]
@.. tmp4 = (z1 - z2) / c1mc2 # first derivative on [c1, c2]
@.. tmp5 = (tmp4 - tmp2) / c1mc3 # second derivative on [c1, c3]
@.. tmp6 = (tmp5 - tmp3) / c1mc4 # third derivative on [c1, c4]
@.. cache.cont4 = (tmp6 - cache.cont3) / c1m1 #fourth derivative on [c1, 1]
@.. tmp7 = z1 / c1 #first derivative on [0, c1]
@.. tmp8 = (tmp4 - tmp7) / c2 #second derivative on [0, c2]
@.. tmp9 = (tmp5 - tmp8) / c3 #third derivative on [0, c3]
@.. tmp10 = (tmp6 - tmp9) / c4 #fourth derivative on [0,c4]
@.. cache.cont5 = cache.cont4 - tmp10 #fifth derivative on [0,1]
end
end

Expand Down Expand Up @@ -1437,7 +1437,7 @@ end
for i in 1 : num_stages
z[i] = f(uprev + z[i], p, t + c[i] * dt)
end
OrdinaryDiffEqCore.increment_nf!(integrator.stats, 5)
OrdinaryDiffEqCore.increment_nf!(integrator.stats, num_stages)

#fw = TI * ff
fw = Vector{typeof(u)}(undef, num_stages)
Expand Down Expand Up @@ -1619,7 +1619,7 @@ end
if (new_W = do_newW(integrator, alg, new_jac, cache.W_γdt))
@inbounds for II in CartesianIndices(J)
W1[II] = -γdt * mass_matrix[Tuple(II)...] + J[II]
for i in 1 :(num_stages - 1) ÷ 2
for i in 1 : (num_stages - 1) ÷ 2
W2[i][II] = -(αdt[i] + βdt[i] * im) * mass_matrix[Tuple(II)...] + J[II]
end
end
Expand Down Expand Up @@ -1673,7 +1673,7 @@ end
@.. tmp = uprev + z[i]
f(ks[i], tmp, p, t + c[i] * dt)
end
OrdinaryDiffEqCore.increment_nf!(integrator.stats, 5)
OrdinaryDiffEqCore.increment_nf!(integrator.stats, num_stages)

#mul!(fw, TI, ks)
for i in 1:num_stages
Expand All @@ -1700,15 +1700,12 @@ end
@.. ubuff = fw[1] - γdt * Mw[1]
needfactor = iter == 1 && new_W

linsolve1 = cache.linsolve1
if needfactor
linres = dolinsolve(integrator, linsolve1; A = W1, b = _vec(ubuff), linu = _vec(dw1))
cache.linsolve1 = dolinsolve(integrator, linsolve1; A = W1, b = _vec(ubuff), linu = _vec(dw1)).cache
else
linres = dolinsolve(integrator, linsolve1; A = nothing, b = _vec(ubuff), linu = _vec(dw1))
cache.linsolve1 = dolinsolve(integrator, linsolve1; A = nothing, b = _vec(ubuff), linu = _vec(dw1)).cache
end

cache.linsolve1 = linres.cache

for i in 1 :(num_stages - 1) ÷ 2
@.. cubuff[i]=complex(
fw[2 * i] - αdt[i] * Mw[2 * i] + βdt[i] * Mw[2 * i + 1], fw[2 * i + 1] - βdt[i] * Mw[2 * i] - αdt[i] * Mw[2 * i + 1])
Expand Down Expand Up @@ -1801,9 +1798,8 @@ end
@.. broadcast=false ubuff=integrator.fsalfirst + tmp

if alg.smooth_est
linres = dolinsolve(integrator, linres.cache; b = _vec(ubuff),
linu = _vec(utilde))
cache.linsolve1 = linres.cache
cache.linsolve1 = dolinsolve(integrator, linsolve1; b = _vec(ubuff),
linu = _vec(utilde)).cache
integrator.stats.nsolve += 1
end

Expand All @@ -1821,9 +1817,8 @@ end
@.. broadcast=false ubuff=fsallast + tmp

if alg.smooth_est
linres = dolinsolve(integrator, linres.cache; b = _vec(ubuff),
linu = _vec(utilde))
cache.linsolve1 = linres.cache
cache.linsolve1 = dolinsolve(integrator, linsolve1; b = _vec(ubuff),
linu = _vec(utilde)).cache
integrator.stats.nsolve += 1
end

Expand Down

0 comments on commit 937641f

Please sign in to comment.