Skip to content

Commit

Permalink
add to tableau, create cache, oop method
Browse files Browse the repository at this point in the history
  • Loading branch information
Shreyas-Ekanathan committed Aug 6, 2024
1 parent 5b0cbe1 commit c887941
Show file tree
Hide file tree
Showing 3 changed files with 397 additions and 7 deletions.
159 changes: 159 additions & 0 deletions lib/OrdinaryDiffEqFIRK/src/firk_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -467,3 +467,162 @@ function alg_cache(alg::RadauIIA9, u, rate_prototype, ::Type{uEltypeNoUnits},
linsolve1, linsolve2, linsolve3, rtol, atol, dt, dt,
Convergence, alg.step_limiter!)
end

mutable struct adaptiveRadauConstantCache{F, Tab, Tol, Dt, U, JType, S} <:
OrdinaryDiffEqConstantCache
uf::F
tab::Tab
κ::Tol
ηold::Tol
iter::Int
cont::AbstractVector{U}
dtprev::Dt
W_γdt::Dt
status::NLStatus
J::JType
end

function alg_cache(alg::adaptiveRadau, s :: Int64, u, rate_prototype, ::Type{uEltypeNoUnits},
::Type{uBottomEltypeNoUnits},
::Type{tTypeNoUnits}, uprev, uprev2, f, t, dt, reltol, p, calck,
::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
uf = UDerivativeWrapper(f, t, p)
uToltype = constvalue(uBottomEltypeNoUnits)
tab = adaptiveRadau(uToltype, constvalue(tTypeNoUnits), s)

cont = Vector{typeof(u)}(undef, s-1)
for i in 1:s-1
cont[i] = zero(u)
end

κ = alg.κ !== nothing ? convert(uToltype, alg.κ) : convert(uToltype, 1 // 100)
J = false .* _vec(rate_prototype) .* _vec(rate_prototype)'

adaptiveRadauConstantCache(uf, tab, κ, one(uToltype), 10000, cont, dt, dt,
Convergence, J)
end

mutable struct adaptiveRadauCache{uType, cuType, uNoUnitsType, rateType, JType, W1Type, W2Type,
UF, JC, F1, F2, Tab, Tol, Dt, rTol, aTol, StepLimiter} <:
OrdinaryDiffEqMutableCache
u::uType
uprev::uType
z::AbstractVector{uType}
w::AbstractVector{uType}
dw1::uType
ubuff::uType
dw2::AbstractVector{cuType}
cubuff::AbstractVector{cuType}
cont::AbstractVector{uType}
du1::rateType
fsalfirst::rateType
k::AbstractVector{rateType}
fw::AbstractVector{rateType}
J::JType
W1::W1Type #real
W2::AbstractVector{W2Type} #complex
uf::UF
tab::Tab
κ::Tol
ηold::Tol
iter::Int
tmp::AbstractVector{uType}
atmp::uNoUnitsType
jac_config::JC
linsolve1::F1 #real
linsolve2::AbstractVector{F2} #complex
rtol::rTol
atol::aTol
dtprev::Dt
W_γdt::Dt
status::NLStatus
step_limiter!::StepLimiter
end

function alg_cache(alg::adaptiveRadau, u, rate_prototype, ::Type{uEltypeNoUnits},
::Type{uBottomEltypeNoUnits},
::Type{tTypeNoUnits}, uprev, uprev2, f, t, dt, reltol, p, calck,
::Val{true}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
uf = UJacobianWrapper(f, t, p)
uToltype = constvalue(uBottomEltypeNoUnits)
tab = RadauIIA9Tableau(uToltype, constvalue(tTypeNoUnits))

κ = alg.κ !== nothing ? convert(uToltype, alg.κ) : convert(uToltype, 1 // 100)

z = Vector{typeof(u)}(undef, s)
w = Vector{typeof(u)}(undef, s)
for i in 1:s
z[i] = w[i] = zero(u)
end

dw1 = zero(u)
ubuff = zero(u)
dw2 = Vector{typeof(u)}(undef, floor(Int, s/2))
for i in 1 : floor(Int, s/2)
dw2[i] = similar(u, Complex{eltype(u)})
recursivefill!(dw[i], false)
end
cubuff = Vector{typeof(u)}(undef, floor(Int, s/2))
for i in 1 :floor(Int, s/2)
cubuff[i] = similar(u, Complex{eltype(u)})
recursivefill!(cubuff[i], false)
end

cont = Vector{typeof(u)}(undef, s-1)
for i in 1:s-1
cont[i] = zero(u)
end

fsalfirst = zero(rate_prototype)
fw = Vector{typeof(rate_prototype)}(undef, s)
k = Vector{typeof(rate_prototype)}(undef, s)
for i in 1:s
k[i] = fw[i] = zero(rate_prototype)
end

J, W1 = build_J_W(alg, u, uprev, p, t, dt, f, uEltypeNoUnits, Val(true))
if J isa AbstractSciMLOperator
error("Non-concrete Jacobian not yet supported by RadauIIA5.")
end
W2 = vector{typeof(Complex{W1})}(undef, floor(Int, s/2))
for i in 1 : floor(Int, s/2)
W2[i] = similar(J, Complex{eltype(W1)})
recursivefill!(w2[i], false)
end

du1 = zero(rate_prototype)

tmp = Vector{typeof(u)}(undef, binomial(s,2))
for i in 1 : binomial(s,2)
tmp[i] = zero(u)
end

atmp = similar(u, uEltypeNoUnits)
recursivefill!(atmp, false)

jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, dw1)

linprob = LinearProblem(W1, _vec(ubuff); u0 = _vec(dw1))
linsolve1 = init(linprob, alg.linsolve, alias_A = true, alias_b = true,
assumptions = LinearSolve.OperatorAssumptions(true))

linsolve2 = Vector{typeof(linsolve1)}(undef, floor(Int, s/2))
for i in 1 : floor(int, s/2)
linprob = LinearProblem(W2[i], _vec(cubuff[i]); u0 = _vec(dw2[i]))
linsolve2 = init(linprob, alg.linsolve, alias_A = true, alias_b = true,
assumptions = LinearSolve.OperatorAssumptions(true))
end

rtol = reltol isa Number ? reltol : zero(reltol)
atol = reltol isa Number ? reltol : zero(reltol)

adaptiveRadauCache(u, uprev,
z, w, dw1, ubuff, dw2, cubuff, cont,
du1, fsalfirst, k, fw,
J, W1, W2,
uf, tab, κ, one(uToltype), 10000,
tmp, atmp, jac_config,
linsolve1, linsolve2, rtol, atol, dt, dt,
Convergence, alg.step_limiter!)
end

197 changes: 197 additions & 0 deletions lib/OrdinaryDiffEqFIRK/src/firk_perform_step.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1339,3 +1339,200 @@ end
integrator.stats.nf += 1
return
end

@muladd function perform_step!(integrator, cache::adaptiveRadauConstantCache,
repeat_step = false, s::Int64)
@unpack t, dt, uprev, u, f, p = integrator
@unpack T, TI, γ, α, β, c, e= cache.tab
@unpack κ, cont = cache
@unpack internalnorm, abstol, reltol, adaptive = integrator.opts
alg = unwrap_alg(integrator, true)
@unpack maxiters = alg
mass_matrix = integrator.f.mass_matrix

# precalculations rtol pow is (num stages + 1)/(2*num stages)
rtol = @.. broadcast=false reltol^(5 / 8)/10
atol = @.. broadcast=false rtol*(abstol / reltol)

γdt, αdt, βdt = γ / dt, α ./ dt, β ./ dt

J = calc_J(integrator, cache)
LU = Vector{Any}(undef, (s + 1) / 2)
if u isa Number
LU[1] = -γdt * mass_matrix + J
for i in 2 : (s + 1) / 2
LU[i] = -(α[i - 1]dt + β[i - 1]dt * im) * mass_matrix + J
end
else
LU[1] = lu(-γdt * mass_matrix + J)
for i in 2 : (s + 1) / 2
LU[i] = lu(-(α[i - 1]dt + β[i - 1]dt * im) * mass_matrix + J)
end
end
integrator.stats.nw += 1

if integrator.iter == 1 || integrator.u_modified || alg.extrapolant == :constant
cache.dtprev = one(cache.dtprev)
for i in 1:s
z[i] = w[i] = map(zero, u)
end
for i in 1:s-1
cont[i] = map(zero, u)
end
else
c' = Vector{eltype(u)}(undef, s) #time stepping
c'[s] = dt / cache.dtprev
for i in 1 : s-1
c'[i] = c[i] * c'[s]
end
for i in 1 : s # collocation polynomial
z[i] = @.. cont[s-1] * (c'[i] - c[1] + 1) + cont[s-1]
j = s - 2
while j > 0
z[i] = @.. z[i] * (c'[i] - c[s-j] + 1) + cont[j]
end
z[i] = @.. z[i] * c'[i]
end
w = @.. TI * z
end

# Newton iteration
local ndw
η = max(cache.ηold, eps(eltype(integrator.opts.reltol)))^(0.8)
fail_convergence = true
iter = 0
while iter < maxiters
iter += 1
integrator.stats.nnonliniter += 1

# evaluate function
ff = Vector{eltype(u)}(undef, s)
for i in 1:s
ff[i] = f(uprev + z[i], p, t + c[i] * dt)
end
integrator.stats.nf += 5

fw = @.. TI * ff
Mw = Vector{eltype(u)}(undef, s)
if mass_matrix isa UniformScaling # `UniformScaling` doesn't play nicely with broadcast
for i in 1:s
Mw[i] = @.. mass_matrix.λ * w[i] #scaling by eigenvalues
end
else
Mw = mass_matrix * w #standard multiplication
end

rhs = Vector{eltype(u)}(undef, s)
rhs[1] = @.. fw[1]-γdt * Mw[1]
i = 2
while i <= s #block by block multiplication
rhs[i] = @.. fw[i] - α[i/2]dt * Mw[i] + β[i/2]dt * Mw[i + 1]
rhs[i + 1] = @.. fw[i + 1] - β[i/2]dt * Mw[i] - α[i/2]dt * Mw[i + 1]
i += 2
end

dw = Vector{eltype(u)}(undef, s)
dw[1] = LU1 \ rhs[1]
for i in 2 : (s + 1) / 2
tmp = LU[i] \ (@.. rhs[2 * i - 2] + rhs[2 * i - 1] * im)
dw[2 * i - 2] = real(tmp)
dw[2 * i - 1] = imag(tmp)
end
integrator.stats.nlsolve += (s + 1) / 2

# compute norm of residuals
iter > 1 && (ndwprev = ndw)
atmp = Vector{eltype(u)}(undef, s)
for i in 1:s
atmp[i] = calculate_residuals(dw[i], uprev, u, atol, rtol, internalnorm, t)
end

ndw = 0
for i in 1:s
ndw = ndw + internalnorm(atmp[i], t)
end
# check divergence (not in initial step)

if iter > 1
θ = ndw / ndwprev
(diverge = θ > 1) && (cache.status = Divergence)
(veryslowconvergence = ndw * θ^(maxiters - iter) > κ * (1 - θ)) &&
(cache.status = VerySlowConvergence)
if diverge || veryslowconvergence
break
end
end

for i in 1 : s
w[i] = @.. w[i] - dw[i]
end
# transform `w` to `z`
z = @.. T * w

# check stopping criterion
iter > 1 &&= θ / (1 - θ))
if η * ndw < κ && (iter > 1 || iszero(ndw) || !iszero(integrator.success_iter))
# Newton method converges
cache.status = η < alg.fast_convergence_cutoff ? FastConvergence :
Convergence
fail_convergence = false
break
end
end

if fail_convergence
integrator.force_stepfail = true
integrator.stats.nnonlinconvfail += 1
return
end
cache.ηold = η
cache.iter = iter

u = @.. uprev + z[s]

if adaptive
edt = e ./ dt
tmp = @.. dot(edt, z)
mass_matrix != I && (tmp = mass_matrix * tmp)
utilde = @.. broadcast=false integrator.fsalfirst+tmp
alg.smooth_est && (utilde = LU[1] \ utilde; integrator.stats.nsolve += 1)
atmp = calculate_residuals(utilde, uprev, u, atol, rtol, internalnorm, t)
integrator.EEst = internalnorm(atmp, t)

if !(integrator.EEst < oneunit(integrator.EEst)) && integrator.iter == 1 ||
integrator.u_modified
f0 = f(uprev .+ utilde, p, t)
integrator.stats.nf += 1
utilde = @.. broadcast=false f0+tmp
alg.smooth_est && (utilde = LU[1] \ utilde; integrator.stats.nsolve += 1)
atmp = calculate_residuals(utilde, uprev, u, atol, rtol, internalnorm, t)
integrator.EEst = internalnorm(atmp, t)
end
end

if integrator.EEst <= oneunit(integrator.EEst)
cache.dtprev = dt
if alg.extrapolant != :constant
derivatives = Matrix{eltype(u)}(undef, s-1, s-1)
for i in 1 : (s - 1)
for j in i : (s-1)
if i == 1
derivatives[i, j] = @.. (z[i] - z[i + 1]) / (c[i] - c[i + 1]) #first derivatives
else
derivatives[i, j] = @.. (derivatives[i - 1, j - 1] - derivatives[i - 1, j]) / (c[j - i + 1] - c[j + 1]) #all others
end
end
end
for i in 1 : (s-1)
cache.cont[i] = derivatives[i, i]
end
end
end

integrator.fsallast = f(u, p, t + dt)
integrator.stats.nf += 1
integrator.k[1] = integrator.fsalfirst
integrator.k[2] = integrator.fsallast
integrator.u = u
return
end
Loading

0 comments on commit c887941

Please sign in to comment.