Skip to content

Commit

Permalink
Merge pull request SciML#2450 from Shreyas-Ekanathan/master
Browse files Browse the repository at this point in the history
Implement Part 1 Of Adaptive Radau Method
  • Loading branch information
ChrisRackauckas authored Sep 20, 2024
2 parents e6ddc71 + bdb0a63 commit 733bfa5
Show file tree
Hide file tree
Showing 9 changed files with 1,211 additions and 95 deletions.
5 changes: 5 additions & 0 deletions lib/OrdinaryDiffEqFIRK/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,20 @@ version = "1.1.1"
[deps]
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898"
GenericLinearAlgebra = "14197337-ba66-59df-a3e3-ca00e7dcff7a"
GenericSchur = "c145ed77-6b09-5dd9-b285-bf645a82121e"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
MuladdMacro = "46d2c3a1-f734-5fdb-9937-b9b9aeba4221"
OrdinaryDiffEqCore = "bbf590c4-e513-4bbe-9b18-05decba2e5d8"
OrdinaryDiffEqDifferentiation = "4302a76b-040a-498a-8c04-15b101fed76b"
OrdinaryDiffEqNonlinearSolve = "127b3ac7-2247-4354-8eb6-78cf4e7c58e8"
Polynomials = "f27b6e38-b328-58d1-80ce-0feddd5e7a45"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
RootedTrees = "47965b36-3f3e-11e9-0dcf-4570dfd42a8c"
SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961"
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"

[compat]
DiffEqBase = "6.152.2"
Expand Down
3 changes: 2 additions & 1 deletion lib/OrdinaryDiffEqFIRK/src/OrdinaryDiffEqFIRK.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import OrdinaryDiffEqCore: alg_order, calculate_residuals!,
get_current_adaptive_order, get_fsalfirstlast,
isfirk, generic_solver_docstring
using MuladdMacro, DiffEqBase, RecursiveArrayTools
using Polynomials, GenericLinearAlgebra, GenericSchur
using SciMLOperators: AbstractSciMLOperator
using LinearAlgebra: I, UniformScaling, mul!, lu
import LinearSolve
Expand All @@ -42,6 +43,6 @@ include("firk_tableaus.jl")
include("firk_perform_step.jl")
include("integrator_interface.jl")

export RadauIIA3, RadauIIA5, RadauIIA9
export RadauIIA3, RadauIIA5, RadauIIA9, AdaptiveRadau

end
2 changes: 2 additions & 0 deletions lib/OrdinaryDiffEqFIRK/src/alg_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@ qmax_default(alg::Union{RadauIIA3, RadauIIA5, RadauIIA9}) = 8
alg_order(alg::RadauIIA3) = 3
alg_order(alg::RadauIIA5) = 5
alg_order(alg::RadauIIA9) = 9
alg_order(alg::AdaptiveRadau) = 5

isfirk(alg::RadauIIA3) = true
isfirk(alg::RadauIIA5) = true
isfirk(alg::RadauIIA9) = true
isfirk(alg::AdaptiveRadau) = true

alg_adaptive_order(alg::RadauIIA3) = 1
alg_adaptive_order(alg::RadauIIA5) = 3
Expand Down
39 changes: 39 additions & 0 deletions lib/OrdinaryDiffEqFIRK/src/algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -150,3 +150,42 @@ function RadauIIA9(; chunk_size = Val{0}(), autodiff = Val{true}(),
controller,
step_limiter!)
end

struct AdaptiveRadau{CS, AD, F, P, FDT, ST, CJ, Tol, C1, C2, StepLimiter} <:
OrdinaryDiffEqNewtonAdaptiveAlgorithm{CS, AD, FDT, ST, CJ}
linsolve::F
precs::P
smooth_est::Bool
extrapolant::Symbol
κ::Tol
maxiters::Int
fast_convergence_cutoff::C1
new_W_γdt_cutoff::C2
controller::Symbol
step_limiter!::StepLimiter
num_stages::Int
end

function AdaptiveRadau(; chunk_size = Val{0}(), autodiff = Val{true}(),
standardtag = Val{true}(), concrete_jac = nothing,
diff_type = Val{:forward}, num_stages = 3,
linsolve = nothing, precs = DEFAULT_PRECS,
extrapolant = :dense, fast_convergence_cutoff = 1 // 5,
new_W_γdt_cutoff = 1 // 5,
controller = :Predictive, κ = nothing, maxiters = 10, smooth_est = true,
step_limiter! = trivial_limiter!)
AdaptiveRadau{_unwrap_val(chunk_size), _unwrap_val(autodiff), typeof(linsolve),
typeof(precs), diff_type, _unwrap_val(standardtag), _unwrap_val(concrete_jac),
typeof(κ), typeof(fast_convergence_cutoff),
typeof(new_W_γdt_cutoff), typeof(step_limiter!)}(linsolve,
precs,
smooth_est,
extrapolant,
κ,
maxiters,
fast_convergence_cutoff,
new_W_γdt_cutoff,
controller,
step_limiter!, num_stages)
end

