Skip to content

Commit

Permalink
Merge pull request #1 from Shreyas-Ekanathan/master
Browse files Browse the repository at this point in the history
update
  • Loading branch information
Shreyas-Ekanathan authored Jun 10, 2024
2 parents e686a5c + d0b3b06 commit 4e974bb
Show file tree
Hide file tree
Showing 7 changed files with 888 additions and 2 deletions.
2 changes: 1 addition & 1 deletion src/OrdinaryDiffEq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,7 @@ export ORK256, CarpenterKennedy2N54, SHLDDRK64, HSLDDRK64, DGLDDRK73_C, DGLDDRK8
RDPK3Sp35, RDPK3SpFSAL35, RDPK3Sp49, RDPK3SpFSAL49, RDPK3Sp510, RDPK3SpFSAL510,
KYK2014DGSSPRK_3S2

export RadauIIA3, RadauIIA5
export RadauIIA3, RadauIIA5, RadauIIA7

export ImplicitEuler, ImplicitMidpoint, Trapezoid, TRBDF2, SDIRK2, SDIRK22,
Kvaerno3, KenCarp3, Cash4, Hairer4, Hairer42, SSPSDIRK2, Kvaerno4,
Expand Down
3 changes: 2 additions & 1 deletion src/alg_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ qmin_default(alg::DP8) = 1 // 3
qmax_default(alg::Union{OrdinaryDiffEqAlgorithm, DAEAlgorithm}) = 10
qmax_default(alg::CompositeAlgorithm) = minimum(qmax_default.(alg.algs))
qmax_default(alg::DP8) = 6
qmax_default(alg::Union{RadauIIA3, RadauIIA5}) = 8
qmax_default(alg::Union{RadauIIA3, RadauIIA5, RadauIIA7}) = 8

