Skip to content

Commit

Permalink
lots of edits
Browse files Browse the repository at this point in the history
  • Loading branch information
Shreyas-Ekanathan committed Aug 21, 2024
1 parent d836402 commit 3d09fa2
Show file tree
Hide file tree
Showing 8 changed files with 157 additions and 192 deletions.
3 changes: 3 additions & 0 deletions lib/OrdinaryDiffEqFIRK/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ MuladdMacro = "46d2c3a1-f734-5fdb-9937-b9b9aeba4221"
FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
OrdinaryDiffEqNonlinearSolve = "127b3ac7-2247-4354-8eb6-78cf4e7c58e8"
GenericLinearAlgebra = "14197337-ba66-59df-a3e3-ca00e7dcff7a"
GenericSchur = "c145ed77-6b09-5dd9-b285-bf645a82121e"
Polynomials = "f27b6e38-b328-58d1-80ce-0feddd5e7a45"

[compat]
julia = "1.10"
Expand Down
1 change: 1 addition & 0 deletions lib/OrdinaryDiffEqFIRK/src/OrdinaryDiffEqFIRK.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import OrdinaryDiffEqCore: alg_order, calculate_residuals!,
get_current_adaptive_order,
isfirk
using MuladdMacro, DiffEqBase, RecursiveArrayTools
using Polynomials, GenericLinearAlgebra, GenericSchur
using SciMLOperators: AbstractSciMLOperator
using LinearAlgebra: I, UniformScaling, mul!, lu
import LinearSolve
Expand Down
1 change: 1 addition & 0 deletions lib/OrdinaryDiffEqFIRK/src/alg_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ qmax_default(alg::Union{RadauIIA3, RadauIIA5, RadauIIA9}) = 8
alg_order(alg::RadauIIA3) = 3
alg_order(alg::RadauIIA5) = 5
alg_order(alg::RadauIIA9) = 9
alg_order(alg::AdaptiveRadau) = 9

isfirk(alg::RadauIIA3) = true
isfirk(alg::RadauIIA5) = true
Expand Down
2 changes: 1 addition & 1 deletion lib/OrdinaryDiffEqFIRK/src/algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ end

