Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
oscardssmith committed Sep 18, 2024
1 parent 0054db2 commit 9af8cdd
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 228 deletions.
137 changes: 15 additions & 122 deletions lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ mutable struct RosenbrockCache{uType, rateType, uNoUnitsType, JType, WType, TabT
alg::A
step_limiter!::StepLimiter
stage_limiter!::StageLimiter
order::Int
interp_order::Int
end
function full_cache(c::RosenbrockCache)
return [c.u, c.uprev, c.dense..., c.du, c.du1, c.du2,
Expand All @@ -56,7 +56,7 @@ struct RosenbrockCombinedConstantCache{TF, UF, Tab, JType, WType, F, AD} <: Rose
W::WType
linsolve::F
autodiff::AD
order::Int
interp_order::Int
end

@cache mutable struct Rosenbrock23Cache{uType, rateType, uNoUnitsType, JType, WType,
Expand Down Expand Up @@ -718,8 +718,12 @@ tabtype(::Rodas4) = Rodas4Tableau
tabtype(::Rodas42) = Rodas42Tableau
tabtype(::Rodas4P) = Rodas4PTableau
tabtype(::Rodas4P2) = Rodas4P2Tableau
tabtype(::Rodas5) = Rodas5Tableau
tabtype(::Rodas5P) = Rodas5PTableau
tabtype(::Rodas5Pr) = Rodas5PTableau
tabtype(::Rodas5Pe) = Rodas5PTableau

function alg_cache(alg::Union{Rodas4, Rodas42, Rodas4P, Rodas4P2},
function alg_cache(alg::Union{Rodas4, Rodas42, Rodas4P, Rodas4P2, Rodas5, Rodas5P, Rodas5Pe, Rodas5Pr},
u, rate_prototype, ::Type{uEltypeNoUnits},
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t,
dt, reltol, p, calck,
Expand All @@ -729,21 +733,22 @@ function alg_cache(alg::Union{Rodas4, Rodas42, Rodas4P, Rodas4P2},
J, W = build_J_W(alg, u, uprev, p, t, dt, f, uEltypeNoUnits, Val(false))
linprob = nothing #LinearProblem(W,copy(u); u0=copy(u))
linsolve = nothing #init(linprob,alg.linsolve,alias_A=true,alias_b=true)
tab = tabtype(alg)(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits))
RosenbrockCombinedConstantCache(tf, uf,
tabtype(alg)(constvalue(uBottomEltypeNoUnits),
constvalue(tTypeNoUnits)), J, W, linsolve,
alg_autodiff(alg), 4)
tab, J, W, linsolve,
alg_autodiff(alg), size(tab.H, 1))
end

function alg_cache(alg::Union{Rodas4, Rodas42, Rodas4P, Rodas4P2},
function alg_cache(alg::Union{Rodas4, Rodas42, Rodas4P, Rodas4P2, Rodas5, Rodas5P, Rodas5Pe, Rodas5Pr},
u, rate_prototype, ::Type{uEltypeNoUnits},
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t,
dt, reltol, p, calck,
::Val{true}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}

tab = tabtype(alg)(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits))
# Initialize vectors
dense = [zero(rate_prototype) for _ in 1:2]
ks = [zero(rate_prototype) for _ in 1:6]
dense = [zero(rate_prototype) for _ in 1:size(tab.H, 1)]
ks = [zero(rate_prototype) for _ in 1:size(tab.A, 1)]
du = zero(rate_prototype)
du1 = zero(rate_prototype)
du2 = zero(rate_prototype)
Expand All @@ -762,7 +767,6 @@ function alg_cache(alg::Union{Rodas4, Rodas42, Rodas4P, Rodas4P2},
recursivefill!(atmp, false)
weight = similar(u, uEltypeNoUnits)
recursivefill!(weight, false)
tab = tabtype(alg)(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits))

tf = TimeGradientWrapper(f, uprev, p)
uf = UJacobianWrapper(f, t, p)
Expand All @@ -785,120 +789,9 @@ function alg_cache(alg::Union{Rodas4, Rodas42, Rodas4P, Rodas4P2},
u, uprev, dense, du, du1, du2, ks, fsalfirst, fsallast,
dT, J, W, tmp, atmp, weight, tab, tf, uf, linsolve_tmp,
linsolve, jac_config, grad_config, reltol, alg,
alg.step_limiter!, alg.stage_limiter!, 4)
end

################################################################################

### Rosenbrock5

function alg_cache(alg::Rodas5, u, rate_prototype, ::Type{uEltypeNoUnits},
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t,
dt, reltol, p, calck,
::Val{true}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
dense = [zero(rate_prototype) for _ in 1:3]
du = zero(rate_prototype)
du1 = zero(rate_prototype)
du2 = zero(rate_prototype)
ks = [zero(rate_prototype) for _ in 1:7]
fsalfirst = zero(rate_prototype)
fsallast = zero(rate_prototype)
dT = zero(rate_prototype)
J, W = build_J_W(alg, u, uprev, p, t, dt, f, uEltypeNoUnits, Val(true))
tmp = zero(rate_prototype)
atmp = similar(u, uEltypeNoUnits)
recursivefill!(atmp, false)
weight = similar(u, uEltypeNoUnits)
recursivefill!(weight, false)
tab = Rodas5Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits))

tf = TimeGradientWrapper(f, uprev, p)
uf = UJacobianWrapper(f, t, p)
linsolve_tmp = zero(rate_prototype)
linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp))
Pl, Pr = wrapprecs(
alg.precs(W, nothing, u, p, t, nothing, nothing, nothing,
nothing)..., weight, tmp)
linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true,
Pl = Pl, Pr = Pr,
assumptions = LinearSolve.OperatorAssumptions(true))
grad_config = build_grad_config(alg, f, tf, du1, t)
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2)
RosenbrockCache(u, uprev, dense, du, du1, du2, ks,
fsalfirst, fsallast, dT, J, W, tmp, atmp, weight, tab, tf, uf,
linsolve_tmp,
linsolve, jac_config, grad_config, reltol, alg, alg.step_limiter!,
alg.stage_limiter!, 5)
end

