From a8c0a52f8ea0a8413d13c42b23853be1b7719c45 Mon Sep 17 00:00:00 2001 From: oscarddssmith Date: Fri, 20 Sep 2024 18:21:43 -0400 Subject: [PATCH] fixes --- .../src/interp_func.jl | 4 +- .../src/rosenbrock_caches.jl | 106 +++++++++--------- .../src/rosenbrock_perform_step.jl | 32 ++++-- .../src/rosenbrock_tableaus.jl | 87 ++++++-------- 4 files changed, 110 insertions(+), 119 deletions(-) diff --git a/lib/OrdinaryDiffEqRosenbrock/src/interp_func.jl b/lib/OrdinaryDiffEqRosenbrock/src/interp_func.jl index 43f66149fc..71b9435af6 100644 --- a/lib/OrdinaryDiffEqRosenbrock/src/interp_func.jl +++ b/lib/OrdinaryDiffEqRosenbrock/src/interp_func.jl @@ -9,11 +9,11 @@ function DiffEqBase.interp_summary(::Type{cacheType}, "1st order linear" end -function DiffEqBase.interp_summary(::Type{cacheType}, +function DiffEqBase.interp_summary(cache::Type{cacheType}, dense::Bool) where { cacheType <: Union{RosenbrockCombinedConstantCache, RosenbrockCache}} - dense ? "specialized $(cache.interp_order) order \"free\" stiffness-aware interpolation" : + dense ? "specialized ? order \"free\" stiffness-aware interpolation" : "1st order linear" end diff --git a/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_caches.jl b/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_caches.jl index 3079601ed9..4349415e2c 100644 --- a/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_caches.jl +++ b/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_caches.jl @@ -8,6 +8,54 @@ function get_fsalfirstlast(cache::GenericRosenbrockMutableCache, u) (cache.fsalfirst, cache.fsallast) end +mutable struct RosenbrockCache{uType, rateType, uNoUnitsType, JType, WType, TabType, + TFType, UFType, F, JCType, GCType, RTolType, A, AV, StepLimiter, StageLimiter} <: RosenbrockMutableCache + u::uType + uprev::uType + dense::Vector{rateType} + du::rateType + du1::rateType + du2::rateType + ks::Vector{rateType} + fsalfirst::rateType + fsallast::rateType + dT::rateType + J::JType + W::WType + tmp::rateType + atmp::uNoUnitsType + weight::uNoUnitsType + tab::TabType + tf::TFType + uf::UFType + linsolve_tmp::rateType + linsolve::F + jac_config::JCType + grad_config::GCType + reltol::RTolType + alg::A + algebraic_vars::AV + step_limiter!::StepLimiter + stage_limiter!::StageLimiter + interp_order::Int +end + +function full_cache(c::RosenbrockCache) + return [c.u, c.uprev, c.dense..., c.du, c.du1, c.du2, + c.ks..., c.fsalfirst, c.fsallast, c.dT, c.tmp, c.atmp, c.weight, c.linsolve_tmp] +end + +struct RosenbrockCombinedConstantCache{TF, UF, Tab, JType, WType, F, AD} <: RosenbrockConstantCache + tf::TF + uf::UF + tab::Tab + J::JType + W::WType + linsolve::F + autodiff::AD + interp_order::Int +end + @cache mutable struct Rosenbrock23Cache{uType, rateType, uNoUnitsType, JType, WType, TabType, TFType, UFType, F, JCType, GCType, RTolType, A, AV, StepLimiter, StageLimiter} <: RosenbrockMutableCache @@ -74,6 +122,10 @@ end stage_limiter!::StageLimiter end +function get_fsalfirstlast(cache::Union{Rosenbrock23Cache, Rosenbrock32Cache}, u) + (cache.fsalfirst, cache.fsallast) +end + function alg_cache(alg::Rosenbrock23, u, rate_prototype, ::Type{uEltypeNoUnits}, ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t, dt, reltol, p, calck, @@ -222,57 +274,6 @@ function alg_cache(alg::Rosenbrock32, u, rate_prototype, ::Type{uEltypeNoUnits}, alg_autodiff(alg)) end -################################################################################ - -# Shampine's Low-order Rosenbrocks -mutable struct RosenbrockCache{uType, rateType, uNoUnitsType, JType, WType, TabType, - TFType, UFType, F, JCType, GCType, RTolType, A, AV, StepLimiter, StageLimiter} <: RosenbrockMutableCache - u::uType - uprev::uType - dense::Vector{rateType} - du::rateType - du1::rateType - du2::rateType - ks::Vector{rateType} - fsalfirst::rateType - fsallast::rateType - dT::rateType - J::JType - W::WType - tmp::rateType - atmp::uNoUnitsType - weight::uNoUnitsType - tab::TabType - tf::TFType - uf::UFType - linsolve_tmp::rateType - linsolve::F - jac_config::JCType - grad_config::GCType - reltol::RTolType - alg::A - algebraic_vars::AV - step_limiter!::StepLimiter - stage_limiter!::StageLimiter - interp_order::Int -end - -function full_cache(c::RosenbrockCache) - return [c.u, c.uprev, c.dense..., c.du, c.du1, c.du2, - c.ks..., c.fsalfirst, c.fsallast, c.dT, c.tmp, c.atmp, c.weight, c.linsolve_tmp] -end - -struct RosenbrockCombinedConstantCache{TF, UF, Tab, JType, WType, F, AD} <: RosenbrockConstantCache - tf::TF - uf::UF - tab::Tab - J::JType - W::WType - linsolve::F - autodiff::AD - interp_order::Int -end - @ROS2(:cache) ################################################################################ @@ -296,9 +297,6 @@ jac_cache(c::Rosenbrock4Cache) = (c.J, c.W) ############################################################################### -### Rodas methods -tabtype(::Rosenbrock23) = Rosenbrock23Tableau -tabtype(::Rosenbrock32) = Rosenbrock32Tableau tabtype(::Rodas23W) = Rodas23WTableau tabtype(::ROS3P) = ROS3PTableau tabtype(::Rodas3) = Rodas3Tableau diff --git a/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_perform_step.jl b/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_perform_step.jl index 438be408e2..9b2d292e10 100644 --- a/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_perform_step.jl +++ b/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_perform_step.jl @@ -434,7 +434,7 @@ end @muladd function perform_step!(integrator, cache::RosenbrockCombinedConstantCache, repeat_step = false) (;t, dt, uprev, u, f, p) = integrator (;tf, uf) = cache - (;A, C, gamma, c, d, H) = cache.tab + (;A, C, b, btilde, gamma, c, d, H) = cache.tab # Precalculations dtC = C ./ dt @@ -489,10 +489,17 @@ end integrator.stats.nsolve += 1 end #@show ks - u = u .+ ks[num_stages] + u = uprev + for i in 1:num_stages + u = @.. u + b[i] * ks[i] + end if integrator.opts.adaptive - atmp = calculate_residuals(ks[num_stages], uprev, u, integrator.opts.abstol, + utilde = uprev + for i in 1:num_stages + utilde = @.. utilde + btilde[i] * ks[i] + end + atmp = calculate_residuals(utilde, uprev, u, integrator.opts.abstol, integrator.opts.reltol, integrator.opts.internalnorm, t) integrator.EEst = integrator.opts.internalnorm(atmp, t) end @@ -538,7 +545,7 @@ end @muladd function perform_step!(integrator, cache::RosenbrockCache, repeat_step = false) (; t, dt, uprev, u, f, p) = integrator (; du, du1, du2, dT, J, W, uf, tf, ks, linsolve_tmp, jac_config, atmp, weight, stage_limiter!, step_limiter!) = cache - (; A, C, gamma, c, d, H) = cache.tab + (; A, C, b, btilde, gamma, c, d, H) = cache.tab # Assignments sizeu = size(u) @@ -549,6 +556,7 @@ end dtC = C .* inv(dt) dtd = dt .* d dtgamma = dt * gamma + utilde = du f(cache.fsalfirst, uprev, p, t) OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) @@ -572,8 +580,8 @@ end @.. $(_vec(ks[1])) = -linres.u integrator.stats.nsolve += 1 - - for stage in 2:length(ks) + num_stages = length(ks) + for stage in 2:num_stages u .= uprev for i in 1:(stage - 1) @.. u += A[stage, i] * ks[i] @@ -601,19 +609,25 @@ end @.. $(_vec(ks[stage])) = -linres.u integrator.stats.nsolve += 1 end - du .= ks[end] - u .+= ks[end] + u .= uprev + for i in 1:num_stages + @.. u += b[i] * ks[i] + end step_limiter!(u, integrator, p, t + dt) if integrator.opts.adaptive + utilde .= 0 + for i in 1:num_stages + @.. utilde += btilde[i] * ks[i] + end if (integrator.alg isa Rodas5Pe) @.. du = 0.2606326497975715 * ks[1] - 0.005158627295444251 * ks[2] + 1.3038988631109731 * ks[3] + 1.235000722062074 * ks[4] + -0.7931985603795049 * ks[5] - 1.005448461135913 * ks[6] - 0.18044626132120234 * ks[7] + 0.17051519239113755 * ks[8] end - calculate_residuals!(atmp, ks[end], uprev, u, integrator.opts.abstol, + calculate_residuals!(atmp, utilde, uprev, u, integrator.opts.abstol, integrator.opts.reltol, integrator.opts.internalnorm, t) integrator.EEst = integrator.opts.internalnorm(atmp, t) end diff --git a/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_tableaus.jl b/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_tableaus.jl index 7ae9b138e8..08710eaed1 100644 --- a/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_tableaus.jl +++ b/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_tableaus.jl @@ -1,6 +1,8 @@ struct RodasTableau{T, T2} A::Matrix{T} C::Matrix{T} + b::Vector{T} + btilde::Vector{T} gamma::T c::Vector{T2} d::Vector{T} @@ -40,29 +42,17 @@ function ROS3PTableau(T, T2) ] tmp = -igamma * (convert(T, 2) - convert(T, 1 / 2) * igamma) C = T[ - 0 0 0 + 0 0 0 -igamma^2 0 0 -igamma * (1 - tmp) tmp 0 ] tmp = igamma * (convert(T, 2 // 3) - convert(T, 1 // 6) * igamma) - b = [(igamma * (convert(T, 1) + tmp)), (tmp), (convert(T, 1 // 3) * igamma)] - # btilde1 = convert(T,2.113248654051871) - # btilde2 = convert(T,1.000000000000000) - # btilde3 = convert(T,0.4226497308103742) + b = T[igamma * (1 + tmp), tmp, igamma / 3] btilde = T[2.113248654051871, 1, 0.4226497308103742] - c = T2[1, 1] + c = T2[0, 1, 1] d = T[0.7886751345948129, -0.2113248654051871, -1.077350269189626] - RodasTableau(A, C, b, btilde, gamma, c, d) -end - -struct Rodas3Tableau{T, T2} - A::Matrix{T} - C::Matrix{T} - b::Vector{T} - btilde::Vector{T} - gamma::T2 - c::Vector{T2} - d::Vector{T} + H = zeros(T, 3, 3) + RodasTableau(A, C, b, btilde, gamma, c, d, H) end function Rodas3Tableau(T, T2) @@ -80,10 +70,11 @@ function Rodas3Tableau(T, T2) 1 -1 -8 // 3 ] b = T[2, 0, 1, 1] - btilde = T[0.0, 0.0, 0.0, 1.0] - c = T[0.0, 1.0, 1.0] + btilde = T[0, 0, 0, 1] + c = T[0, 1, 1] d = T[1 // 2, 3 // 2, 0, 0] - RodasTableau(A, C, b, btilde, gamma, c, d) + H = zeros(T, 3, 3) + RodasTableau(A, C, b, btilde, gamma, c, d, H) end function Rodas3PTableau(T, T2) @@ -91,7 +82,7 @@ function Rodas3PTableau(T, T2) A = T[ 0 0 0 0 4 // 3 0 0 0 - 0 0 0 0 + 4 // 3 0 0 0 2.90625 3.375 0.40625 0 ] C = T[ @@ -101,15 +92,16 @@ function Rodas3PTableau(T, T2) 1.21875 5.0625 1.96875 0 4.03125 15.1875 4.03125 6.0 ] - c = T2[4 // 9, 0] + b = A[end, :] + btilde = T[0, 0, 0, 1] + c = T2[0, 4 // 9, 1] d = T[1 // 3, 1 // 9, 1] H = T[ - 0 0 0 0 0 1.78125 6.75 0.15625 6 1 4.21875 15.1875 3.09375 9 0 ] h2_2 = T[4.21875, 2.025, 1.63125, 1.7, 0.1] - RodasTableau(A, C, gamma, c, d, H, h2_2) + RodasTableau(A, C, b, btilde, gamma, c, d, H)#, h2_2) end @ROS2(:tableau) @@ -141,11 +133,13 @@ function Rodas4Tableau(T, T2) 7.496443313967647 -10.24680431464352 -33.99990352819905 11.70890893206160 0 8.083246795921522 -7.981132988064893 -31.52159432874371 16.31930543123136 -6.058818238834054 ] + b = A[end, :] + btilde = T[0, 0, 0, 0, 0, 1] c = T2[0, 0.386, 0.21, 0.63, 1, 1] d = T[0.25, -0.1043, 0.1035, -0.0362, 0, 0] H = T[10.12623508344586 -7.487995877610167 -34.80091861555747 -7.992771707568823 1.025137723295662 0 -0.6762803392801253 6.087714651680015 16.43084320892478 24.76722511418386 -6.594389125716872 0] - RodasTableau(A, C, gamma, c, d, H) + RodasTableau(A, C, b, btilde, gamma, c, d, H) end function Rodas42Tableau(T, T2) @@ -162,11 +156,13 @@ function Rodas42Tableau(T, T2) -32.64449927841361 -99.35311008728094 49.99119122405989 0 0 -76.46023087151691 -278.5942120829058 153.9294840910643 10.97101866258358 0 -76.29701586804983 -294.2795630511232 162.0029695867566 23.65166903095270 -7.652977706771382] + b = A[end, :] + btilde = T[0, 0, 0, 0, 0, 1] c = T2[0, 0.3507221, 0.2557041, 0.681779, 1, 1] d = T[0.25, -0.0690221, -0.0009672, -0.087979, 0, 0] H = T[-38.71940424117216 -135.8025833007622 64.51068857505875 -4.192663174613162 -2.531932050335060 0 -14.99268484949843 -76.30242396627033 58.65928432851416 16.61359034616402 -0.6758691794084156 0] - RodasTableau(A, C, gamma, c, d, H) + RodasTableau(A, C, b, btilde, gamma, c, d, H) end function Rodas4PTableau(T, T2) @@ -186,11 +182,13 @@ function Rodas4PTableau(T, T2) 10.81793056857153 6.780270611428266 19.53485944642410 0 0 34.19095006749676 15.49671153725963 54.74760875964130 14.16005392148534 0 34.62605830930532 15.30084976114473 56.99955578662667 18.40807009793095 -5.714285714285717] + b = A[end, :] + btilde = T[0, 0, 0, 0, 0, 1] c = T2[0, 0.75, 0.21, 0.63, 1, 1] d = T[0.25, -0.5, -0.023504, -0.0362, 0, 0] H = T[25.09876703708589 11.62013104361867 28.49148307714626 -5.664021568594133 0 0 1.638054557396973 -0.7373619806678748 8.477918219238990 15.99253148779520 -1.882352941176471 0] - RodasTableau(A, C, gamma, c, d, H) + RodasTableau(A, C, b, btilde, gamma, c, d, H) end function Rodas4P2Tableau(T, T2) @@ -207,6 +205,8 @@ function Rodas4P2Tableau(T, T2) -8.575016317114033 -7.606483992117508 12.224997650124820 0 0 -5.888975457523102 -8.157396617841821 24.805546872612922 12.790401512796979 0 -4.408651676063871 -6.692003137674639 24.625568527593117 16.627521966636085 -5.714285714285718] + b = A[end, :] + btilde = T[0, 0, 0, 0, 0, 1] c = T2[0, 0.75, 0.321448134013046, 0.519745732277726, 1, 1] d = T[0.25, -0.5, -0.189532918363016, 0.085612108792769, 0, 0] H = [-5.323528268423303 -10.042123754867493 17.175254928256965 -5.079931171878093 -0.016185991706112 0 @@ -236,6 +236,8 @@ function Rodas5Tableau(T, T2) 34.20013733472935 -14.15535402717690 57.82335640988400 25.83362985412365 1.408950972071624 -6.551835421242162 0 42.57076742291101 -13.80770672017997 93.98938432427124 18.77919633714503 -31.58359187223370 -6.685968952921985 -5.810979938412932 ] + b = A[end, :] + btilde = T[0, 0, 0, 0, 0, 0, 0, 1] c = T2[0, 0.38, 0.3878509998321533, 0.4839718937873840, 0.4570477008819580, 1, 1, 1] d = T[gamma, -0.1823079225333714636, -0.319231832186874912, 0.3449828624725343, -0.377417564392089818, 0, 0, 0] @@ -244,32 +246,7 @@ function Rodas5Tableau(T, T2) 44.19024239501722 1.3677947663381929e-13 202.93261852171622 -35.5669339789154 -181.91095152160645 3.4116351403665033 2.5793540257308067 2.2435122582734066 -44.0988150021747 -5.755396159656812e-13 -181.26175034586677 56.99302194811676 183.21182741427398 -7.480257918273637 -5.792426076169686 -5.32503859794143 ] - # println("---Rodas5---") - - #= - a71 = -14.09640773051259 - a72 = 6.925207756232704 - a73 = -41.47510893210728 - a74 = 2.343771018586405 - a75 = 24.13215229196062 - a76 = convert(T,1) - a81 = -14.09640773051259 - a82 = 6.925207756232704 - a83 = -41.47510893210728 - a84 = 2.343771018586405 - a85 = 24.13215229196062 - a86 = convert(T,1) - a87 = convert(T,1) - b1 = -14.09640773051259 - b2 = 6.925207756232704 - b3 = -41.47510893210728 - b4 = 2.343771018586405 - b5 = 24.13215229196062 - b6 = convert(T,1) - b7 = convert(T,1) - b8 = convert(T,1) - =# - RodasTableau(A, C, gamma, d, c, H) + RodasTableau(A, C, b, btilde, gamma, c, d, H) end function Rodas5PTableau(T, T2) @@ -294,6 +271,8 @@ function Rodas5PTableau(T, T2) 30.91273214028599 -3.1208243349937974 77.79954646070892 34.28646028294783 -19.097331116725623 -28.087943162872662 0 37.80277123390563 -3.2571969029072276 112.26918849496327 66.9347231244047 -40.06618937091002 -54.66780262877968 -9.48861652309627 ] + b = A[end, :] + btilde = T[0, 0, 0, 0, 0, 0, 0, 1] c = T2[0, 0.6358126895828704, 0.4095798393397535, 0.9769306725060716, 0.4288403609558664, 1, 1, 1] d = T[0.21193756319429014, -0.42387512638858027, -0.3384627126235924, 1.8046452872882734, 2.325825639765069, 0, 0, 0] H = T[ @@ -301,7 +280,7 @@ function Rodas5PTableau(T, T2) -9.91568850695171 -0.9689944594115154 3.0438037242978453 -24.495224566215796 20.176138334709044 15.98066361424651 -6.789040303419874 -6.710236069923372 11.419903575922262 2.8879645146136994 72.92137995996029 80.12511834622643 -52.072871366152654 -59.78993625266729 -0.15582684282751913 4.883087185713722 ] - RodasTableau(A, C, gamma, d, c, H) + RodasTableau(A, C, b, btilde, gamma, c, d, H) end @RosenbrockW6S4OS(:tableau)