Skip to content

Commit

Permalink
IN PLACE
Browse files Browse the repository at this point in the history
  • Loading branch information
Shreyas-Ekanathan committed Sep 1, 2024
1 parent bff13df commit 55259d5
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 41 deletions.
2 changes: 1 addition & 1 deletion lib/OrdinaryDiffEqFIRK/src/firk_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -594,7 +594,7 @@ function alg_cache(alg::AdaptiveRadau, u, rate_prototype, ::Type{uEltypeNoUnits}
if J isa AbstractSciMLOperator
error("Non-concrete Jacobian not yet supported by RadauIIA5.")
end
W2 = Vector{Any}(undef, floor(Int, num_stages/2))
W2 = Vector{Any}(undef, floor(Int, num_stages / 2))
for i in 1 : floor(Int, num_stages / 2)
W2[i] = similar(J, Complex{eltype(W1)})
recursivefill!(W2[i], false)
Expand Down
54 changes: 24 additions & 30 deletions lib/OrdinaryDiffEqFIRK/src/firk_perform_step.jl
Original file line number Diff line number Diff line change
Expand Up @@ -770,6 +770,7 @@ end
@.. broadcast=false cache.cont3=cache.cont2 - (tmp - z1 / c1) / c2
end
end

f(fsallast, u, p, t + dt)
OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1)
return
Expand Down Expand Up @@ -1175,11 +1176,9 @@ end
linsolve1 = cache.linsolve1

if needfactor
linres1 = dolinsolve(integrator, linsolve1; A = W1, b = _vec(ubuff),
linu = _vec(dw1))
linres1 = dolinsolve(integrator, linsolve1; A = W1, b = _vec(ubuff), linu = _vec(dw1))
else
linres1 = dolinsolve(integrator, linsolve1; A = nothing, b = _vec(ubuff),
linu = _vec(dw1))
linres1 = dolinsolve(integrator, linsolve1; A = nothing, b = _vec(ubuff), linu = _vec(dw1))
end

cache.linsolve1 = linres1.cache
Expand All @@ -1190,11 +1189,9 @@ end
linsolve2 = cache.linsolve2

if needfactor
linres2 = dolinsolve(integrator, linsolve2; A = W2, b = _vec(cubuff1),
linu = _vec(dw23))
linres2 = dolinsolve(integrator, linsolve2; A = W2, b = _vec(cubuff1), linu = _vec(dw23))
else
linres2 = dolinsolve(integrator, linsolve2; A = nothing, b = _vec(cubuff1),
linu = _vec(dw23))
linres2 = dolinsolve(integrator, linsolve2; A = nothing, b = _vec(cubuff1), linu = _vec(dw23))
end

cache.linsolve2 = linres2.cache
Expand All @@ -1205,11 +1202,9 @@ end
linsolve3 = cache.linsolve3

if needfactor
linres3 = dolinsolve(integrator, linsolve3; A = W3, b = _vec(cubuff2),
linu = _vec(dw45))
linres3 = dolinsolve(integrator, linsolve3; A = W3, b = _vec(cubuff2), linu = _vec(dw45))
else
linres3 = dolinsolve(integrator, linsolve3; A = nothing, b = _vec(cubuff2),
linu = _vec(dw45))
linres3 = dolinsolve(integrator, linsolve3; A = nothing, b = _vec(cubuff2), linu = _vec(dw45))
end

cache.linsolve3 = linres3.cache
Expand Down Expand Up @@ -1596,7 +1591,7 @@ end
end
z[i] = @.. z[i] * c_prime[i]
end
w = @.. TI * z
w = TI * z
end

# Newton iteration
Expand Down Expand Up @@ -1636,48 +1631,45 @@ end

linsolve1 = cache.linsolve1
if needfactor
linres = dolinsolve(integrator, linsolve1; A = W1, b = _vec(ubuff),
linu = _vec(dw1))
linres = dolinsolve(integrator, linsolve1; A = W1, b = _vec(ubuff), linu = _vec(dw1))
else
linres = dolinsolve(integrator, linsolve1; A = nothing, b = _vec(ubuff),
linu = _vec(dw1))
linres = dolinsolve(integrator, linsolve1; A = nothing, b = _vec(ubuff), linu = _vec(dw1))
end

cache.linsolve1 = linres.cache

linres2 = Vector{Any}(undef, Int((num_stages - 1) / 2))