function alg_cache(alg::Rodas5, u, rate_prototype, ::Type{uEltypeNoUnits},
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t,
dt, reltol, p, calck,
::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
tf = TimeDerivativeWrapper(f, u, p)
uf = UDerivativeWrapper(f, t, p)
J, W = build_J_W(alg, u, uprev, p, t, dt, f, uEltypeNoUnits, Val(false))
linprob = nothing #LinearProblem(W,copy(u); u0=copy(u))
linsolve = nothing #init(linprob,alg.linsolve,alias_A=true,alias_b=true)
RosenbrockCombinedConstantCache(tf, uf,
Rodas5Tableau(constvalue(uBottomEltypeNoUnits),
constvalue(tTypeNoUnits)), J, W, linsolve, alg_autodiff(alg), 5)
alg.step_limiter!, alg.stage_limiter!, size(tab.H, 1))
end

function alg_cache(
alg::Union{Rodas5P, Rodas5Pe, Rodas5Pr}, u, rate_prototype, ::Type{uEltypeNoUnits},
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t,
dt, reltol, p, calck,
::Val{true}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
dense = [zero(rate_prototype) for _ in 1:3]
du = zero(rate_prototype)
du1 = zero(rate_prototype)
du2 = zero(rate_prototype)
ks = [zero(rate_prototype) for _ in 1:8]
fsalfirst = zero(rate_prototype)
fsallast = zero(rate_prototype)
dT = zero(rate_prototype)
J, W = build_J_W(alg, u, uprev, p, t, dt, f, uEltypeNoUnits, Val(true))
tmp = zero(rate_prototype)
atmp = similar(u, uEltypeNoUnits)
recursivefill!(atmp, false)
weight = similar(u, uEltypeNoUnits)
recursivefill!(weight, false)
tab = Rodas5PTableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits))

tf = TimeGradientWrapper(f, uprev, p)
uf = UJacobianWrapper(f, t, p)
linsolve_tmp = zero(rate_prototype)
linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp))
Pl, Pr = wrapprecs(
alg.precs(W, nothing, u, p, t, nothing, nothing, nothing,
nothing)..., weight, tmp)
linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true,
Pl = Pl, Pr = Pr,
assumptions = LinearSolve.OperatorAssumptions(true))
grad_config = build_grad_config(alg, f, tf, du1, t)
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2)
RosenbrockCache(u, uprev, dense, du, du1, du2, ks,
fsalfirst, fsallast, dT, J, W, tmp, atmp, weight, tab, tf, uf,
linsolve_tmp,
linsolve, jac_config, grad_config, reltol, alg, alg.step_limiter!,
alg.stage_limiter!, 5)
end

