From a68502859bf7f5a245c91079ff068e1764cf591d Mon Sep 17 00:00:00 2001 From: Shreyas-Ekanathan Date: Sat, 10 Aug 2024 18:12:22 -0400 Subject: [PATCH] add in-place --- .../src/firk_perform_step.jl | 272 +++++++++++++++++- lib/OrdinaryDiffEqFIRK/src/firk_tableaus.jl | 20 +- 2 files changed, 271 insertions(+), 21 deletions(-) diff --git a/lib/OrdinaryDiffEqFIRK/src/firk_perform_step.jl b/lib/OrdinaryDiffEqFIRK/src/firk_perform_step.jl index 7077b291ff..81a4f814c0 100644 --- a/lib/OrdinaryDiffEqFIRK/src/firk_perform_step.jl +++ b/lib/OrdinaryDiffEqFIRK/src/firk_perform_step.jl @@ -1024,7 +1024,7 @@ end @unpack dw1, ubuff, dw23, dw45, cubuff1, cubuff2 = cache @unpack k, k2, k3, k4, k5, fw1, fw2, fw3, fw4, fw5 = cache @unpack J, W1, W2, W3 = cache - @unpack tmp, tmp2, tmp3, tmp4, tmp5, tmp6, atmp, jac_config, linsolve1, linsolve2, rtol, atol, step_limiter! = cache + @unpack tmp, tmp2, tmp3, tmp4, tmp5, tmp6, atmp, jac_config, linsolve1, linsolve2, linsolve3, rtol, atol, step_limiter! = cache @unpack internalnorm, abstol, reltol, adaptive = integrator.opts alg = unwrap_alg(integrator, true) @unpack maxiters = alg @@ -1078,26 +1078,26 @@ end c2′ = c2 * c5′ c3′ = c3 * c5′ c4′ = c4 * c5′ - z1 = @.. broadcast=false c1′*(cont1 + + @.. broadcast=false z1 = c1′*(cont1 + (c1′ - c3m1) * (cont2 + (c1′ - c2m1) * (cont3 + (c1′ - c1m1) * cont4))) - z2 = @.. broadcast=false c2′*(cont1 + + @.. broadcast=false z2 = c2′*(cont1 + (c2′ - c3m1) * (cont2 + (c2′ - c2m1) * (cont3 + (c2′ - c1m1) * cont4))) - z3 = @.. broadcast=false c3′*(cont1 + + @.. broadcast=false z3 = c3′*(cont1 + (c3′ - c3m1) * (cont2 + (c3′ - c2m1) * (cont3 + (c3′ - c1m1) * cont4))) - z4 = @.. broadcast=false c4′*(cont1 + + @.. broadcast=false z4 = c4′*(cont1 + (c4′ - c3m1) * (cont2 + (c4′ - c2m1) * (cont3 + (c4′ - c1m1) * cont4))) - z5 = @.. broadcast=false c5′*(cont1 + + @.. broadcast=false z5 = c5′*(cont1 + (c5′ - c3m1) * (cont2 + (c5′ - c2m1) * (cont3 + (c5′ - c1m1) * cont4))) - w1 = @.. broadcast=false TI11*z1+TI12*z2+TI13*z3+TI14*z4+TI15*z5 - w2 = @.. broadcast=false TI21*z1+TI22*z2+TI23*z3+TI24*z4+TI25*z5 - w3 = @.. broadcast=false TI31*z1+TI32*z2+TI33*z3+TI34*z4+TI35*z5 - w4 = @.. broadcast=false TI41*z1+TI42*z2+TI43*z3+TI44*z4+TI45*z5 - w5 = @.. broadcast=false TI51*z1+TI52*z2+TI53*z3+TI54*z4+TI55*z5 + @.. broadcast=false w1 = TI11*z1+TI12*z2+TI13*z3+TI14*z4+TI15*z5 + @.. broadcast=false w2 = TI21*z1+TI22*z2+TI23*z3+TI24*z4+TI25*z5 + @.. broadcast=false w3 = TI31*z1+TI32*z2+TI33*z3+TI34*z4+TI35*z5 + @.. broadcast=false w4 = TI41*z1+TI42*z2+TI43*z3+TI44*z4+TI45*z5 + @.. broadcast=false w5 = TI51*z1+TI52*z2+TI53*z3+TI54*z4+TI55*z5 end # Newton iteration @@ -1376,7 +1376,7 @@ end for i in 1:s z[i] = w[i] = map(zero, u) end - for i in 1:s-1 + for i in 1:(s-1) cont[i] = map(zero, u) end else @@ -1463,9 +1463,8 @@ end end end - for i in 1 : s - w[i] = @.. w[i] - dw[i] - end + w = @.. w - dw + # transform `w` to `z` z = @.. T * w @@ -1536,3 +1535,246 @@ end integrator.u = u return end + +@muladd function perform_step!(integrator, cache::adaptiveRadauCache, repeat_step = false) + @unpack t, dt, uprev, u, f, p, fsallast, fsalfirst = integrator + @unpack T, TI, γ, α, β, c, e= cache.tab + @unpack κ, cont, z, w = cache + @unpack dw1, ubuff, dw2, cubuff1, cubuff2 = cache + @unpack k, fw, J, W1, W2 = cache + @unpack tmp, atmp, jac_config, linsolve1, linsolve2, rtol, atol, step_limiter! = cache + @unpack internalnorm, abstol, reltol, adaptive = integrator.opts + alg = unwrap_alg(integrator, true) + @unpack maxiters = alg + mass_matrix = integrator.f.mass_matrix + + # precalculations + + γdt, αdt, βdt = γ / dt, α ./ dt, β ./ dt + (new_jac = do_newJ(integrator, alg, cache, repeat_step)) && + (calc_J!(J, integrator, cache); cache.W_γdt = dt) + if (new_W = do_newW(integrator, alg, new_jac, cache.W_γdt)) + @inbounds for II in CartesianIndices(J) + W1[II] = -γdt * mass_matrix[Tuple(II)...] + J[II] + for i in 1 : (s - 1) / 2 + W2[i][II] = -(α[i]dt + β[i]dt * im) * mass_matrix[Tuple(II)...] + J[II] + end + end + integrator.stats.nw += 1 + end + + # TODO better initial guess + if integrator.iter == 1 || integrator.u_modified || alg.extrapolant == :constant + cache.dtprev = one(cache.dtprev) + uzero = zero(eltype(u)) + for i in 1:s + @.. z[i] = w[i] = uzero + end + for i in 1:(s-1) + @.. cache.cont[i] = uzero + end + else + c' = Vector{eltype(u)}(undef, s) #time stepping + c'[s] = dt / cache.dtprev + for i in 1 : s-1 + c'[i] = c[i] * c'[s] + end + for i in 1 : s # collocation polynomial + @.. z[i] = cont[s-1] * (c'[i] - c[1] + 1) + cont[s-1] + j = s - 2 + while j > 0 + @.. z[i] = z[i] * (c'[i] - c[s-j] + 1) + cont[j] + end + @.. z[i] = z[i] * c'[i] + end + @.. w = TI * z + end + + # Newton iteration + local ndw + η = max(cache.ηold, eps(eltype(integrator.opts.reltol)))^(0.8) + fail_convergence = true + iter = 0 + while iter < maxiters + iter += 1 + integrator.stats.nnonliniter += 1 + + # evaluate function + k[1] = fsallast + for i in 1 : s + @.. tmp = uprev + z[i] + f(k[i], tmp, p, t + c[i] * dt) + end + integrator.stats.nf += s + + @.. fw = TI * k + if mass_matrix === I + Mw = w + elseif mass_matrix isa UniformScaling + for i in 1 : s + mul!(z[i], mass_matrix.λ, w[i]) + end + Mw = z + else + for i in 1 : s + mul!(z[i], mass_matrix.λ, w[i]) + end + Mw = z + end + + @.. ubuff = fw[1] - γdt * Mw[1] + needfactor = iter == 1 && new_W + + linsolve1 = cache.linsolve1 + linres = Vector{BigFloat}(undef, s) + if needfactor + linres[1] = dolinsolve(integrator, linsolve1; A = W1, b = _vec(ubuff), + linu = _vec(dw1)) + else + linres[1] = dolinsolve(integrator, linsolve1; A = nothing, b = _vec(ubuff), + linu = _vec(dw1)) + end + + cache.linsolve1 = linres1.cache + + for i in 1 : (s-1)/2 + @.. broadcast=false cubuff[i]=complex( + fw2 - α[i]dt * Mw[2*i] + β[i]dt * Mw[2*i + 1], fw3 - β[i]dt * Mw[2*i] - α[i]dt * Mw[2*i + 1]) + linsolve2[i] = cache.linsolve2[i] + if needfactor + linres[i + 1] = dolinsolve(integrator, linsolve2[i]; A = W2[i], b = _vec(cubuff[i]), + linu = _vec(dw2[i])) + else + linres[i + 1] = dolinsolve(integrator, linsolve2[i]; A = nothing, b = _vec(cubuff[i]), + linu = _vec(dw2[i])) + end + cache.linsolve2[i] = linres[i + 1].cache + end + + integrator.stats.nsolve += (s+1) / 2 + dw[1] = dw1 + i = 2 + while i <= s + dw[i] = z[i] + dw[i + 1] = z[i + 1] + @.. dw[i] = real(dw2[i - 1]) + @.. dw[i + 1] = imag(dw2[i - 1]) + i += 2 + end + + # compute norm of residuals + iter > 1 && (ndwprev = ndw) + ndws = Vector{BigFloat}(undef, s) + calculate_residuals!(atmp, dw[1], uprev, u, atol, rtol, internalnorm, t) + ndws[1] = internalnorm(atmp, t) + for i in 2:s + calculate_residuals!(atmp, dw[i - 1], uprev, u, atol, rtol, internalnorm, t) + ndws[i] = internalnorm(atmp, t) + end + + ndw = 0 + for i in 1 : s + ndw += ndws[i] + end + + # check divergence (not in initial step) + + if iter > 1 + θ = ndw / ndwprev + (diverge = θ > 1) && (cache.status = Divergence) + (veryslowconvergence = ndw * θ^(maxiters - iter) > κ * (1 - θ)) && + (cache.status = VerySlowConvergence) + if diverge || veryslowconvergence + break + end + end + + @.. w = w - dw + + # transform `w` to `z` + @.. z = T * w + # check stopping criterion + + iter > 1 && (η = θ / (1 - θ)) + if η * ndw < κ && (iter > 1 || iszero(ndw) || !iszero(integrator.success_iter)) + # Newton method converges + cache.status = η < alg.fast_convergence_cutoff ? FastConvergence : + Convergence + fail_convergence = false + break + end + end + if fail_convergence + integrator.force_stepfail = true + integrator.stats.nnonlinconvfail += 1 + return + end + + cache.ηold = η + cache.iter = iter + + @.. broadcast=false u=uprev + z[s] + + step_limiter!(u, integrator, p, t + dt) + + if adaptive + utilde = w2 + edt = e./dt + @.. tmp= dot(edt, z) + mass_matrix != I && (mul!(w1, mass_matrix, tmp); copyto!(tmp, w1)) + @.. ubuff=integrator.fsalfirst + tmp + + if alg.smooth_est + linres1 = dolinsolve(integrator, linres1.cache; b = _vec(ubuff), + linu = _vec(utilde)) + cache.linsolve1 = linres1.cache + integrator.stats.nsolve += 1 + end + + # RadauIIA5 needs a transformed rtol and atol see + # https://github.com/luchr/ODEInterface.jl/blob/0bd134a5a358c4bc13e0fb6a90e27e4ee79e0115/src/radau5.f#L399-L421 + calculate_residuals!(atmp, utilde, uprev, u, atol, rtol, internalnorm, t) + integrator.EEst = internalnorm(atmp, t) + + if !(integrator.EEst < oneunit(integrator.EEst)) && integrator.iter == 1 || + integrator.u_modified + @.. broadcast=false utilde=uprev + utilde + f(fsallast, utilde, p, t) + integrator.stats.nf += 1 + @.. broadcast=false ubuff=fsallast + tmp + + if alg.smooth_est + linres1 = dolinsolve(integrator, linres1.cache; b = _vec(ubuff), + linu = _vec(utilde)) + cache.linsolve1 = linres1.cache + integrator.stats.nsolve += 1 + end + + calculate_residuals!(atmp, utilde, uprev, u, atol, rtol, internalnorm, t) + integrator.EEst = internalnorm(atmp, t) + end + end + + if integrator.EEst <= oneunit(integrator.EEst) + cache.dtprev = dt + if alg.extrapolant != :constant + derivatives = Matrix{eltype(u)}(undef, s-1, s-1) + for i in 1 : (s - 1) + for j in i : (s-1) + if i == 1 + @.. derivatives[i, j] = (z[i] - z[i + 1]) / (c[i] - c[i + 1]) #first derivatives + else + @.. derivatives[i, j] = (derivatives[i - 1, j - 1] - derivatives[i - 1, j]) / (c[j - i + 1] - c[j + 1]) #all others + end + end + end + for i in 1 : (s-1) + cache.cont[i] = derivatives[i, s - 1] + end + end + end + + f(fsallast, u, p, t + dt) + integrator.stats.nf += 1 + return +end diff --git a/lib/OrdinaryDiffEqFIRK/src/firk_tableaus.jl b/lib/OrdinaryDiffEqFIRK/src/firk_tableaus.jl index 217049d94c..478bdd4240 100644 --- a/lib/OrdinaryDiffEqFIRK/src/firk_tableaus.jl +++ b/lib/OrdinaryDiffEqFIRK/src/firk_tableaus.jl @@ -269,7 +269,7 @@ struct adaptiveRadauTableau(T, T2) e::AbstractVector{T} end -using Polynomials, GenericLinearAlgebra, LinearAlgebra, LinearSolve, GenericSchur +using Polynomials, GenericLinearAlgebra, LinearAlgebra, LinearSolve, GenericSchur, BSeries function adaptiveRadauTableau(T, T2, s::Int64) tmp = Vector{BigFloat}(undef, s-1) @@ -301,15 +301,19 @@ function adaptiveRadauTableau(T, T2, s::Int64) end a = c_q * inverse_c_powers a_inverse = inv(a) - b = eigvals(a_inverse) + b = Vector{Bigfloat}(undef, s) + for i in 1 : s + b[i] = a[s, i] + end + vals = eigvals(a_inverse) γ = real(b[s]) α = Vector{BigFloat}(undef, floor(Int, s/2)) β = Vector{BigFloat}(undef, floor(Int, s/2)) index = 1 i = 1 while i <= (s-1) - α[index] = real(b[i]) - β[index] = imag(b[i + 1]) + α[index] = real(vals[i]) + β[index] = imag(vals[i + 1]) index = index + 1 i = i + 2 end @@ -332,7 +336,11 @@ function adaptiveRadauTableau(T, T2, s::Int64) end end TI = inv(T) - #adaptiveRadauTableau(T, TI, γ, α, β, c, e) + b_hat = Vector{BigFloat}(undef, s) + embedded = bseries(a, b_hat, c, s - 2) + + #e = b_hat - b + #adaptiveRadautableau(T, TI, γ, α, β, c, e) end -adaptiveRadauTableau(0, 0, 1) +adaptiveRadauTableau(0, 0, 3)