for i in 1 : Int((num_stages - 1) / 2)
@.. cubuff[i]=complex(
fw[2 * i] - αdt[i] * Mw[2 * i] + βdt[i] * Mw[2 * i + 1], fw[2 * i + 1] - βdt[i] * Mw[2 * i] - αdt[i] * Mw[2 * i + 1])
linsolve2[i] = cache.linsolve2[i]
if needfactor
linres2[i] = dolinsolve(integrator, linsolve2[i]; A = W2[i], b = _vec(cubuff[i]),
linu = _vec(dw2[i]))
linres2[i] = dolinsolve(integrator, linsolve2[i]; A = W2[i], b = _vec(cubuff[i]), linu = _vec(dw2[i]))
else
linres2[i] = dolinsolve(integrator, linsolve2[i]; A = nothing, b = _vec(cubuff[i]),
linu = _vec(dw2[i]))
linres2[i] = dolinsolve(integrator, linsolve2[i]; A = nothing, b = _vec(cubuff[i]), linu = _vec(dw2[i]))
end
cache.linsolve2[i] = linres2[i].cache
end

integrator.stats.nsolve += (num_stages + 1) / 2
dw = Vector{Any}(undef, num_stages - 1)
i = 1

while i <= Int((num_stages - 1) / 2)
dw[i] = z[i]
dw[i + 1] = z[i + 1]
@.. dw[i] = real(dw2[i])
@.. dw[i + 1] = imag(dw2[i])
i += 2
dw[2 * i - 1] = z[2 * i - 1]
dw[2 * i] = z[2 * i]
dw[2 * i - 1] = real(dw2[i])
dw[2 * i] = imag(dw2[i])
i = i + 1
end

# compute norm of residuals
iter > 1 && (ndwprev = ndw)
ndws = Vector{Any}(undef, num_stages)
ndws[1] = calculate_residuals!(atmp, dw1, uprev, u, atol, rtol, internalnorm, t)
ndws[1] = internalnorm(atmp, t)

for i in 2 : num_stages
@show i
calculate_residuals!(atmp, dw[i - 1], uprev, u, atol, rtol, internalnorm, t)
ndws[i] = internalnorm(atmp, t)
end
Expand Down Expand Up @@ -1705,7 +1697,7 @@ end
end

# transform `w` to `z`
z = T * w
z = vec(T * w)
# check stopping criterion

iter > 1 &&= θ / (1 - θ))
Expand Down Expand Up @@ -1773,16 +1765,18 @@ end
if alg.extrapolant != :constant
derivatives = Matrix{Any}(undef, num_stages, num_stages)
pushfirst!(c, 0)
pushfirst!(z, map(zero, u))
for i in 1 : num_stages
for j in i : num_stages
if i == 1
derivatives[i, j] = @.. (z[i] - z[i + 1]) / (c[i] - c[i + 1]) #first derivatives
derivatives[i, j] = @.. (z[j] - z[j + 1]) / (c[j] - c[j + 1]) #first derivatives
else
derivatives[i, j] = @.. (derivatives[i - 1, j - 1] - derivatives[i - 1, j]) / (c[j - i + 1] - c[j + 1]) #all others
end
end
end
popfirst!(c)
popfirst!(z)
for i in 1 : num_stages
cache.cont[i] = derivatives[i, num_stages]
end
Expand Down
39 changes: 35 additions & 4 deletions lib/OrdinaryDiffEqFIRK/src/firk_tableaus.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ struct RadauIIA5Tableau{T, T2}
T22::T
T23::T
T31::T
#T32::T = 1
#T33::T = 0
#T32::T
#T33::T
TI11::T
TI12::T
TI13::T
Expand All @@ -56,7 +56,7 @@ struct RadauIIA5Tableau{T, T2}
TI33::T
c1::T2
c2::T2
#c3::T2 = 1
#c3::T2
γ::T
α::T
β::T
Expand Down Expand Up @@ -110,7 +110,38 @@ function RadauIIA5Tableau(T, T2)
γ, α, β,
e1, e2, e3)
end