function alg_cache(
alg::Union{Rodas5P, Rodas5Pe, Rodas5Pr}, u, rate_prototype, ::Type{uEltypeNoUnits},
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t,
dt, reltol, p, calck,
::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
tf = TimeDerivativeWrapper(f, u, p)
uf = UDerivativeWrapper(f, t, p)
J, W = build_J_W(alg, u, uprev, p, t, dt, f, uEltypeNoUnits, Val(false))
linprob = nothing #LinearProblem(W,copy(u); u0=copy(u))
linsolve = nothing #init(linprob,alg.linsolve,alias_A=true,alias_b=true)
RosenbrockCombinedConstantCache(tf, uf,
Rodas5PTableau(constvalue(uBottomEltypeNoUnits),
constvalue(tTypeNoUnits)), J, W, linsolve, alg_autodiff(alg), 5)
end

function get_fsalfirstlast(
cache::Union{Rosenbrock23Cache, Rosenbrock32Cache, Rosenbrock33Cache,
Expand Down
26 changes: 13 additions & 13 deletions lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_interpolants.jl
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ From MATLAB ODE Suite by Shampine
Θ, dt, y₀, y₁, k, cache::Union{RosenbrockCombinedConstantCache, Rodas23WConstantCache, Rodas3PConstantCache, RosenbrockCache, Rodas23WCache, Rodas3PCache},
idxs::Nothing, T::Type{Val{0}}, differential_vars)
Θ1 = 1 - Θ
if !isdefined(cache, :order) || cache.order == 4
if !hasproperty(cache, :interp_order) || cache.interp_order == 2
@.. Θ1 * y₀+Θ * (y₁ + Θ1 * (k[1] + Θ * k[2]))
else
@.. Θ1 * y₀+Θ * (y₁ + Θ1 * (k[1] + Θ * (k[2] + Θ * k[3])))
Expand All @@ -143,7 +143,7 @@ end
Rodas23WCache, Rodas3PConstantCache, Rodas3PCache},
idxs, T::Type{Val{0}}, differential_vars)
Θ1 = 1 - Θ
if !isdefined(cache, :order) || cache.order == 4
if !hasproperty(cache, :interp_order) || cache.interp_order == 2
@views @.. Θ1 * y₀[idxs]+Θ * (y₁[idxs] + Θ1 * (k[1][idxs] + Θ * k[2][idxs]))
else
@views @.. Θ1 * y₀[idxs]+Θ * (y₁[idxs] + Θ1 * (k[1][idxs] + Θ * (k[2][idxs] + Θ * k[3][idxs])))
Expand All @@ -155,7 +155,7 @@ end
Rodas23WCache, Rodas3PConstantCache, Rodas3PCache},
idxs::Nothing, T::Type{Val{0}}, differential_vars)
Θ1 = 1 - Θ
if !isdefined(cache, :order) || cache.order == 4
if !hasproperty(cache, :interp_order) || cache.interp_order == 2
@.. out=Θ1 * y₀ + Θ * (y₁ + Θ1 * (k[1] + Θ * k[2]))
else
@.. out=Θ1 * y₀ + Θ * (y₁ + Θ1 * (k[1] + Θ * (k[2] + Θ * k[3])))
Expand All @@ -168,7 +168,7 @@ end
Rodas23WCache, Rodas3PConstantCache, Rodas3PCache},
idxs, T::Type{Val{0}}, differential_vars)
Θ1 = 1 - Θ
if !isdefined(cache, :order) || cache.order == 4
if !hasproperty(cache, :interp_order) || cache.interp_order == 2
@views @.. out=Θ1 * y₀[idxs] + Θ * (y₁[idxs] + Θ1 * (k[1][idxs] + Θ * k[2][idxs]))
else
@views @.. Θ1 * y₀[idxs]+Θ * (y₁[idxs] +
Expand All @@ -181,32 +181,32 @@ end
@muladd function _ode_interpolant(
Θ, dt, y₀, y₁, k, cache::Union{RosenbrockCache, Rodas23WCache, Rodas3PCache, RosenbrockCombinedConstantCache, Rodas23WConstantCache, Rodas3PConstantCache},
idxs::Nothing, T::Type{Val{1}}, differential_vars)
if !isdefined(cache, :order) || cache.order == 4
if !hasproperty(cache, :interp_order) || cache.interp_order == 2
@.. (k[1] + Θ * (-2 * k[1] + 2 * k[2] - 3 * k[2] * Θ) - y₀ + y₁)/dt
else
@.. (k[1][idxs] + Θ * (-2 * k[1][idxs] + 2 * k[2][idxs] +
Θ * (-3 * k[2][idxs] + 3 * k[3][idxs] - 4 * Θ * k[3][idxs])) -
y₀[idxs] + y₁[idxs])/dt
@.. (k[1] + Θ * (-2 * k[1] + 2 * k[2] +
Θ * (-3 * k[2] + 3 * k[3] - 4 * Θ * k[3])) -
y₀ + y₁)/dt
end
end
@muladd function _ode_interpolant(Θ, dt, y₀, y₁, k,
cache::Union{RosenbrockCombinedConstantCache, RosenbrockCache, Rodas23WConstantCache,
Rodas23WCache, Rodas3PConstantCache, Rodas3PCache},
idxs, T::Type{Val{1}}, differential_vars)
if !isdefined(cache, :order) || cache.order == 4
if !hasproperty(cache, :interp_order) || cache.interp_order == 2
@views @.. (k[1][idxs] + Θ * (-2 * k[1][idxs] + 2 * k[2][idxs] - 3 * k[2][idxs] * Θ) -
y₀[idxs] + y₁[idxs])/dt
else
@views @.. (k[1] + Θ * (-2 * k[1] + 2 * k[2] +
Θ * (-3 * k[2] + 3 * k[3] - 4 * Θ * k[3])) - y₀ + y₁)/dt
@views @.. (k[1][idxs] + Θ * (-2 * k[1][idxs] + 2 * k[2][idxs] +
Θ * (-3 * k[2][idxs] + 3 * k[3][idxs] - 4 * Θ * k[3][idxs])) - y₀[idxs] + y₁[idxs])/dt
end
end

@muladd function _ode_interpolant!(out, Θ, dt, y₀, y₁, k,
cache::Union{RosenbrockCombinedConstantCache, RosenbrockCache, Rodas23WConstantCache,
Rodas23WCache, Rodas3PConstantCache, Rodas3PCache},
idxs::Nothing, T::Type{Val{1}}, differential_vars)
if !isdefined(cache, :order) || cache.order == 4
if !hasproperty(cache, :interp_order) || cache.interp_order == 2
@.. out=(k[1] + Θ * (-2 * k[1] + 2 * k[2] - 3 * k[2] * Θ) - y₀ + y₁) / dt
else
@.. out=(k[1] + Θ * (-2 * k[1] + 2 * k[2] +
Expand All @@ -219,7 +219,7 @@ end
cache::Union{RosenbrockCombinedConstantCache, RosenbrockCache, Rodas23WConstantCache,
Rodas23WCache, Rodas3PConstantCache, Rodas3PCache},
idxs, T::Type{Val{1}}, differential_vars)
if !isdefined(cache, :order) || cache.order == 4
if !hasproperty(cache, :interp_order) || cache.interp_order == 2
@views @.. out=(k[1][idxs] +
Θ *
(-2 * k[1][idxs] + 2 * k[2][idxs] -
Expand Down
4 changes: 2 additions & 2 deletions lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_perform_step.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1203,7 +1203,7 @@ end
#### Rodas4 type method

function initialize!(integrator, cache::RosenbrockCombinedConstantCache)
integrator.kshortsize = cache.order == 5 ? 3 : 2
integrator.kshortsize = size(cache.tab.H, 1)
integrator.k = typeof(integrator.k)(undef, integrator.kshortsize)
# Avoid undefined entries if k is an array of arrays
for i in 1:integrator.kshortsize
Expand Down Expand Up @@ -1307,7 +1307,7 @@ end
end

function initialize!(integrator, cache::RosenbrockCache)
integrator.kshortsize = cache.order == 5 ? 3 : 2
integrator.kshortsize = size(cache.tab.H, 1)
resize!(integrator.k, integrator.kshortsize)
for i in 1:integrator.kshortsize
integrator.k[i] = cache.dense[i]
Expand Down
32 changes: 16 additions & 16 deletions lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_tableaus.jl
Original file line number Diff line number Diff line change
Expand Up @@ -338,13 +338,13 @@ function Rodas5Tableau(T, T2)
-14.09640773051259 6.925207756232704 -41.47510893210728 2.343771018586405 24.13215229196062 1 1 0
]
C = T[
0 0 0 0 0 0 0
-10.31323885133993 0 0 0 0 0 0
-21.04823117650003 -7.234992135176716 0 0 0 0 0
32.22751541853323 -4.943732386540191 19.44922031041879 0 0 0 0
-20.69865579590063 -8.816374604402768 1.260436877740897 -0.7495647613787146 0 0 0
-46.22004352711257 -17.49534862857472 -289.6389582892057 93.60855400400906 318.3822534212147 0 0
34.20013733472935 -14.15535402717690 57.82335640988400 25.83362985412365 1.408950972071624 -6.551835421242162 0
0 0 0 0 0 0 0
-10.31323885133993 0 0 0 0 0 0
-21.04823117650003 -7.234992135176716 0 0 0 0 0
32.22751541853323 -4.943732386540191 19.44922031041879 0 0 0 0
-20.69865579590063 -8.816374604402768 1.260436877740897 -0.7495647613787146 0 0 0
-46.22004352711257 -17.49534862857472 -289.6389582892057 93.60855400400906 318.3822534212147 0 0
34.20013733472935 -14.15535402717690 57.82335640988400 25.83362985412365 1.408950972071624 -6.551835421242162 0
42.57076742291101 -13.80770672017997 93.98938432427124 18.77919633714503 -31.58359187223370 -6.685968952921985 -5.810979938412932
]
c = T2[0, 0.38, 0.3878509998321533, 0.4839718937873840, 0.4570477008819580, 1, 1, 1]
Expand Down Expand Up @@ -380,7 +380,7 @@ function Rodas5Tableau(T, T2)
b7 = convert(T,1)
b8 = convert(T,1)
=#
RodasTableau(A, C, gamma, d, c, H)
RodasTableau(A, C, gamma, c, d, H)
end

function Rodas5PTableau(T, T2)
Expand All @@ -396,13 +396,13 @@ function Rodas5PTableau(T, T2)
-7.502846399306121 2.561846144803919 -11.627539656261098 -0.18268767659942256 0.030198172008377946 1 1 0
]
C = T[
0 0 0 0 0 0 0
-14.155112264123755 0 0 0 0 0 0
-17.97296035885952 -2.859693295451294 0 0 0 0 0
147.12150275711716 -1.41221402718213 71.68940251302358 0 0 0 0
165.43517024871676 -0.4592823456491126 42.90938336958603 -5.961986721573306 0 0 0
24.854864614690072 -3.0009227002832186 47.4931110020768 5.5814197821558125 -0.6610691825249471 0 0
30.91273214028599 -3.1208243349937974 77.79954646070892 34.28646028294783 -19.097331116725623 -28.087943162872662 0
0 0 0 0 0 0 0
-14.155112264123755 0 0 0 0 0 0
-17.97296035885952 -2.859693295451294 0 0 0 0 0
147.12150275711716 -1.41221402718213 71.68940251302358 0 0 0 0
165.43517024871676 -0.4592823456491126 42.90938336958603 -5.961986721573306 0 0 0
24.854864614690072 -3.0009227002832186 47.4931110020768 5.5814197821558125 -0.6610691825249471 0 0
30.91273214028599 -3.1208243349937974 77.79954646070892 34.28646028294783 -19.097331116725623 -28.087943162872662 0
37.80277123390563 -3.2571969029072276 112.26918849496327 66.9347231244047 -40.06618937091002 -54.66780262877968 -9.48861652309627
]
c = T2[0, 0.6358126895828704, 0.4095798393397535, 0.9769306725060716, 0.4288403609558664, 1, 1, 1]
Expand All @@ -412,7 +412,7 @@ function Rodas5PTableau(T, T2)
-9.91568850695171 -0.9689944594115154 3.0438037242978453 -24.495224566215796 20.176138334709044 15.98066361424651 -6.789040303419874 -6.710236069923372
11.419903575922262 2.8879645146136994 72.92137995996029 80.12511834622643 -52.072871366152654 -59.78993625266729 -0.15582684282751913 4.883087185713722
]
RodasTableau(A, C, gamma, d, c, H)
RodasTableau(A, C, gamma, c, d, H)
end

@RosenbrockW6S4OS(:tableau)
Expand Down
Loading

0 comments on commit 9af8cdd

Please sign in to comment.