Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
oscardssmith committed Sep 27, 2024
1 parent 9d2a4c8 commit a8c0a52
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 119 deletions.
4 changes: 2 additions & 2 deletions lib/OrdinaryDiffEqRosenbrock/src/interp_func.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@ function DiffEqBase.interp_summary(::Type{cacheType},
"1st order linear"
end

function DiffEqBase.interp_summary(::Type{cacheType},
function DiffEqBase.interp_summary(cache::Type{cacheType},
dense::Bool) where {
cacheType <:
Union{RosenbrockCombinedConstantCache,
RosenbrockCache}}
dense ? "specialized $(cache.interp_order) order \"free\" stiffness-aware interpolation" :
dense ? "specialized ? order \"free\" stiffness-aware interpolation" :
"1st order linear"
end
106 changes: 52 additions & 54 deletions lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,54 @@ function get_fsalfirstlast(cache::GenericRosenbrockMutableCache, u)
(cache.fsalfirst, cache.fsallast)
end

mutable struct RosenbrockCache{uType, rateType, uNoUnitsType, JType, WType, TabType,
TFType, UFType, F, JCType, GCType, RTolType, A, AV, StepLimiter, StageLimiter} <: RosenbrockMutableCache
u::uType
uprev::uType
dense::Vector{rateType}
du::rateType
du1::rateType
du2::rateType
ks::Vector{rateType}
fsalfirst::rateType
fsallast::rateType
dT::rateType
J::JType
W::WType
tmp::rateType
atmp::uNoUnitsType
weight::uNoUnitsType
tab::TabType
tf::TFType
uf::UFType
linsolve_tmp::rateType
linsolve::F
jac_config::JCType
grad_config::GCType
reltol::RTolType
alg::A
algebraic_vars::AV
step_limiter!::StepLimiter
stage_limiter!::StageLimiter
interp_order::Int
end

function full_cache(c::RosenbrockCache)
return [c.u, c.uprev, c.dense..., c.du, c.du1, c.du2,
c.ks..., c.fsalfirst, c.fsallast, c.dT, c.tmp, c.atmp, c.weight, c.linsolve_tmp]
end

struct RosenbrockCombinedConstantCache{TF, UF, Tab, JType, WType, F, AD} <: RosenbrockConstantCache
tf::TF
uf::UF
tab::Tab
J::JType
W::WType
linsolve::F
autodiff::AD
interp_order::Int
end

@cache mutable struct Rosenbrock23Cache{uType, rateType, uNoUnitsType, JType, WType,
TabType, TFType, UFType, F, JCType, GCType,
RTolType, A, AV, StepLimiter, StageLimiter} <: RosenbrockMutableCache
Expand Down Expand Up @@ -74,6 +122,10 @@ end
stage_limiter!::StageLimiter
end

function get_fsalfirstlast(cache::Union{Rosenbrock23Cache, Rosenbrock32Cache}, u)
(cache.fsalfirst, cache.fsallast)
end

function alg_cache(alg::Rosenbrock23, u, rate_prototype, ::Type{uEltypeNoUnits},
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t,
dt, reltol, p, calck,
Expand Down Expand Up @@ -222,57 +274,6 @@ function alg_cache(alg::Rosenbrock32, u, rate_prototype, ::Type{uEltypeNoUnits},
alg_autodiff(alg))
end

################################################################################

# Shampine's Low-order Rosenbrocks
mutable struct RosenbrockCache{uType, rateType, uNoUnitsType, JType, WType, TabType,
TFType, UFType, F, JCType, GCType, RTolType, A, AV, StepLimiter, StageLimiter} <: RosenbrockMutableCache
u::uType
uprev::uType
dense::Vector{rateType}
du::rateType
du1::rateType
du2::rateType
ks::Vector{rateType}
fsalfirst::rateType
fsallast::rateType
dT::rateType
J::JType
W::WType
tmp::rateType
atmp::uNoUnitsType
weight::uNoUnitsType
tab::TabType
tf::TFType
uf::UFType
linsolve_tmp::rateType
linsolve::F
jac_config::JCType
grad_config::GCType
reltol::RTolType
alg::A
algebraic_vars::AV
step_limiter!::StepLimiter
stage_limiter!::StageLimiter
interp_order::Int
end

function full_cache(c::RosenbrockCache)
return [c.u, c.uprev, c.dense..., c.du, c.du1, c.du2,
c.ks..., c.fsalfirst, c.fsallast, c.dT, c.tmp, c.atmp, c.weight, c.linsolve_tmp]
end

struct RosenbrockCombinedConstantCache{TF, UF, Tab, JType, WType, F, AD} <: RosenbrockConstantCache
tf::TF
uf::UF
tab::Tab
J::JType
W::WType
linsolve::F
autodiff::AD
interp_order::Int
end

@ROS2(:cache)

################################################################################
Expand All @@ -296,9 +297,6 @@ jac_cache(c::Rosenbrock4Cache) = (c.J, c.W)

###############################################################################

### Rodas methods
tabtype(::Rosenbrock23) = Rosenbrock23Tableau
tabtype(::Rosenbrock32) = Rosenbrock32Tableau
tabtype(::Rodas23W) = Rodas23WTableau
tabtype(::ROS3P) = ROS3PTableau
tabtype(::Rodas3) = Rodas3Tableau
Expand Down
32 changes: 23 additions & 9 deletions lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_perform_step.jl
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,7 @@ end
@muladd function perform_step!(integrator, cache::RosenbrockCombinedConstantCache, repeat_step = false)
(;t, dt, uprev, u, f, p) = integrator
(;tf, uf) = cache
(;A, C, gamma, c, d, H) = cache.tab
(;A, C, b, btilde, gamma, c, d, H) = cache.tab

# Precalculations
dtC = C ./ dt
Expand Down Expand Up @@ -489,10 +489,17 @@ end
integrator.stats.nsolve += 1
end
#@show ks
u = u .+ ks[num_stages]
u = uprev
for i in 1:num_stages
u = @.. u + b[i] * ks[i]
end

if integrator.opts.adaptive
atmp = calculate_residuals(ks[num_stages], uprev, u, integrator.opts.abstol,
utilde = uprev
for i in 1:num_stages
utilde = @.. utilde + btilde[i] * ks[i]
end
atmp = calculate_residuals(utilde, uprev, u, integrator.opts.abstol,
integrator.opts.reltol, integrator.opts.internalnorm, t)
integrator.EEst = integrator.opts.internalnorm(atmp, t)
end
Expand Down Expand Up @@ -538,7 +545,7 @@ end
@muladd function perform_step!(integrator, cache::RosenbrockCache, repeat_step = false)
(; t, dt, uprev, u, f, p) = integrator
(; du, du1, du2, dT, J, W, uf, tf, ks, linsolve_tmp, jac_config, atmp, weight, stage_limiter!, step_limiter!) = cache
(; A, C, gamma, c, d, H) = cache.tab
(; A, C, b, btilde, gamma, c, d, H) = cache.tab

# Assignments
sizeu = size(u)
Expand All @@ -549,6 +556,7 @@ end
dtC = C .* inv(dt)
dtd = dt .* d
dtgamma = dt * gamma
utilde = du

f(cache.fsalfirst, uprev, p, t)
OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1)
Expand All @@ -572,8 +580,8 @@ end

@.. $(_vec(ks[1])) = -linres.u
integrator.stats.nsolve += 1

for stage in 2:length(ks)
num_stages = length(ks)
for stage in 2:num_stages
u .= uprev
for i in 1:(stage - 1)
@.. u += A[stage, i] * ks[i]
Expand Down Expand Up @@ -601,19 +609,25 @@ end
@.. $(_vec(ks[stage])) = -linres.u
integrator.stats.nsolve += 1
end
du .= ks[end]
u .+= ks[end]
u .= uprev
for i in 1:num_stages
@.. u += b[i] * ks[i]
end

step_limiter!(u, integrator, p, t + dt)

if integrator.opts.adaptive
utilde .= 0
for i in 1:num_stages
@.. utilde += btilde[i] * ks[i]
end
if (integrator.alg isa Rodas5Pe)
@.. du = 0.2606326497975715 * ks[1] - 0.005158627295444251 * ks[2] +
1.3038988631109731 * ks[3] + 1.235000722062074 * ks[4] +
-0.7931985603795049 * ks[5] - 1.005448461135913 * ks[6] -
0.18044626132120234 * ks[7] + 0.17051519239113755 * ks[8]
end
calculate_residuals!(atmp, ks[end], uprev, u, integrator.opts.abstol,
calculate_residuals!(atmp, utilde, uprev, u, integrator.opts.abstol,
integrator.opts.reltol, integrator.opts.internalnorm, t)
integrator.EEst = integrator.opts.internalnorm(atmp, t)
end
Expand Down
Loading

0 comments on commit a8c0a52

Please sign in to comment.