#=
function BigRadauIIA5Tableau(T, T2)
T11 = convert(T, 0.091232394870892942791548135249436196118684699372210280712184363514099824021240149574725365814781580305065489937969163922775110463056339192206701819661425186)
T12 = convert(T, -0.141255295020954208427990383807797309409263248498594798844289981408804297900674604638610419147468875667691398225003133444988034605081071965848437945842767211)
T31 = convert(T, -0.0300291941051474244918611170890538666683842974606300802563717702200388818691214144173874588956764952224874407424115249418136547481236684478531215095064078994)
T21 = convert(T, 0.241717932707107018957474779310148232884879540532595279746187345714229132659465207414913313803429072060469564350914390845001169448350326344874859416624577348)
T22 = convert(T, 0.204129352293799931995990810298338174086540402523315938937516234649384944528706774788799548853122282827246947911905379230680096946800308693162079538975632443)
T23 = convert(T, 0.382942112757261937795438233599873210357792575012007744255205163027042915338009760005422153613194350161760232119048691964499888989151661861236831969497483828)
T31 = convert(T, 0.966048182615092936190567080794590794996748754810883844283183333914131408744555961195911605614405476210484499875001737558078500322423463946527349731087504518)
T32 = convert(T, 1.0)
T33 = convert(T, 0.0)
TI11 = convert(T, 4.32557989006315535102435095295614882731995158490590784287320458848019483341979047442263696495019938973156007686663488090615420049217658854859024016717169837)
TI12 = convert(T, 0.339199251815809869542824974053410987511771566126056902312311333553438988409693737874718833892037643701271502187763370262948704203562215007824701228014200056)
TI13 = convert(T, 0.541770539935874871186523033492089631898841317849243944095021379289933921771713116368931784890546144473788347538203807242114936998948954098533375649163016612)
TI21 = convert(T, -4.17871859155190472734646265851205623000038388214686525896709481539843195209360778128456932548583273459040707932166364293012713818843609182148794380267482041)
TI22 = convert(T, -0.327682820761062387082533272429616234245791838308340887801415258608836530255609335712523838667242449344879454518796849992049787172023800373390124427898159896)
TI23 = convert(T, 0.476623554500550451960069084091012497939942928625055897109833707684876604712862299049343675491204859381277636585708398915065951363736337328178192801074535132)
TI31 = convert(T, -0.502872634945786875951247343139544292859248429570937886791036339034110181540695221500843782634464164585836226038438397328726973424362168221527501738985822875)
TI32 = convert(T, 2.57192694985560542918678535360167505469448742842178326395573566888176471664393761903447163100353067504020263109067033226021288356347565113471227052083596358)
TI33 = convert(T, -0.596039204828224924968821911099302403289857517521591823052174732952989090998130905722763344484798508456930766594977798579939415052669401095404149917833710127)
γ = convert(T, 3.63783425274449573220841851357777579794593608687391153215117488565841871456727143375130115708511223004183651123208497057248238260532214672028700625775335843)
α = convert(T, 2.68108287362775213389579074321111210102703195656304423392441255717079064271636428312434942145744388497908174438395751471375880869733892663985649687112332242)
β = convert(T, 3.05043019924741056942637762478756790444070419917947659226291744751211727051786694870515117615266028855554735929171362769761399150862332538376382934625577549)
c1 = convert(T2, 0.155051025721682190180271592529410860803405251934332987156730743274903962254268497346014056689535976518140539877338581087514113454016224265837421604876272084)
c2 = convert(T2, 0.644948974278317809819728407470589139196594748065667012843269256725096037745731502653985943310464023481859460122661418912485886545983775734162578395123729143)
RadauIIA5Tableau{T, T2}(T11, T12, T13, T21, T22, T23, T31,T32, T33,
TI11, TI12, TI13, TI21, TI22, TI23, TI31, TI32, TI33,
c1, c2, #= c3 = 1 =#
γ, α, β,
e1, e2, e3)
end
=#
struct RadauIIA9Tableau{T, T2}
T11::T
T12::T
Expand Down
11 changes: 5 additions & 6 deletions lib/OrdinaryDiffEqFIRK/test/ode_firk_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ for prob in [prob_ode_linear, prob_ode_2Dlinear]
@test sim21.𝒪est[:final]5 atol=testTol
end

sol = solve(prob_ode_2Dlinear, AdaptiveRadau(), adaptive = false, dt = 1e-2)
sol = solve(prob_ode_linear, RadauIIA9(), adaptive = false, dt = 1e-2)
sol = solve(prob_ode_2Dlinear, RadauIIA5(), adaptive = false, dt = 1e-2)
sol = solve(prob_ode_2Dlinear, AdaptiveRadau(num_stages = 5), adaptive = false, dt = 1e-1)
sol = solve(prob_ode_2Dlinear, RadauIIA9(), adaptive = false, dt = 1e-1)
sol = solve(prob_ode_2Dlinear, RadauIIA5(), adaptive = false, dt = 1e-1)

sim21 = test_convergence(1 ./ 2 .^ (2.5:-1:0.5), prob_ode_linear, RadauIIA9())
@test sim21.𝒪est[:final]9 atol=testTol
Expand All @@ -26,9 +26,8 @@ for i in [3, 5, 7, 9]
@test sim21.𝒪est[:final] (2 * i - 1) atol=testTol
end

sol = solve(prob_ode_2Dlinear_big, RadauIIA9(), adaptive=false, dt = 1e-5)
for i in [5]
sim21 = test_convergence(1 ./ 10 .^ (5:-1:3), prob_ode_2Dlinear_big, AdaptiveRadau(num_stages = i))
for i in [3, 5, 7, 9]
sim21 = test_convergence(1 ./ 2 .^ (5:-1:3), prob_ode_2Dlinear_big, AdaptiveRadau(num_stages = i))
@test sim21.𝒪est[:final] (2 * i - 1) atol=testTol
end

Expand Down

0 comments on commit 55259d5

Please sign in to comment.