Skip to content

Commit

Permalink
little things
Browse files Browse the repository at this point in the history
  • Loading branch information
Shreyas-Ekanathan committed Sep 6, 2024
1 parent 40baa62 commit 12630e0
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 21 deletions.
10 changes: 8 additions & 2 deletions lib/OrdinaryDiffEqFIRK/src/firk_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -532,6 +532,7 @@ mutable struct AdaptiveRadauCache{uType, cuType, uNoUnitsType, rateType, JType,
dw2::Vector{cuType}
cubuff::Vector{cuType}
cont::Vector{uType}
derivatives:: Matrix{uType}
du1::rateType
fsalfirst::rateType
ks::Vector{rateType}
Expand Down Expand Up @@ -594,10 +595,15 @@ function alg_cache(alg::AdaptiveRadau, u, rate_prototype, ::Type{uEltypeNoUnits}
recursivefill!.(cubuff, false)

cont = Vector{typeof(u)}(undef, num_stages)
for i in 1: num_stages
for i in 1 : num_stages
cont[i] = zero(u)
end

derivatives = Matrix{typeof(u)}(undef, num_stages, num_stages)
for i in 1 : num_stages, j in 1 : num_stages
derivatives[i, j] = zero(u)
end

fsalfirst = zero(rate_prototype)
fw = Vector{typeof(rate_prototype)}(undef, num_stages)
ks = Vector{typeof(rate_prototype)}(undef, num_stages)
Expand Down Expand Up @@ -635,7 +641,7 @@ function alg_cache(alg::AdaptiveRadau, u, rate_prototype, ::Type{uEltypeNoUnits}
atol = reltol isa Number ? reltol : zero(reltol)

AdaptiveRadauCache(u, uprev,
z, w, dw1, ubuff, dw2, cubuff, cont,
z, w, dw1, ubuff, dw2, cubuff, cont, derivatives,
du1, fsalfirst, ks, k, fw,
J, W1, W2,
uf, tab, κ, one(uToltype), 10000, tmp,
Expand Down
32 changes: 13 additions & 19 deletions lib/OrdinaryDiffEqFIRK/src/firk_perform_step.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1369,21 +1369,16 @@ end

J = calc_J(integrator, cache)

#if u isa Number
# LU1 = Complex(-γdt * mass_matrix + J)
# LU2 = -(αdt[1] + βdt[1] * im) * mass_matrix + J
#else
LU1 = lu(-γdt * mass_matrix + J)
LU2 = lu(-(αdt[1] + βdt[1] * im) * mass_matrix + J)
#end
LU = [LU2 for _ in 1:(num_stages + 1) ÷ 2]
LU2 = Vector{Complex{typeof(u)}}(undef, (num_stages - 1) ÷ 2)
if u isa Number
for i in 3 :(num_stages + 1) ÷ 2
LU[i] = -(αdt[i - 1] + βdt[i - 1] * im) * mass_matrix + J
LU1 = -γdt * mass_matrix + J
for i in 1 : (num_stages - 1) ÷ 2
LU2[i] = -(αdt[i] + βdt[i] * im) * mass_matrix + J
end
else
for i in 3 :(num_stages + 1) ÷ 2
LU[i] = lu(-(αdt[i - 1] + βdt[i - 1] * im) * mass_matrix + J)
LU1 = lu(-γdt * mass_matrix + J)
for i in 1 : (num_stages - 1) ÷ 2
LU2[i] = lu(-(αdt[i] + βdt[i] * im) * mass_matrix + J)
end
end

Expand Down Expand Up @@ -1453,7 +1448,7 @@ end
dw = Vector{typeof(u)}(undef, num_stages)
dw[1] = _reshape(LU1 \ _vec(rhs[1]), axes(u))
for i in 2 :(num_stages + 1) ÷ 2
tmp = _reshape(LU[i] \ _vec(@.. rhs[2 * i - 2] + rhs[2 * i - 1] * im), axes(u))
tmp = _reshape(LU2[i - 1] \ _vec(@.. rhs[2 * i - 2] + rhs[2 * i - 1] * im), axes(u))
dw[2 * i - 2] = real(tmp)
dw[2 * i - 1] = imag(tmp)
end
Expand Down Expand Up @@ -1555,7 +1550,7 @@ end
@muladd function perform_step!(integrator, cache::AdaptiveRadauCache, repeat_step = false)
@unpack t, dt, uprev, u, f, p, fsallast, fsalfirst = integrator
@unpack T, TI, γ, α, β, c, #=e,=# num_stages = cache.tab
@unpack κ, cont, z, w = cache
@unpack κ, cont, derivatives, z, w = cache
@unpack dw1, ubuff, dw2, cubuff = cache
@unpack ks, k, fw, J, W1, W2 = cache
@unpack tmp, atmp, jac_config, linsolve1, linsolve2, rtol, atol, step_limiter! = cache
Expand Down Expand Up @@ -1763,15 +1758,14 @@ end
if integrator.EEst <= oneunit(integrator.EEst)
cache.dtprev = dt
if alg.extrapolant != :constant
derivatives = Matrix{typeof(u)}(undef, num_stages, num_stages)
derivatives[1, 1] = @.. z[1] / c[1]
@.. derivatives[1, 1] = z[1] / c[1]
for j in 2 : num_stages
derivatives[1, j] = @.. (z[j - 1] - z[j]) / (c[j - 1] - c[j]) #first derivatives
@.. derivatives[1, j] = (z[j - 1] - z[j]) / (c[j - 1] - c[j]) #first derivatives
end
for i in 2 : num_stages
derivatives[i, i] = @.. (derivatives[i - 1, i] - derivatives[i - 1, i - 1]) / c[i]
@.. derivatives[i, i] = (derivatives[i - 1, i] - derivatives[i - 1, i - 1]) / c[i]
for j in i+1 : num_stages
derivatives[i, j] = @.. (derivatives[i - 1, j - 1] - derivatives[i - 1, j]) / (c[j - i] - c[j]) #all others
@.. derivatives[i, j] = (derivatives[i - 1, j - 1] - derivatives[i - 1, j]) / (c[j - i] - c[j]) #all others
end
end
for i in 1 : num_stages
Expand Down

0 comments on commit 12630e0

Please sign in to comment.