From 937641f9cbf10234152a12300f82c1338ecbe1c1 Mon Sep 17 00:00:00 2001 From: Shreyas Ekanathan Date: Mon, 18 Nov 2024 21:31:51 -0500 Subject: [PATCH] speed ups --- lib/OrdinaryDiffEqFIRK/src/algorithms.jl | 4 +- lib/OrdinaryDiffEqFIRK/src/controllers.jl | 9 ++- lib/OrdinaryDiffEqFIRK/src/firk_caches.jl | 39 ++++++++-- .../src/firk_perform_step.jl | 75 +++++++++---------- 4 files changed, 75 insertions(+), 52 deletions(-) diff --git a/lib/OrdinaryDiffEqFIRK/src/algorithms.jl b/lib/OrdinaryDiffEqFIRK/src/algorithms.jl index ba45841ac6..e98389ec7c 100644 --- a/lib/OrdinaryDiffEqFIRK/src/algorithms.jl +++ b/lib/OrdinaryDiffEqFIRK/src/algorithms.jl @@ -163,8 +163,8 @@ struct AdaptiveRadau{CS, AD, F, P, FDT, ST, CJ, Tol, C1, C2, StepLimiter} <: new_W_γdt_cutoff::C2 controller::Symbol step_limiter!::StepLimiter - min_stages::Int - max_stages::Int + min_order::Int + max_order::Int end function AdaptiveRadau(; chunk_size = Val{0}(), autodiff = Val{true}(), diff --git a/lib/OrdinaryDiffEqFIRK/src/controllers.jl b/lib/OrdinaryDiffEqFIRK/src/controllers.jl index 95816c8b32..3849d8114c 100644 --- a/lib/OrdinaryDiffEqFIRK/src/controllers.jl +++ b/lib/OrdinaryDiffEqFIRK/src/controllers.jl @@ -22,12 +22,14 @@ function step_accept_controller!(integrator, controller::PredictiveController, a cache.step = step + 1 hist_iter = hist_iter * 0.8 + iter * 0.2 cache.hist_iter = hist_iter + max_stages = (alg.max_order - 1) ÷ 4 * 2 + 1 + min_stages = (alg.min_order - 1) ÷ 4 * 2 + 1 if (step > 10) - if (hist_iter < 2.6 && num_stages < (alg.max_order + 1) ÷ 2) + if (hist_iter < 2.6 && num_stages <= max_stages) cache.num_stages += 2 cache.step = 1 cache.hist_iter = iter - elseif ((hist_iter > 8 || cache.status == VerySlowConvergence || cache.status == Divergence) && num_stages > (alg.min_order + 1) ÷ 2) + elseif ((hist_iter > 8 || cache.status == VerySlowConvergence || cache.status == Divergence) && num_stages >= min_stages) cache.num_stages -= 2 cache.step = 1 cache.hist_iter = iter @@ -44,8 +46,9 @@ function step_reject_controller!(integrator, controller::PredictiveController, a cache.step = step + 1 hist_iter = hist_iter * 0.8 + iter * 0.2 cache.hist_iter = hist_iter + min_stages = (alg.min_order - 1) ÷ 4 * 2 + 1 if (step > 10) - if ((hist_iter > 8 || cache.status == VerySlowConvergence || cache.status == Divergence) && num_stages > (alg.min_order + 1) ÷ 2) + if ((hist_iter > 8 || cache.status == VerySlowConvergence || cache.status == Divergence) && num_stages >= min_stages) cache.num_stages -= 2 cache.step = 1 cache.hist_iter = iter diff --git a/lib/OrdinaryDiffEqFIRK/src/firk_caches.jl b/lib/OrdinaryDiffEqFIRK/src/firk_caches.jl index e1dbe08ad8..b971b3cc4f 100644 --- a/lib/OrdinaryDiffEqFIRK/src/firk_caches.jl +++ b/lib/OrdinaryDiffEqFIRK/src/firk_caches.jl @@ -362,6 +362,10 @@ mutable struct RadauIIA9Cache{uType, cuType, uNoUnitsType, rateType, JType, W1Ty tmp4::uType tmp5::uType tmp6::uType + tmp7::uType + tmp8::uType + tmp9::uType + tmp10::uType atmp::uNoUnitsType jac_config::JC linsolve1::F1 @@ -440,6 +444,10 @@ function alg_cache(alg::RadauIIA9, u, rate_prototype, ::Type{uEltypeNoUnits}, tmp4 = zero(u) tmp5 = zero(u) tmp6 = zero(u) + tmp7 = zero(u) + tmp8 = zero(u) + tmp9 = zero(u) + tmp10 = zero(u) atmp = similar(u, uEltypeNoUnits) recursivefill!(atmp, false) jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, dw1) @@ -469,7 +477,7 @@ function alg_cache(alg::RadauIIA9, u, rate_prototype, ::Type{uEltypeNoUnits}, du1, fsalfirst, k, k2, k3, k4, k5, fw1, fw2, fw3, fw4, fw5, J, W1, W2, W3, uf, tab, κ, one(uToltype), 10000, - tmp, tmp2, tmp3, tmp4, tmp5, tmp6, atmp, jac_config, + tmp, tmp2, tmp3, tmp4, tmp5, tmp6, tmp7, tmp8, tmp9, tmp10, atmp, jac_config, linsolve1, linsolve2, linsolve3, rtol, atol, dt, dt, Convergence, alg.step_limiter!) end @@ -497,17 +505,26 @@ function alg_cache(alg::AdaptiveRadau, u, rate_prototype, ::Type{uEltypeNoUnits} ::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} uf = UDerivativeWrapper(f, t, p) uToltype = constvalue(uBottomEltypeNoUnits) - max = (alg.max_order + 1) ÷ 2 - num_stages = (alg.min_order + 1) ÷ 2 + + max_order = alg.max_order + min_order = alg.min_order + max = (max_order - 1) ÷ 4 * 2 + 1 + min = (min_order - 1) ÷ 4 * 2 + 1 + if (alg.min_order < 5) + error("min_order choice $min_order below 5 is not compatible with the algorithm") + elseif (max < min) + error("max_order $max_order is below min_order $min_order") + end + num_stages = min + tabs = [BigRadauIIA5Tableau(uToltype, constvalue(tTypeNoUnits)), BigRadauIIA9Tableau(uToltype, constvalue(tTypeNoUnits)), BigRadauIIA13Tableau(uToltype, constvalue(tTypeNoUnits))] - i = 9 while i <= max push!(tabs, adaptiveRadauTableau(uToltype, constvalue(tTypeNoUnits), i)) i += 2 end cont = Vector{typeof(u)}(undef, max) - for i in 1: max + for i in 1:max cont[i] = zero(u) end @@ -570,8 +587,16 @@ function alg_cache(alg::AdaptiveRadau, u, rate_prototype, ::Type{uEltypeNoUnits} uf = UJacobianWrapper(f, t, p) uToltype = constvalue(uBottomEltypeNoUnits) - max = (alg.max_order + 1) ÷ 2 - num_stages = (alg.min_order + 1) ÷ 2 + max_order = alg.max_order + min_order = alg.min_order + max = (max_order - 1) ÷ 4 * 2 + 1 + min = (min_order - 1) ÷ 4 * 2 + 1 + if (alg.min_order < 5) + error("min_order choice $min_order below 5 is not compatible with the algorithm") + elseif (max < min) + error("max_order $max_order is below min_order $min_order") + end + num_stages = min tabs = [BigRadauIIA5Tableau(uToltype, constvalue(tTypeNoUnits)), BigRadauIIA9Tableau(uToltype, constvalue(tTypeNoUnits)), BigRadauIIA13Tableau(uToltype, constvalue(tTypeNoUnits))] i = 9 diff --git a/lib/OrdinaryDiffEqFIRK/src/firk_perform_step.jl b/lib/OrdinaryDiffEqFIRK/src/firk_perform_step.jl index 874d968cb6..2058c4fb15 100644 --- a/lib/OrdinaryDiffEqFIRK/src/firk_perform_step.jl +++ b/lib/OrdinaryDiffEqFIRK/src/firk_perform_step.jl @@ -1032,7 +1032,7 @@ end @unpack dw1, ubuff, dw23, dw45, cubuff1, cubuff2 = cache @unpack k, k2, k3, k4, k5, fw1, fw2, fw3, fw4, fw5 = cache @unpack J, W1, W2, W3 = cache - @unpack tmp, tmp2, tmp3, tmp4, tmp5, tmp6, atmp, jac_config, linsolve1, linsolve2, linsolve3, rtol, atol, step_limiter! = cache + @unpack tmp, tmp2, tmp3, tmp4, tmp5, tmp6, tmp7, tmp8, tmp9, tmp10, atmp, jac_config, linsolve1, linsolve2, linsolve3, rtol, atol, step_limiter! = cache @unpack internalnorm, abstol, reltol, adaptive = integrator.opts alg = unwrap_alg(integrator, true) @unpack maxiters = alg @@ -1087,30 +1087,30 @@ end c2′ = c2 * c5′ c3′ = c3 * c5′ c4′ = c4 * c5′ - z1 = @.. c1′ * (cont1 + + @.. z1 = c1′ * (cont1 + (c1′-c4m1) * (cont2 + (c1′ - c3m1) * (cont3 + (c1′ - c2m1) * (cont4 + (c1′ - c1m1) * cont5)))) - z2 = @.. c2′ * (cont1 + + @.. z2 = c2′ * (cont1 + (c2′-c4m1) * (cont2 + (c2′ - c3m1) * (cont3 + (c2′ - c2m1) * (cont4 + (c2′ - c1m1) * cont5)))) - z3 = @.. c3′ * (cont1 + + @.. z3 = c3′ * (cont1 + (c3′-c4m1) * (cont2 + (c3′ - c3m1) * (cont3 + (c3′ - c2m1) * (cont4 + (c3′ - c1m1) * cont5)))) - z4 = @.. c4′ * (cont1 + + @.. z4 = c4′ * (cont1 + (c4′-c4m1) * (cont2 + (c4′ - c3m1) * (cont3 + (c4′ - c2m1) * (cont4 + (c4′ - c1m1) * cont5)))) - z5 = @.. c5′ * (cont1 + + @.. z5 = c5′ * (cont1 + (c5′-c4m1) * (cont2 + (c5′ - c3m1) * (cont3 + (c5′ - c2m1) * (cont4 + (c5′ - c1m1) * cont5)))) - w1 = @.. broadcast=false TI11*z1+TI12*z2+TI13*z3+TI14*z4+TI15*z5 - w2 = @.. broadcast=false TI21*z1+TI22*z2+TI23*z3+TI24*z4+TI25*z5 - w3 = @.. broadcast=false TI31*z1+TI32*z2+TI33*z3+TI34*z4+TI35*z5 - w4 = @.. broadcast=false TI41*z1+TI42*z2+TI43*z3+TI44*z4+TI45*z5 - w5 = @.. broadcast=false TI51*z1+TI52*z2+TI53*z3+TI54*z4+TI55*z5 + @.. w1 = TI11*z1+TI12*z2+TI13*z3+TI14*z4+TI15*z5 + @.. w2 = TI21*z1+TI22*z2+TI23*z3+TI24*z4+TI25*z5 + @.. w3 = TI31*z1+TI32*z2+TI33*z3+TI34*z4+TI35*z5 + @.. w4 = TI41*z1+TI42*z2+TI43*z3+TI44*z4+TI45*z5 + @.. w5 = TI51*z1+TI52*z2+TI53*z3+TI54*z4+TI55*z5 end # Newton iteration @@ -1328,21 +1328,21 @@ end if integrator.EEst <= oneunit(integrator.EEst) cache.dtprev = dt if alg.extrapolant != :constant - cache.cont1 = @.. (z4 - z5) / c4m1 # first derivative on [c4, 1] - tmp1 = @.. (z3 - z4) / c3mc4 # first derivative on [c3, c4] - cache.cont2 = @.. (tmp1 - cache.cont1) / c3m1 # second derivative on [c3, 1] - tmp2 = @.. (z2 - z3) / c2mc3 # first derivative on [c2, c3] - tmp3 = @.. (tmp2 - tmp1) / c2mc4 # second derivative on [c2, c4] - cache.cont3 = @.. (tmp3 - cache.cont2) / c2m1 # third derivative on [c2, 1] - tmp4 = @.. (z1 - z2) / c1mc2 # first derivative on [c1, c2] - tmp5 = @.. (tmp4 - tmp2) / c1mc3 # second derivative on [c1, c3] - tmp6 = @.. (tmp5 - tmp3) / c1mc4 # third derivative on [c1, c4] - cache.cont4 = @.. (tmp6 - cache.cont3) / c1m1 #fourth derivative on [c1, 1] - tmp7 = @.. z1 / c1 #first derivative on [0, c1] - tmp8 = @.. (tmp4 - tmp7) / c2 #second derivative on [0, c2] - tmp9 = @.. (tmp5 - tmp8) / c3 #third derivative on [0, c3] - tmp10 = @.. (tmp6 - tmp9) / c4 #fourth derivative on [0,c4] - cache.cont5 = @.. cache.cont4 - tmp10 #fifth derivative on [0,1] + @.. cache.cont1 = (z4 - z5) / c4m1 # first derivative on [c4, 1] + @.. tmp = (z3 - z4) / c3mc4 # first derivative on [c3, c4] + @.. cache.cont2 = (tmp - cache.cont1) / c3m1 # second derivative on [c3, 1] + @.. tmp2 = (z2 - z3) / c2mc3 # first derivative on [c2, c3] + @.. tmp3 = (tmp2 - tmp) / c2mc4 # second derivative on [c2, c4] + @.. cache.cont3 = (tmp3 - cache.cont2) / c2m1 # third derivative on [c2, 1] + @.. tmp4 = (z1 - z2) / c1mc2 # first derivative on [c1, c2] + @.. tmp5 = (tmp4 - tmp2) / c1mc3 # second derivative on [c1, c3] + @.. tmp6 = (tmp5 - tmp3) / c1mc4 # third derivative on [c1, c4] + @.. cache.cont4 = (tmp6 - cache.cont3) / c1m1 #fourth derivative on [c1, 1] + @.. tmp7 = z1 / c1 #first derivative on [0, c1] + @.. tmp8 = (tmp4 - tmp7) / c2 #second derivative on [0, c2] + @.. tmp9 = (tmp5 - tmp8) / c3 #third derivative on [0, c3] + @.. tmp10 = (tmp6 - tmp9) / c4 #fourth derivative on [0,c4] + @.. cache.cont5 = cache.cont4 - tmp10 #fifth derivative on [0,1] end end @@ -1437,7 +1437,7 @@ end for i in 1 : num_stages z[i] = f(uprev + z[i], p, t + c[i] * dt) end - OrdinaryDiffEqCore.increment_nf!(integrator.stats, 5) + OrdinaryDiffEqCore.increment_nf!(integrator.stats, num_stages) #fw = TI * ff fw = Vector{typeof(u)}(undef, num_stages) @@ -1619,7 +1619,7 @@ end if (new_W = do_newW(integrator, alg, new_jac, cache.W_γdt)) @inbounds for II in CartesianIndices(J) W1[II] = -γdt * mass_matrix[Tuple(II)...] + J[II] - for i in 1 :(num_stages - 1) ÷ 2 + for i in 1 : (num_stages - 1) ÷ 2 W2[i][II] = -(αdt[i] + βdt[i] * im) * mass_matrix[Tuple(II)...] + J[II] end end @@ -1673,7 +1673,7 @@ end @.. tmp = uprev + z[i] f(ks[i], tmp, p, t + c[i] * dt) end - OrdinaryDiffEqCore.increment_nf!(integrator.stats, 5) + OrdinaryDiffEqCore.increment_nf!(integrator.stats, num_stages) #mul!(fw, TI, ks) for i in 1:num_stages @@ -1700,15 +1700,12 @@ end @.. ubuff = fw[1] - γdt * Mw[1] needfactor = iter == 1 && new_W - linsolve1 = cache.linsolve1 if needfactor - linres = dolinsolve(integrator, linsolve1; A = W1, b = _vec(ubuff), linu = _vec(dw1)) + cache.linsolve1 = dolinsolve(integrator, linsolve1; A = W1, b = _vec(ubuff), linu = _vec(dw1)).cache else - linres = dolinsolve(integrator, linsolve1; A = nothing, b = _vec(ubuff), linu = _vec(dw1)) + cache.linsolve1 = dolinsolve(integrator, linsolve1; A = nothing, b = _vec(ubuff), linu = _vec(dw1)).cache end - cache.linsolve1 = linres.cache - for i in 1 :(num_stages - 1) ÷ 2 @.. cubuff[i]=complex( fw[2 * i] - αdt[i] * Mw[2 * i] + βdt[i] * Mw[2 * i + 1], fw[2 * i + 1] - βdt[i] * Mw[2 * i] - αdt[i] * Mw[2 * i + 1]) @@ -1801,9 +1798,8 @@ end @.. broadcast=false ubuff=integrator.fsalfirst + tmp if alg.smooth_est - linres = dolinsolve(integrator, linres.cache; b = _vec(ubuff), - linu = _vec(utilde)) - cache.linsolve1 = linres.cache + cache.linsolve1 = dolinsolve(integrator, linsolve1; b = _vec(ubuff), + linu = _vec(utilde)).cache integrator.stats.nsolve += 1 end @@ -1821,9 +1817,8 @@ end @.. broadcast=false ubuff=fsallast + tmp if alg.smooth_est - linres = dolinsolve(integrator, linres.cache; b = _vec(ubuff), - linu = _vec(utilde)) - cache.linsolve1 = linres.cache + cache.linsolve1 = dolinsolve(integrator, linsolve1; b = _vec(ubuff), + linu = _vec(utilde)).cache integrator.stats.nsolve += 1 end