Skip to content

Commit

Permalink
fix collocation on adaptive radau
Browse files Browse the repository at this point in the history
  • Loading branch information
Shreyas-Ekanathan committed Aug 23, 2024
1 parent 9eb4538 commit 1d7a4bd
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 27 deletions.
8 changes: 4 additions & 4 deletions lib/OrdinaryDiffEqFIRK/src/firk_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -497,8 +497,8 @@ function alg_cache(alg::AdaptiveRadau, u, rate_prototype, ::Type{uEltypeNoUnits}
num_stages = alg.num_stages
tab = adaptiveRadauTableau(uToltype, constvalue(tTypeNoUnits), num_stages)

cont = Vector{typeof(u)}(undef, num_stages - 1)
for i in 1: (num_stages - 1)
cont = Vector{typeof(u)}(undef, num_stages)
for i in 1: num_stages
cont[i] = zero(u)
end

Expand Down Expand Up @@ -576,8 +576,8 @@ function alg_cache(alg::AdaptiveRadau, u, rate_prototype, ::Type{uEltypeNoUnits}
recursivefill!(cubuff[i], false)
end

cont = Vector{typeof(u)}(undef, num_stages - 1)
for i in 1: (num_stages - 1)
cont = Vector{typeof(u)}(undef, num_stages)
for i in 1: num_stages
cont[i] = zero(u)
end

Expand Down
43 changes: 20 additions & 23 deletions lib/OrdinaryDiffEqFIRK/src/firk_perform_step.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1344,10 +1344,7 @@ end
if integrator.iter == 1 || integrator.u_modified || alg.extrapolant == :constant
cache.dtprev = one(cache.dtprev)
for i in 1 : num_stages
z[i] = w[i] = map(zero, u)
end
for i in 1 : (num_stages - 1)
cache.cont[i] = map(zero, u)
z[i] = w[i] = cache.cont[i] = map(zero, u)
end
else
c_prime = Vector{eltype(u)}(undef, num_stages) #time stepping
Expand All @@ -1356,8 +1353,8 @@ end
c_prime[i] = c[i] * c_prime[num_stages]
end
for i in 1 : num_stages # collocation polynomial
z[i] = @.. cont[num_stages - 1] * (c_prime[i] - c[1] + 1) + cont[num_stages - 2]
j = num_stages - 3
z[i] = @.. cont[num_stages] * (c_prime[i] - c[1] + 1) + cont[num_stages - 1]
j = num_stages - 2
while j > 0
z[i] = @.. z[i] * (c_prime[i] - c[num_stages- j - 1] + 1) + cont[j]
j = j - 1
Expand Down Expand Up @@ -1483,18 +1480,20 @@ end
if integrator.EEst <= oneunit(integrator.EEst)
cache.dtprev = dt
if alg.extrapolant != :constant
derivatives = Matrix{eltype(u)}(undef, num_stages - 1, num_stages - 1)
for i in 1 : (num_stages - 1)
for j in i : (num_stages - 1)
derivatives = Matrix{eltype(u)}(undef, num_stages, num_stages)
pushfirst!(c, 0)
for i in 1 : num_stages
for j in i : num_stages
if i == 1
derivatives[i, j] = @.. (z[i] - z[i + 1]) / (c[i] - c[i + 1]) #first derivatives
else
derivatives[i, j] = @.. (derivatives[i - 1, j - 1] - derivatives[i - 1, j]) / (c[j - i + 1] - c[j + 1]) #all others
end
end
end
popfirst!(c)
for i in 1 : (num_stages - 1)
cache.cont[i] = derivatives[i, num_stages - 1]
cache.cont[i] = derivatives[i, num_stages]
end
end
end
Expand Down Expand Up @@ -1537,12 +1536,8 @@ end
# TODO better initial guess
if integrator.iter == 1 || integrator.u_modified || alg.extrapolant == :constant
cache.dtprev = one(cache.dtprev)
uzero = zero(eltype(u))
for i in 1 : num_stages
@.. z[i] = w[i] = uzero
end
for i in 1 : (num_stages-1)
@.. cache.cont[i] = uzero
z[i] = w[i] = cache.cont[i] = map(zero, u)
end
else
c_prime = Vector{eltype(u)}(undef, num_stages) #time stepping
Expand All @@ -1551,8 +1546,8 @@ end
c_prime[i] = c[i] * c_prime[num_stages]
end
for i in 1 : num_stages # collocation polynomial
z[i] = @.. cont[num_stages - 1] * (c_prime[i] - c[1] + 1) + cont[num_stages - 2]
j = num_stages - 3
z[i] = @.. cont[num_stages] * (c_prime[i] - c[1] + 1) + cont[num_stages - 1]
j = num_stages - 2
while j > 0
z[i] = @.. z[i] * (c_prime[i] - c[num_stages- j - 1] + 1) + cont[j]
j = j - 1
Expand Down Expand Up @@ -1728,18 +1723,20 @@ end
if integrator.EEst <= oneunit(integrator.EEst)
cache.dtprev = dt
if alg.extrapolant != :constant
derivatives = Matrix{eltype(u)}(undef, num_stages - 1, num_stages - 1)
for i in 1 : (num_stages - 1)
for j in i : (num_stages - 1)
derivatives = Matrix{eltype(u)}(undef, num_stages, num_stages)
pushfirst!(c, 0)
for i in 1 : num_stages
for j in i : num_stages
if i == 1
@.. derivatives[i, j] = (z[i] - z[i + 1]) / (c[i] - c[i + 1]) #first derivatives
derivatives[i, j] = @.. (z[i] - z[i + 1]) / (c[i] - c[i + 1]) #first derivatives
else
@.. derivatives[i, j] = (derivatives[i - 1, j - 1] - derivatives[i - 1, j]) / (c[j - i + 1] - c[j + 1]) #all others
derivatives[i, j] = @.. (derivatives[i - 1, j - 1] - derivatives[i - 1, j]) / (c[j - i + 1] - c[j + 1]) #all others
end
end
end
popfirst!(c)
for i in 1 : (num_stages - 1)
cache.cont[i] = derivatives[i, num_stages - 1]
cache.cont[i] = derivatives[i, num_stages]
end
end
end
Expand Down

0 comments on commit 1d7a4bd

Please sign in to comment.