function AdaptiveRadau(; chunk_size = Val{0}(), autodiff = Val{true}(),
standardtag = Val{true}(), concrete_jac = nothing,
diff_type = Val{:forward}, num_stages = 5,
diff_type = Val{:forward}, num_stages = 3,
linsolve = nothing, precs = DEFAULT_PRECS,
extrapolant = :dense, fast_convergence_cutoff = 1 // 5,
new_W_γdt_cutoff = 1 // 5,
Expand Down
202 changes: 102 additions & 100 deletions lib/OrdinaryDiffEqFIRK/src/firk_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -468,41 +468,42 @@ function alg_cache(alg::RadauIIA9, u, rate_prototype, ::Type{uEltypeNoUnits},
Convergence, alg.step_limiter!)
end

mutable struct adaptiveRadauConstantCache{S, F, Tab, Tol, Dt, U, JType} <:
mutable struct AdaptiveRadauConstantCache{F, Tab, Tol, Dt, U, JType} <:
OrdinaryDiffEqConstantCache
uf::F
tab::Tab
κ::Tol
ηold::Tol
iter::Int
cont::AbstractVector{U}
dtprev::Dt
W_γdt::Dt
status::NLStatus
J::JType
uf::F
tab::Tab
κ::Tol
ηold::Tol
iter::Int
cont::AbstractVector{U}
dtprev::Dt
W_γdt::Dt
status::NLStatus
J::JType
end

function alg_cache(alg::AdaptiveRadau, u, rate_prototype, ::Type{uEltypeNoUnits},
::Type{uBottomEltypeNoUnits},
::Type{tTypeNoUnits}, uprev, uprev2, f, t, dt, reltol, p, calck,
::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
uf = UDerivativeWrapper(f, t, p)
uToltype = constvalue(uBottomEltypeNoUnits)
tab = adaptiveRadau(uToltype, constvalue(tTypeNoUnits), alg.num_stages)

cont = Vector{typeof(u)}(undef, num_stages - 1)
for i in 1: (num_stages - 1)
cont[i] = zero(u)
end
::Type{uBottomEltypeNoUnits},
::Type{tTypeNoUnits}, uprev, uprev2, f, t, dt, reltol, p, calck,
::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
uf = UDerivativeWrapper(f, t, p)
uToltype = constvalue(uBottomEltypeNoUnits)
num_stages = alg.num_stages
tab = adaptiveRadauTableau(uToltype, constvalue(tTypeNoUnits), num_stages)

cont = Vector{typeof(u)}(undef, num_stages - 1)
for i in 1: (num_stages - 1)
cont[i] = zero(u)
end

κ = alg.κ !== nothing ? convert(uToltype, alg.κ) : convert(uToltype, 1 // 100)
J = false .* _vec(rate_prototype) .* _vec(rate_prototype)'
κ = alg.κ !== nothing ? convert(uToltype, alg.κ) : convert(uToltype, 1 // 100)
J = false .* _vec(rate_prototype) .* _vec(rate_prototype)'

adaptiveRadauConstantCache(uf, tab, κ, one(uToltype), 10000, cont, dt, dt,
Convergence, J)
AdaptiveRadauConstantCache(uf, tab, κ, one(uToltype), 10000, cont, dt, dt,
Convergence, J)
end

mutable struct adaptiveRadauCache{uType, cuType, uNoUnitsType, rateType, JType, W1Type, W2Type,
mutable struct AdaptiveRadauCache{uType, cuType, uNoUnitsType, rateType, JType, W1Type, W2Type,
UF, JC, F1, F2, Tab, Tol, Dt, rTol, aTol, StepLimiter} <:
OrdinaryDiffEqMutableCache
u::uType
Expand Down Expand Up @@ -540,89 +541,90 @@ mutable struct adaptiveRadauCache{uType, cuType, uNoUnitsType, rateType, JType,
end

function alg_cache(alg::AdaptiveRadau, u, rate_prototype, ::Type{uEltypeNoUnits},
::Type{uBottomEltypeNoUnits},
::Type{tTypeNoUnits}, uprev, uprev2, f, t, dt, reltol, p, calck,
::Val{true}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
uf = UJacobianWrapper(f, t, p)
uToltype = constvalue(uBottomEltypeNoUnits)
tab = RadauIIA9Tableau(uToltype, constvalue(tTypeNoUnits))

κ = alg.κ !== nothing ? convert(uToltype, alg.κ) : convert(uToltype, 1 // 100)

z = Vector{typeof(u)}(undef, num_stages)
w = Vector{typeof(u)}(undef, num_stages)
for i in 1:s
z[i] = w[i] = zero(u)
end
::Type{uBottomEltypeNoUnits},
::Type{tTypeNoUnits}, uprev, uprev2, f, t, dt, reltol, p, calck,
::Val{true}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
uf = UJacobianWrapper(f, t, p)
uToltype = constvalue(uBottomEltypeNoUnits)
alg.num_stages = num_stages
tab = adaptiveRadauTableau(uToltype, constvalue(tTypeNoUnits), num_stages)

dw1 = zero(u)
ubuff = zero(u)
dw2 = Vector{typeof(u)}(undef, floor(Int, num_stages/2))
for i in 1 : floor(Int, num_stages/2)
dw2[i] = similar(u, Complex{eltype(u)})
recursivefill!(dw[i], false)
end
cubuff = Vector{typeof(u)}(undef, floor(Int, num_stages/2))
for i in 1 :floor(Int, num_stages/2)
cubuff[i] = similar(u, Complex{eltype(u)})
recursivefill!(cubuff[i], false)
end
κ = alg.κ !== nothing ? convert(uToltype, alg.κ) : convert(uToltype, 1 // 100)

cont = Vector{typeof(u)}(undef, num_stages - 1)
for i in 1: (num_stages - 1)
cont[i] = zero(u)
end
z = Vector{typeof(u)}(undef, num_stages)
w = Vector{typeof(u)}(undef, num_stages)
for i in 1:s
z[i] = w[i] = zero(u)
end

fsalfirst = zero(rate_prototype)
fw = Vector{typeof(rate_prototype)}(undef, num_stages)
k = Vector{typeof(rate_prototype)}(undef, num_stages)
for i in 1: num_stages
k[i] = fw[i] = zero(rate_prototype)
end
dw1 = zero(u)
ubuff = zero(u)
dw2 = Vector{typeof(u)}(undef, floor(Int, num_stages/2))
for i in 1 : floor(Int, num_stages/2)
dw2[i] = similar(u, Complex{eltype(u)})
recursivefill!(dw[i], false)
end
cubuff = Vector{typeof(u)}(undef, floor(Int, num_stages/2))
for i in 1 :floor(Int, num_stages/2)
cubuff[i] = similar(u, Complex{eltype(u)})
recursivefill!(cubuff[i], false)
end

J, W1 = build_J_W(alg, u, uprev, p, t, dt, f, uEltypeNoUnits, Val(true))
if J isa AbstractSciMLOperator
error("Non-concrete Jacobian not yet supported by RadauIIA5.")
end
W2 = vector{typeof(Complex{W1})}(undef, floor(Int, num_stages/2))
for i in 1 : floor(Int, num_stages/2)
W2[i] = similar(J, Complex{eltype(W1)})
recursivefill!(w2[i], false)
end
cont = Vector{typeof(u)}(undef, num_stages - 1)
for i in 1: (num_stages - 1)
cont[i] = zero(u)
end

du1 = zero(rate_prototype)
fsalfirst = zero(rate_prototype)
fw = Vector{typeof(rate_prototype)}(undef, num_stages)
k = Vector{typeof(rate_prototype)}(undef, num_stages)
for i in 1: num_stages
k[i] = fw[i] = zero(rate_prototype)
end

tmp = Vector{typeof(u)}(undef, binomial(num_stages,2))
for i in 1 : binomial(num_stages,2)
tmp[i] = zero(u)
end
J, W1 = build_J_W(alg, u, uprev, p, t, dt, f, uEltypeNoUnits, Val(true))
if J isa AbstractSciMLOperator
error("Non-concrete Jacobian not yet supported by RadauIIA5.")
end
W2 = vector{typeof(Complex{W1})}(undef, floor(Int, num_stages/2))
for i in 1 : floor(Int, num_stages/2)
W2[i] = similar(J, Complex{eltype(W1)})
recursivefill!(w2[i], false)
end

atmp = similar(u, uEltypeNoUnits)
recursivefill!(atmp, false)
du1 = zero(rate_prototype)

jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, dw1)
tmp = Vector{typeof(u)}(undef, binomial(num_stages,2))
for i in 1 : binomial(num_stages,2)
tmp[i] = zero(u)
end

linprob = LinearProblem(W1, _vec(ubuff); u0 = _vec(dw1))
linsolve1 = init(linprob, alg.linsolve, alias_A = true, alias_b = true,
assumptions = LinearSolve.OperatorAssumptions(true))
atmp = similar(u, uEltypeNoUnits)
recursivefill!(atmp, false)

linsolve2 = Vector{typeof(linsolve1)}(undef, floor(Int, num_stages/2))
for i in 1 : floor(int, num_stages/2)
linprob = LinearProblem(W2[i], _vec(cubuff[i]); u0 = _vec(dw2[i]))
linsolve2 = init(linprob, alg.linsolve, alias_A = true, alias_b = true,
assumptions = LinearSolve.OperatorAssumptions(true))
end
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, dw1)

linprob = LinearProblem(W1, _vec(ubuff); u0 = _vec(dw1))
linsolve1 = init(linprob, alg.linsolve, alias_A = true, alias_b = true,
assumptions = LinearSolve.OperatorAssumptions(true))

linsolve2 = Vector{typeof(linsolve1)}(undef, floor(Int, num_stages/2))
for i in 1 : floor(int, num_stages/2)
linprob = LinearProblem(W2[i], _vec(cubuff[i]); u0 = _vec(dw2[i]))
linsolve2 = init(linprob, alg.linsolve, alias_A = true, alias_b = true,
assumptions = LinearSolve.OperatorAssumptions(true))
end

rtol = reltol isa Number ? reltol : zero(reltol)
atol = reltol isa Number ? reltol : zero(reltol)

rtol = reltol isa Number ? reltol : zero(reltol)
atol = reltol isa Number ? reltol : zero(reltol)

adaptiveRadauCache(u, uprev,
z, w, dw1, ubuff, dw2, cubuff, cont,
du1, fsalfirst, k, fw,
J, W1, W2,
uf, tab, κ, one(uToltype), 10000,
tmp, atmp, jac_config,
linsolve1, linsolve2, rtol, atol, dt, dt,
Convergence, alg.step_limiter!)
AdaptiveRadauCache(u, uprev,
z, w, dw1, ubuff, dw2, cubuff, cont,
du1, fsalfirst, k, fw,
J, W1, W2,
uf, tab, κ, one(uToltype), 10000,
tmp, atmp, jac_config,
linsolve1, linsolve2, rtol, atol, dt, dt,
Convergence, alg.step_limiter!)
end

Loading

0 comments on commit 3d09fa2

Please sign in to comment.