189 changes: 187 additions & 2 deletions lib/OrdinaryDiffEqFIRK/src/firk_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,7 @@ mutable struct RadauIIA9ConstantCache{F, Tab, Tol, Dt, U, JType} <:
cont2::U
cont3::U
cont4::U
cont5::U
dtprev::Dt
W_γdt::Dt
status::NLStatus
Expand All @@ -304,7 +305,7 @@ function alg_cache(alg::RadauIIA9, u, rate_prototype, ::Type{uEltypeNoUnits},
κ = alg.κ !== nothing ? convert(uToltype, alg.κ) : convert(uToltype, 1 // 100)
J = false .* _vec(rate_prototype) .* _vec(rate_prototype)'

RadauIIA9ConstantCache(uf, tab, κ, one(uToltype), 10000, u, u, u, u, dt, dt,
RadauIIA9ConstantCache(uf, tab, κ, one(uToltype), 10000, u, u, u, u, u, dt, dt,
Convergence, J)
end

Expand Down Expand Up @@ -333,6 +334,7 @@ mutable struct RadauIIA9Cache{uType, cuType, uNoUnitsType, rateType, JType, W1Ty
cont2::uType
cont3::uType
cont4::uType
cont5::uType
du1::rateType
fsalfirst::rateType
k::rateType
Expand Down Expand Up @@ -407,6 +409,7 @@ function alg_cache(alg::RadauIIA9, u, rate_prototype, ::Type{uEltypeNoUnits},
cont2 = zero(u)
cont3 = zero(u)
cont4 = zero(u)
cont5 = zero(u)

fsalfirst = zero(rate_prototype)
k = zero(rate_prototype)
Expand Down Expand Up @@ -462,11 +465,193 @@ function alg_cache(alg::RadauIIA9, u, rate_prototype, ::Type{uEltypeNoUnits},

RadauIIA9Cache(u, uprev,
z1, z2, z3, z4, z5, w1, w2, w3, w4, w5,
dw1, ubuff, dw23, dw45, cubuff1, cubuff2, cont1, cont2, cont3, cont4,
dw1, ubuff, dw23, dw45, cubuff1, cubuff2, cont1, cont2, cont3, cont4, cont5,
du1, fsalfirst, k, k2, k3, k4, k5, fw1, fw2, fw3, fw4, fw5,
J, W1, W2, W3,
uf, tab, κ, one(uToltype), 10000,
tmp, tmp2, tmp3, tmp4, tmp5, tmp6, atmp, jac_config,
linsolve1, linsolve2, linsolve3, rtol, atol, dt, dt,
Convergence, alg.step_limiter!)
end

mutable struct AdaptiveRadauConstantCache{F, Tab, Tol, Dt, U, JType} <:
OrdinaryDiffEqConstantCache
uf::F
tab::Tab
κ::Tol
ηold::Tol
iter::Int
cont::Vector{U}
dtprev::Dt
W_γdt::Dt
status::NLStatus
J::JType
end

function alg_cache(alg::AdaptiveRadau, u, rate_prototype, ::Type{uEltypeNoUnits},
::Type{uBottomEltypeNoUnits},
::Type{tTypeNoUnits}, uprev, uprev2, f, t, dt, reltol, p, calck,
::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
uf = UDerivativeWrapper(f, t, p)
uToltype = constvalue(uBottomEltypeNoUnits)
num_stages = alg.num_stages

if (num_stages == 3)
tab = BigRadauIIA5Tableau(uToltype, constvalue(tTypeNoUnits))
elseif (num_stages == 5)
tab = BigRadauIIA9Tableau(uToltype, constvalue(tTypeNoUnits))
elseif (num_stages == 7)
tab = BigRadauIIA13Tableau(uToltype, constvalue(tTypeNoUnits))
elseif iseven(num_stages) || num_stages <3
error("num_stages must be odd and 3 or greater")
else
tab = adaptiveRadauTableau(uToltype, constvalue(tTypeNoUnits), num_stages)
end

cont = Vector{typeof(u)}(undef, num_stages)
for i in 1: num_stages
cont[i] = zero(u)
end

κ = alg.κ !== nothing ? convert(uToltype, alg.κ) : convert(uToltype, 1 // 100)
J = false .* _vec(rate_prototype) .* _vec(rate_prototype)'

AdaptiveRadauConstantCache(uf, tab, κ, one(uToltype), 10000, cont, dt, dt,
Convergence, J)
end

mutable struct AdaptiveRadauCache{uType, cuType, tType, uNoUnitsType, rateType, JType, W1Type, W2Type,
UF, JC, F1, F2, Tab, Tol, Dt, rTol, aTol, StepLimiter} <:
FIRKMutableCache
u::uType
uprev::uType
z::Vector{uType}
w::Vector{uType}
c_prime::Vector{tType}
dw1::uType
ubuff::uType
dw2::Vector{cuType}
cubuff::Vector{cuType}
dw::Vector{uType}
cont::Vector{uType}
derivatives:: Matrix{uType}
du1::rateType
fsalfirst::rateType
ks::Vector{rateType}
k::rateType
fw::Vector{rateType}
J::JType
W1::W1Type #real
W2::Vector{W2Type} #complex
uf::UF
tab::Tab
κ::Tol
ηold::Tol
iter::Int
tmp::uType
atmp::uNoUnitsType
jac_config::JC
linsolve1::F1 #real
linsolve2::Vector{F2} #complex
rtol::rTol
atol::aTol
dtprev::Dt
W_γdt::Dt
status::NLStatus
step_limiter!::StepLimiter
end

function alg_cache(alg::AdaptiveRadau, u, rate_prototype, ::Type{uEltypeNoUnits},
::Type{uBottomEltypeNoUnits},
::Type{tTypeNoUnits}, uprev, uprev2, f, t, dt, reltol, p, calck,
::Val{true}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
uf = UJacobianWrapper(f, t, p)
uToltype = constvalue(uBottomEltypeNoUnits)
num_stages = alg.num_stages

if (num_stages == 3)
tab = BigRadauIIA5Tableau(uToltype, constvalue(tTypeNoUnits))
elseif (num_stages == 5)
tab = BigRadauIIA9Tableau(uToltype, constvalue(tTypeNoUnits))
elseif (num_stages == 7)
tab = BigRadauIIA13Tableau(uToltype, constvalue(tTypeNoUnits))
elseif iseven(num_stages) || num_stages < 3
error("num_stages must be odd and 3 or greater")
else
tab = adaptiveRadauTableau(uToltype, constvalue(tTypeNoUnits), num_stages)
end

κ = alg.κ !== nothing ? convert(uToltype, alg.κ) : convert(uToltype, 1 // 100)

z = Vector{typeof(u)}(undef, num_stages)
w = Vector{typeof(u)}(undef, num_stages)
for i in 1 : num_stages
z[i] = w[i] = zero(u)
end

c_prime = Vector{typeof(t)}(undef, num_stages) #time stepping

dw1 = zero(u)
ubuff = zero(u)
dw2 = [similar(u, Complex{eltype(u)}) for _ in 1 : (num_stages - 1) ÷ 2]
recursivefill!.(dw2, false)
cubuff = [similar(u, Complex{eltype(u)}) for _ in 1 : (num_stages - 1) ÷ 2]
recursivefill!.(cubuff, false)
dw = Vector{typeof(u)}(undef, num_stages - 1)

cont = Vector{typeof(u)}(undef, num_stages)
for i in 1 : num_stages
cont[i] = zero(u)
end

derivatives = Matrix{typeof(u)}(undef, num_stages, num_stages)
for i in 1 : num_stages, j in 1 : num_stages
derivatives[i, j] = zero(u)
end

fsalfirst = zero(rate_prototype)
fw = Vector{typeof(rate_prototype)}(undef, num_stages)
ks = Vector{typeof(rate_prototype)}(undef, num_stages)
for i in 1: num_stages
ks[i] = fw[i] = zero(rate_prototype)
end
k = ks[1]

J, W1 = build_J_W(alg, u, uprev, p, t, dt, f, uEltypeNoUnits, Val(true))
if J isa AbstractSciMLOperator
error("Non-concrete Jacobian not yet supported by AdaptiveRadau.")
end

W2 = [similar(J, Complex{eltype(W1)}) for _ in 1 : (num_stages - 1) ÷ 2]
recursivefill!.(W2, false)

du1 = zero(rate_prototype)

tmp = zero(u)

atmp = similar(u, uEltypeNoUnits)
recursivefill!(atmp, false)

jac_config = build_jac_config(alg, f, uf, du1, uprev, u, zero(u), dw1)

linprob = LinearProblem(W1, _vec(ubuff); u0 = _vec(dw1))
linsolve1 = init(linprob, alg.linsolve, alias_A = true, alias_b = true,
assumptions = LinearSolve.OperatorAssumptions(true))

linsolve2 = [
init(LinearProblem(W2[i], _vec(cubuff[i]); u0 = _vec(dw2[i])), alg.linsolve, alias_A = true, alias_b = true,
assumptions = LinearSolve.OperatorAssumptions(true)) for i in 1 : (num_stages - 1) ÷ 2]

rtol = reltol isa Number ? reltol : zero(reltol)
atol = reltol isa Number ? reltol : zero(reltol)

AdaptiveRadauCache(u, uprev,
z, w, c_prime, dw1, ubuff, dw2, cubuff, dw, cont, derivatives,
du1, fsalfirst, ks, k, fw,
J, W1, W2,
uf, tab, κ, one(uToltype), 10000, tmp,
atmp, jac_config,
linsolve1, linsolve2, rtol, atol, dt, dt,
Convergence, alg.step_limiter!)
end

Loading

0 comments on commit 733bfa5

Please sign in to comment.