function has_chunksize(alg::OrdinaryDiffEqAlgorithm)
return alg isa Union{OrdinaryDiffEqExponentialAlgorithm,
Expand Down Expand Up @@ -575,6 +575,7 @@ alg_order(alg::TanYam7) = 7
alg_order(alg::TsitPap8) = 8
alg_order(alg::RadauIIA3) = 3
alg_order(alg::RadauIIA5) = 5
alg_order(alg::RadauIIA7) = 7
alg_order(alg::ImplicitEuler) = 1
alg_order(alg::RKMK2) = 2
alg_order(alg::RKMK4) = 4
Expand Down
50 changes: 50 additions & 0 deletions src/algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1709,6 +1709,56 @@ function RadauIIA5(; chunk_size = Val{0}(), autodiff = Val{true}(),
end
TruncatedStacktraces.@truncate_stacktrace RadauIIA5

"""
@article{hairer1999stiff,
title={Stiff differential equations solved by Radau methods},
author={Hairer, Ernst and Wanner, Gerhard},
journal={Journal of Computational and Applied Mathematics},
volume={111},
number={1-2},
pages={93--111},
year={1999},
publisher={Elsevier}
}
RadauIIA7: Fully-Implicit Runge-Kutta Method
An A-B-L stable fully implicit Runge-Kutta method with internal tableau complex basis transform for efficiency.
"""
struct RadauIIA7{CS, AD, F, P, FDT, ST, CJ, Tol, C1, C2} <:
OrdinaryDiffEqNewtonAdaptiveAlgorithm{CS, AD, FDT, ST, CJ}
linsolve::F
precs::P
smooth_est::Bool
extrapolant::Symbol
κ::Tol
maxiters::Int
fast_convergence_cutoff::C1
new_W_γdt_cutoff::C2
controller::Symbol
end

function RadauIIA7(; chunk_size = Val{0}(), autodiff = Val{true}(),
standardtag = Val{true}(), concrete_jac = nothing,
diff_type = Val{:forward},
linsolve = nothing, precs = DEFAULT_PRECS,
extrapolant = :dense, fast_convergence_cutoff = 1 // 5,
new_W_γdt_cutoff = 1 // 5,
controller = :Predictive, κ = nothing, maxiters = 10, smooth_est = true)
RadauIIA7{_unwrap_val(chunk_size), _unwrap_val(autodiff), typeof(linsolve),
typeof(precs), diff_type, _unwrap_val(standardtag), _unwrap_val(concrete_jac),
typeof(κ), typeof(fast_convergence_cutoff), typeof(new_W_γdt_cutoff)}(linsolve,
precs,
smooth_est,
extrapolant,
κ,
maxiters,
fast_convergence_cutoff,
new_W_γdt_cutoff,
controller)
end
TruncatedStacktraces.@truncate_stacktrace RadauIIA7


################################################################################

# SDIRK Methods
Expand Down
181 changes: 181 additions & 0 deletions src/caches/firk_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -273,3 +273,184 @@ function alg_cache(alg::RadauIIA5, u, rate_prototype, ::Type{uEltypeNoUnits},
tmp, atmp, jac_config, linsolve1, linsolve2, rtol, atol, dt, dt,
Convergence, alg.step_limiter!)
end

mutable struct RadauIIA7ConstantCache{F, Tab, Tol, Dt, U, JType} <:
OrdinaryDiffEqConstantCache
uf::F
tab::Tab
κ::Tol
ηold::Tol
iter::Int
cont1::U
cont2::U
cont3::U
dtprev::Dt
W_γdt::Dt
status::NLStatus
J::JType
end

function alg_cache(alg::RadauIIA7, u, rate_prototype, ::Type{uEltypeNoUnits},
::Type{uBottomEltypeNoUnits},
::Type{tTypeNoUnits}, uprev, uprev2, f, t, dt, reltol, p, calck,
::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
uf = UDerivativeWrapper(f, t, p)
uToltype = constvalue(uBottomEltypeNoUnits)
tab = RadauIIA7Tableau(uToltype, constvalue(tTypeNoUnits))

κ = convert(uToltype, 1 // 100)
J = false .* _vec(rate_prototype) .* _vec(rate_prototype)'

RadauIIA7ConstantCache(uf, tab, κ, one(uToltype), 10000, u, u, u, dt, dt,
Convergence, J)
end

mutable struct RadauIIA7Cache{uType, cuType, uNoUnitsType, rateType, JType, W1Type, W2Type,
UF, JC, F1, F2, Tab, Tol, Dt, rTol, aTol} <:
OrdinaryDiffEqMutableCache
u::uType
uprev::uType
z1::uType
z2::uType
z3::uType
z4::uType
z5::uType
w1::uType
w2::uType
w3::uType
w4::uType
w5::uType
dw1::uType
ubuff::uType
dw23::cuType
dw45::cuType
cubuff::cuType
cont1::uType
cont2::uType
cont3::uType
cont4::uType
du1::rateType #why is this here?
fsalfirst::rateType
k::rateType
k2::rateType
k3::rateType
k4::rateType
k5::rateType
fw1::rateType
fw2::rateType
fw3::rateType
fw4::rateType
fw5::rateType
J::JType
W1::W1Type
W2::W2Type # complex
W3::W3Type #CHECK THIS TYPE
uf::UF
tab::Tab
κ::Tol
ηold::Tol
iter::Int
tmp::uType
atmp::uNoUnitsType
jac_config::JC
linsolve1::F1
linsolve2::F2
linsolve3::F2 #CHECK THIS TYPE
rtol::rTol
atol::aTol
dtprev::Dt
W_γdt::Dt
status::NLStatus
end
TruncatedStacktraces.@truncate_stacktrace RadauIIA7Cache 1

function alg_cache(alg::RadauIIA7, u, rate_prototype, ::Type{uEltypeNoUnits},
::Type{uBottomEltypeNoUnits},
::Type{tTypeNoUnits}, uprev, uprev2, f, t, dt, reltol, p, calck,
::Val{true}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
uf = UJacobianWrapper(f, t, p)
uToltype = constvalue(uBottomEltypeNoUnits)
tab = RadauIIA7Tableau(uToltype, constvalue(tTypeNoUnits))

κ = alg.κ !== nothing ? convert(uToltype, alg.κ) : convert(uToltype, 1 // 100)

z1 = zero(u)
z2 = zero(u)
z3 = zero(u)
z4 = zero(u)
z5 = zero(u)
w1 = zero(u)
w2 = zero(u)
w3 = zero(u)
w4 = zero(u)
w5 = zero(u)
dw1 = zero(u)
ubuff = zero(u)
dw23 = similar(u, Complex{eltype(u)})
dw45 = similar(u, Complex{eltype(u)})
recursivefill!(dw23, false)
recursivefill!(dw45, false)
cubuff = similar(u, Complex{eltype(u)})
recursivefill!(cubuff, false)
cont1 = zero(u)
cont2 = zero(u)
cont3 = zero(u)
cont4 = zero(u)

fsalfirst = zero(rate_prototype)
k = zero(rate_prototype)
k2 = zero(rate_prototype)
k3 = zero(rate_prototype)
k4 = zero(rate_prototype)
k5 = zero(rate_prototype)
fw1 = zero(rate_prototype)
fw2 = zero(rate_prototype)
fw3 = zero(rate_prototype)
fw4 = zero(rate_prototype)
fw5 = zero(rate_prototype)

J, W1 = build_J_W(alg, u, uprev, p, t, dt, f, uEltypeNoUnits, Val(true))
if J isa AbstractSciMLOperator
error("Non-concrete Jacobian not yet supported by RadauIIA5.")
end
W2 = similar(J, Complex{eltype(W1)})
W3 = similar(J, Complex{eltype(W1)})
recursivefill!(W2, false)
recursivefill!(W3, false)

du1 = zero(rate_prototype)

tmp = zero(u)
atmp = similar(u, uEltypeNoUnits)
recursivefill!(atmp, false)
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, dw1)

linprob = LinearProblem(W1, _vec(ubuff); u0 = _vec(dw1))
linsolve1 = init(linprob, alg.linsolve, alias_A = true, alias_b = true,
assumptions = LinearSolve.OperatorAssumptions(true))
#Pl = LinearSolve.InvPreconditioner(Diagonal(_vec(weight))),
#Pr = Diagonal(_vec(weight)))
linprob = LinearProblem(W2, _vec(cubuff); u0 = _vec(dw23))
linsolve2 = init(linprob, alg.linsolve, alias_A = true, alias_b = true,
assumptions = LinearSolve.OperatorAssumptions(true))
#Pl = LinearSolve.InvPreconditioner(Diagonal(_vec(weight))),
#Pr = Diagonal(_vec(weight)))
linprob = LinearProblem(W3, _vec(cubuff); u0 = _vec(dw45))
linsolve3 = init(linprob, alg.linsolve, alias_A = true, alias_b = true,
assumptions = LinearSolve.OperatorAssumptions(true))
#Pl = LinearSolve.InvPreconditioner(Diagonal(_vec(weight))),
#Pr = Diagonal(_vec(weight)))


rtol = reltol isa Number ? reltol : zero(reltol)
atol = reltol isa Number ? reltol : zero(reltol)

RadauIIA7Cache(u, uprev,
z1, z2, z3, z4, z5, w1, w2, w3, w4, w5,
dw1, ubuff, dw23, dw45, cubuff, cont1, cont2, cont3, cont4,
du1, fsalfirst, k, k2, k3, k4, k5, fw1, fw2, fw3, fw4, fw5,
J, W1, W2, W3,
uf, tab, κ, one(uToltype), 10000,
tmp, atmp, jac_config, linsolve1, linsolve2, linsolve3, rtol, atol, dt, dt,
Convergence)

Loading

0 comments on commit 4e974bb

Please sign in to comment.