Skip to content

Commit

Permalink
small edits
Browse files Browse the repository at this point in the history
  • Loading branch information
Shreyas-Ekanathan committed Sep 14, 2024
1 parent 9c15269 commit 75168c8
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 10 deletions.
6 changes: 3 additions & 3 deletions lib/OrdinaryDiffEqFIRK/src/firk_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -520,14 +520,14 @@ function alg_cache(alg::AdaptiveRadau, u, rate_prototype, ::Type{uEltypeNoUnits}
Convergence, J)
end

mutable struct AdaptiveRadauCache{uType, cuType, uNoUnitsType, rateType, JType, W1Type, W2Type,
mutable struct AdaptiveRadauCache{uType, cuType, tType, uNoUnitsType, rateType, JType, W1Type, W2Type,
UF, JC, F1, F2, Tab, Tol, Dt, rTol, aTol, StepLimiter} <:
FIRKMutableCache
u::uType
uprev::uType
z::Vector{uType}
w::Vector{uType}
c_prime::Vector{BigFloat}
c_prime::Vector{tType}
dw1::uType
ubuff::uType
dw2::Vector{cuType}
Expand Down Expand Up @@ -589,7 +589,7 @@ function alg_cache(alg::AdaptiveRadau, u, rate_prototype, ::Type{uEltypeNoUnits}
z[i] = w[i] = zero(u)
end

c_prime = Vector{BigFloat}(undef, num_stages) #time stepping
c_prime = Vector{typeof(t)}(undef, num_stages) #time stepping

dw1 = zero(u)
ubuff = zero(u)
Expand Down
35 changes: 28 additions & 7 deletions lib/OrdinaryDiffEqFIRK/src/firk_perform_step.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1498,12 +1498,20 @@ end

# transform `w` to `z`
#z = T * w
for i in 1:num_stages
for i in 1:num_stages - 1
z[i] = zero(u)
for j in 1:num_stages
z[i] += T[i,j] * w[j]
end
end
z[num_stages] = T[num_stages, 1] * w[1]
i = 2
while i < num_stages
z[num_stages] += w[i]
i += 2
end


# check stopping criterion
iter > 1 &&= θ / (1 - θ))
if η * ndw < κ && (iter > 1 || iszero(ndw) || !iszero(integrator.success_iter))
Expand All @@ -1524,13 +1532,16 @@ end
cache.iter = iter

u = @.. uprev + z[num_stages]
#=

if adaptive
edt = e ./ dt
tmp = @.. dot(edt, z)
tmp = dot(edt, z)
mass_matrix != I && (tmp = mass_matrix * tmp)
utilde = @.. broadcast=false integrator.fsalfirst+tmp
alg.smooth_est && (utilde = LU[1] \ utilde; integrator.stats.nsolve += 1)
if alg.smooth_est
utilde = _reshape(LU1 \ _vec(utilde), axes(u))
integrator.stats.nsolve += 1
end
atmp = calculate_residuals(utilde, uprev, u, atol, rtol, internalnorm, t)
integrator.EEst = internalnorm(atmp, t)

Expand All @@ -1539,12 +1550,15 @@ end
f0 = f(uprev .+ utilde, p, t)
integrator.stats.nf += 1
utilde = @.. broadcast=false f0+tmp
alg.smooth_est && (utilde = LU[1] \ utilde; integrator.stats.nsolve += 1)
if alg.smooth_est
utilde = _reshape(LU1 \ _vec(utilde), axes(u))
integrator.stats.nsolve += 1
end
atmp = calculate_residuals(utilde, uprev, u, atol, rtol, internalnorm, t)
integrator.EEst = internalnorm(atmp, t)
end
end
=#

if integrator.EEst <= oneunit(integrator.EEst)
cache.dtprev = dt
if alg.extrapolant != :constant
Expand Down Expand Up @@ -1729,12 +1743,19 @@ end

# transform `w` to `z`
#mul!(z, T, w)
for i in 1:num_stages
for i in 1:num_stages - 1
z[i] = zero(u)
for j in 1:num_stages
z[i] += T[i,j] * w[j]
end
end
z[num_stages] = T[num_stages, 1] * w[1]
i = 2
while i < num_stages
z[num_stages] += w[i]
i += 2
end

# check stopping criterion
iter > 1 &&= θ / (1 - θ))
if η * ndw < κ && (iter > 1 || iszero(ndw) || !iszero(integrator.success_iter))
Expand Down

0 comments on commit 75168c8

Please sign in to comment.