Skip to content

Commit

Permalink
add in-place
Browse files Browse the repository at this point in the history
  • Loading branch information
Shreyas-Ekanathan committed Aug 10, 2024
1 parent 45641a3 commit a685028
Show file tree
Hide file tree
Showing 2 changed files with 271 additions and 21 deletions.
272 changes: 257 additions & 15 deletions lib/OrdinaryDiffEqFIRK/src/firk_perform_step.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1024,7 +1024,7 @@ end
@unpack dw1, ubuff, dw23, dw45, cubuff1, cubuff2 = cache
@unpack k, k2, k3, k4, k5, fw1, fw2, fw3, fw4, fw5 = cache
@unpack J, W1, W2, W3 = cache
@unpack tmp, tmp2, tmp3, tmp4, tmp5, tmp6, atmp, jac_config, linsolve1, linsolve2, rtol, atol, step_limiter! = cache
@unpack tmp, tmp2, tmp3, tmp4, tmp5, tmp6, atmp, jac_config, linsolve1, linsolve2, linsolve3, rtol, atol, step_limiter! = cache
@unpack internalnorm, abstol, reltol, adaptive = integrator.opts
alg = unwrap_alg(integrator, true)
@unpack maxiters = alg
Expand Down Expand Up @@ -1078,26 +1078,26 @@ end
c2′ = c2 * c5′
c3′ = c3 * c5′
c4′ = c4 * c5′
z1 = @.. broadcast=false c1′*(cont1 +
@.. broadcast=false z1 = c1′*(cont1 +
(c1′ - c3m1) * (cont2 +
(c1′ - c2m1) * (cont3 + (c1′ - c1m1) * cont4)))
z2 = @.. broadcast=false c2′*(cont1 +
@.. broadcast=false z2 = c2′*(cont1 +
(c2′ - c3m1) * (cont2 +
(c2′ - c2m1) * (cont3 + (c2′ - c1m1) * cont4)))
z3 = @.. broadcast=false c3′*(cont1 +
@.. broadcast=false z3 = c3′*(cont1 +
(c3′ - c3m1) * (cont2 +
(c3′ - c2m1) * (cont3 + (c3′ - c1m1) * cont4)))
z4 = @.. broadcast=false c4′*(cont1 +
@.. broadcast=false z4 = c4′*(cont1 +
(c4′ - c3m1) * (cont2 +
(c4′ - c2m1) * (cont3 + (c4′ - c1m1) * cont4)))
z5 = @.. broadcast=false c5′*(cont1 +
@.. broadcast=false z5 = c5′*(cont1 +
(c5′ - c3m1) * (cont2 +
(c5′ - c2m1) * (cont3 + (c5′ - c1m1) * cont4)))
w1 = @.. broadcast=false TI11*z1+TI12*z2+TI13*z3+TI14*z4+TI15*z5
w2 = @.. broadcast=false TI21*z1+TI22*z2+TI23*z3+TI24*z4+TI25*z5
w3 = @.. broadcast=false TI31*z1+TI32*z2+TI33*z3+TI34*z4+TI35*z5
w4 = @.. broadcast=false TI41*z1+TI42*z2+TI43*z3+TI44*z4+TI45*z5
w5 = @.. broadcast=false TI51*z1+TI52*z2+TI53*z3+TI54*z4+TI55*z5
@.. broadcast=false w1 = TI11*z1+TI12*z2+TI13*z3+TI14*z4+TI15*z5
@.. broadcast=false w2 = TI21*z1+TI22*z2+TI23*z3+TI24*z4+TI25*z5
@.. broadcast=false w3 = TI31*z1+TI32*z2+TI33*z3+TI34*z4+TI35*z5
@.. broadcast=false w4 = TI41*z1+TI42*z2+TI43*z3+TI44*z4+TI45*z5
@.. broadcast=false w5 = TI51*z1+TI52*z2+TI53*z3+TI54*z4+TI55*z5
end

# Newton iteration
Expand Down Expand Up @@ -1376,7 +1376,7 @@ end
for i in 1:s
z[i] = w[i] = map(zero, u)
end
for i in 1:s-1
for i in 1:(s-1)
cont[i] = map(zero, u)
end
else
Expand Down Expand Up @@ -1463,9 +1463,8 @@ end
end
end

for i in 1 : s
w[i] = @.. w[i] - dw[i]
end
w = @.. w - dw

# transform `w` to `z`
z = @.. T * w

Expand Down Expand Up @@ -1536,3 +1535,246 @@ end
integrator.u = u
return
end

@muladd function perform_step!(integrator, cache::adaptiveRadauCache, repeat_step = false)
@unpack t, dt, uprev, u, f, p, fsallast, fsalfirst = integrator
@unpack T, TI, γ, α, β, c, e= cache.tab
@unpack κ, cont, z, w = cache
@unpack dw1, ubuff, dw2, cubuff1, cubuff2 = cache
@unpack k, fw, J, W1, W2 = cache
@unpack tmp, atmp, jac_config, linsolve1, linsolve2, rtol, atol, step_limiter! = cache
@unpack internalnorm, abstol, reltol, adaptive = integrator.opts
alg = unwrap_alg(integrator, true)
@unpack maxiters = alg
mass_matrix = integrator.f.mass_matrix

# precalculations

γdt, αdt, βdt = γ / dt, α ./ dt, β ./ dt
(new_jac = do_newJ(integrator, alg, cache, repeat_step)) &&
(calc_J!(J, integrator, cache); cache.W_γdt = dt)
if (new_W = do_newW(integrator, alg, new_jac, cache.W_γdt))
@inbounds for II in CartesianIndices(J)
W1[II] = -γdt * mass_matrix[Tuple(II)...] + J[II]
for i in 1 : (s - 1) / 2
W2[i][II] = -(α[i]dt + β[i]dt * im) * mass_matrix[Tuple(II)...] + J[II]
end
end
integrator.stats.nw += 1
end

# TODO better initial guess
if integrator.iter == 1 || integrator.u_modified || alg.extrapolant == :constant
cache.dtprev = one(cache.dtprev)
uzero = zero(eltype(u))
for i in 1:s
@.. z[i] = w[i] = uzero
end
for i in 1:(s-1)
@.. cache.cont[i] = uzero
end
else
c' = Vector{eltype(u)}(undef, s) #time stepping
c'[s] = dt / cache.dtprev
for i in 1 : s-1
c'[i] = c[i] * c'[s]
end
for i in 1 : s # collocation polynomial
@.. z[i] = cont[s-1] * (c'[i] - c[1] + 1) + cont[s-1]
j = s - 2
while j > 0
@.. z[i] = z[i] * (c'[i] - c[s-j] + 1) + cont[j]
end
@.. z[i] = z[i] * c'[i]
end
@.. w = TI * z
end

# Newton iteration
local ndw
η = max(cache.ηold, eps(eltype(integrator.opts.reltol)))^(0.8)
fail_convergence = true
iter = 0
while iter < maxiters
iter += 1
integrator.stats.nnonliniter += 1

# evaluate function
k[1] = fsallast
for i in 1 : s
@.. tmp = uprev + z[i]
f(k[i], tmp, p, t + c[i] * dt)
end
integrator.stats.nf += s

@.. fw = TI * k
if mass_matrix === I
Mw = w
elseif mass_matrix isa UniformScaling
for i in 1 : s
mul!(z[i], mass_matrix.λ, w[i])
end
Mw = z
else
for i in 1 : s
mul!(z[i], mass_matrix.λ, w[i])
end
Mw = z
end

@.. ubuff = fw[1] - γdt * Mw[1]
needfactor = iter == 1 && new_W

linsolve1 = cache.linsolve1
linres = Vector{BigFloat}(undef, s)
if needfactor
linres[1] = dolinsolve(integrator, linsolve1; A = W1, b = _vec(ubuff),
linu = _vec(dw1))
else
linres[1] = dolinsolve(integrator, linsolve1; A = nothing, b = _vec(ubuff),
linu = _vec(dw1))
end

cache.linsolve1 = linres1.cache

for i in 1 : (s-1)/2
@.. broadcast=false cubuff[i]=complex(
fw2 - α[i]dt * Mw[2*i] + β[i]dt * Mw[2*i + 1], fw3 - β[i]dt * Mw[2*i] - α[i]dt * Mw[2*i + 1])
linsolve2[i] = cache.linsolve2[i]
if needfactor
linres[i + 1] = dolinsolve(integrator, linsolve2[i]; A = W2[i], b = _vec(cubuff[i]),
linu = _vec(dw2[i]))
else
linres[i + 1] = dolinsolve(integrator, linsolve2[i]; A = nothing, b = _vec(cubuff[i]),
linu = _vec(dw2[i]))
end
cache.linsolve2[i] = linres[i + 1].cache
end

integrator.stats.nsolve += (s+1) / 2
dw[1] = dw1
i = 2
while i <= s
dw[i] = z[i]
dw[i + 1] = z[i + 1]
@.. dw[i] = real(dw2[i - 1])
@.. dw[i + 1] = imag(dw2[i - 1])
i += 2
end

# compute norm of residuals
iter > 1 && (ndwprev = ndw)
ndws = Vector{BigFloat}(undef, s)
calculate_residuals!(atmp, dw[1], uprev, u, atol, rtol, internalnorm, t)
ndws[1] = internalnorm(atmp, t)
for i in 2:s
calculate_residuals!(atmp, dw[i - 1], uprev, u, atol, rtol, internalnorm, t)
ndws[i] = internalnorm(atmp, t)
end

ndw = 0
for i in 1 : s
ndw += ndws[i]
end

# check divergence (not in initial step)

if iter > 1
θ = ndw / ndwprev
(diverge = θ > 1) && (cache.status = Divergence)
(veryslowconvergence = ndw * θ^(maxiters - iter) > κ * (1 - θ)) &&
(cache.status = VerySlowConvergence)
if diverge || veryslowconvergence
break
end
end

@.. w = w - dw

# transform `w` to `z`
@.. z = T * w
# check stopping criterion

iter > 1 &&= θ / (1 - θ))
if η * ndw < κ && (iter > 1 || iszero(ndw) || !iszero(integrator.success_iter))
# Newton method converges
cache.status = η < alg.fast_convergence_cutoff ? FastConvergence :
Convergence
fail_convergence = false
break
end
end
if fail_convergence
integrator.force_stepfail = true
integrator.stats.nnonlinconvfail += 1
return
end

cache.ηold = η
cache.iter = iter

@.. broadcast=false u=uprev + z[s]

step_limiter!(u, integrator, p, t + dt)

if adaptive
utilde = w2
edt = e./dt
@.. tmp= dot(edt, z)
mass_matrix != I && (mul!(w1, mass_matrix, tmp); copyto!(tmp, w1))
@.. ubuff=integrator.fsalfirst + tmp

if alg.smooth_est
linres1 = dolinsolve(integrator, linres1.cache; b = _vec(ubuff),
linu = _vec(utilde))
cache.linsolve1 = linres1.cache
integrator.stats.nsolve += 1
end

# RadauIIA5 needs a transformed rtol and atol see
# https://github.com/luchr/ODEInterface.jl/blob/0bd134a5a358c4bc13e0fb6a90e27e4ee79e0115/src/radau5.f#L399-L421
calculate_residuals!(atmp, utilde, uprev, u, atol, rtol, internalnorm, t)
integrator.EEst = internalnorm(atmp, t)

if !(integrator.EEst < oneunit(integrator.EEst)) && integrator.iter == 1 ||
integrator.u_modified
@.. broadcast=false utilde=uprev + utilde
f(fsallast, utilde, p, t)
integrator.stats.nf += 1
@.. broadcast=false ubuff=fsallast + tmp

if alg.smooth_est
linres1 = dolinsolve(integrator, linres1.cache; b = _vec(ubuff),
linu = _vec(utilde))
cache.linsolve1 = linres1.cache
integrator.stats.nsolve += 1
end

calculate_residuals!(atmp, utilde, uprev, u, atol, rtol, internalnorm, t)
integrator.EEst = internalnorm(atmp, t)
end
end

if integrator.EEst <= oneunit(integrator.EEst)
cache.dtprev = dt
if alg.extrapolant != :constant
derivatives = Matrix{eltype(u)}(undef, s-1, s-1)
for i in 1 : (s - 1)
for j in i : (s-1)
if i == 1
@.. derivatives[i, j] = (z[i] - z[i + 1]) / (c[i] - c[i + 1]) #first derivatives
else
@.. derivatives[i, j] = (derivatives[i - 1, j - 1] - derivatives[i - 1, j]) / (c[j - i + 1] - c[j + 1]) #all others
end
end
end
for i in 1 : (s-1)
cache.cont[i] = derivatives[i, s - 1]
end
end
end

f(fsallast, u, p, t + dt)
integrator.stats.nf += 1
return
end
20 changes: 14 additions & 6 deletions lib/OrdinaryDiffEqFIRK/src/firk_tableaus.jl
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ struct adaptiveRadauTableau(T, T2)
e::AbstractVector{T}
end

using Polynomials, GenericLinearAlgebra, LinearAlgebra, LinearSolve, GenericSchur
using Polynomials, GenericLinearAlgebra, LinearAlgebra, LinearSolve, GenericSchur, BSeries

function adaptiveRadauTableau(T, T2, s::Int64)
tmp = Vector{BigFloat}(undef, s-1)
Expand Down Expand Up @@ -301,15 +301,19 @@ function adaptiveRadauTableau(T, T2, s::Int64)
end
a = c_q * inverse_c_powers
a_inverse = inv(a)
b = eigvals(a_inverse)
b = Vector{Bigfloat}(undef, s)
for i in 1 : s
b[i] = a[s, i]
end
vals = eigvals(a_inverse)
γ = real(b[s])
α = Vector{BigFloat}(undef, floor(Int, s/2))
β = Vector{BigFloat}(undef, floor(Int, s/2))
index = 1
i = 1
while i <= (s-1)
α[index] = real(b[i])
β[index] = imag(b[i + 1])
α[index] = real(vals[i])
β[index] = imag(vals[i + 1])
index = index + 1
i = i + 2
end
Expand All @@ -332,7 +336,11 @@ function adaptiveRadauTableau(T, T2, s::Int64)
end
end
TI = inv(T)
#adaptiveRadauTableau(T, TI, γ, α, β, c, e)
b_hat = Vector{BigFloat}(undef, s)
embedded = bseries(a, b_hat, c, s - 2)

#e = b_hat - b
#adaptiveRadautableau(T, TI, γ, α, β, c, e)
end

adaptiveRadauTableau(0, 0, 1)
adaptiveRadauTableau(0, 0, 3)

0 comments on commit a685028

Please sign in to comment.