Skip to content

Commit

Permalink
fix collocation on radauIIA9
Browse files Browse the repository at this point in the history
  • Loading branch information
Shreyas-Ekanathan committed Aug 23, 2024
1 parent 51d2012 commit 9eb4538
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 66 deletions.
2 changes: 1 addition & 1 deletion lib/OrdinaryDiffEqFIRK/src/alg_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ qmax_default(alg::Union{RadauIIA3, RadauIIA5, RadauIIA9}) = 8
alg_order(alg::RadauIIA3) = 3
alg_order(alg::RadauIIA5) = 5
alg_order(alg::RadauIIA9) = 9
alg_order(alg::AdaptiveRadau) = 9
alg_order(alg::AdaptiveRadau) = 5

isfirk(alg::RadauIIA3) = true
isfirk(alg::RadauIIA5) = true
Expand Down
7 changes: 5 additions & 2 deletions lib/OrdinaryDiffEqFIRK/src/firk_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,7 @@ mutable struct RadauIIA9ConstantCache{F, Tab, Tol, Dt, U, JType} <:
cont2::U
cont3::U
cont4::U
cont5::U
dtprev::Dt
W_γdt::Dt
status::NLStatus
Expand All @@ -304,7 +305,7 @@ function alg_cache(alg::RadauIIA9, u, rate_prototype, ::Type{uEltypeNoUnits},
κ = alg.κ !== nothing ? convert(uToltype, alg.κ) : convert(uToltype, 1 // 100)
J = false .* _vec(rate_prototype) .* _vec(rate_prototype)'

RadauIIA9ConstantCache(uf, tab, κ, one(uToltype), 10000, u, u, u, u, dt, dt,
RadauIIA9ConstantCache(uf, tab, κ, one(uToltype), 10000, u, u, u, u, u, dt, dt,
Convergence, J)
end

Expand Down Expand Up @@ -333,6 +334,7 @@ mutable struct RadauIIA9Cache{uType, cuType, uNoUnitsType, rateType, JType, W1Ty
cont2::uType
cont3::uType
cont4::uType
cont5::uType
du1::rateType
fsalfirst::rateType
k::rateType
Expand Down Expand Up @@ -407,6 +409,7 @@ function alg_cache(alg::RadauIIA9, u, rate_prototype, ::Type{uEltypeNoUnits},
cont2 = zero(u)
cont3 = zero(u)
cont4 = zero(u)
cont5 = zero(u)

fsalfirst = zero(rate_prototype)
k = zero(rate_prototype)
Expand Down Expand Up @@ -462,7 +465,7 @@ function alg_cache(alg::RadauIIA9, u, rate_prototype, ::Type{uEltypeNoUnits},

RadauIIA9Cache(u, uprev,
z1, z2, z3, z4, z5, w1, w2, w3, w4, w5,
dw1, ubuff, dw23, dw45, cubuff1, cubuff2, cont1, cont2, cont3, cont4,
dw1, ubuff, dw23, dw45, cubuff1, cubuff2, cont1, cont2, cont3, cont4, cont5,
du1, fsalfirst, k, k2, k3, k4, k5, fw1, fw2, fw3, fw4, fw5,
J, W1, W2, W3,
uf, tab, κ, one(uToltype), 10000,
Expand Down
127 changes: 73 additions & 54 deletions lib/OrdinaryDiffEqFIRK/src/firk_perform_step.jl
Original file line number Diff line number Diff line change
Expand Up @@ -723,9 +723,9 @@ end
if integrator.EEst <= oneunit(integrator.EEst)
cache.dtprev = dt
if alg.extrapolant != :constant
@.. broadcast=false cache.cont1=(z2 - z3) / c2m1
@.. broadcast=false tmp=(z1 - z2) / c1mc2
@.. broadcast=false cache.cont2=(tmp - cache.cont1) / c1m1
@.. broadcast=false cache.cont1=(z2 - z3) / c2m1
@.. broadcast=false tmp=(z1 - z2) / c1mc2
@.. broadcast=false cache.cont2=(tmp - cache.cont1) / c1m1
@.. broadcast=false cache.cont3=cache.cont2 - (tmp - z1 / c1) / c2
end
end
Expand All @@ -740,7 +740,7 @@ end
@unpack T11, T12, T13, T14, T15, T21, T22, T23, T24, T25, T31, T32, T33, T34, T35, T41, T42, T43, T44, T45, T51 = cache.tab #= T52 = 1, T53 = 0, T54 = 1, T55 = 0=#
@unpack TI11, TI12, TI13, TI14, TI15, TI21, TI22, TI23, TI24, TI25, TI31, TI32, TI33, TI34, TI35, TI41, TI42, TI43, TI44, TI45, TI51, TI52, TI53, TI54, TI55 = cache.tab
@unpack c1, c2, c3, c4, γ, α1, β1, α2, β2, e1, e2, e3, e4, e5 = cache.tab
@unpack κ, cont1, cont2, cont3, cont4 = cache
@unpack κ, cont1, cont2, cont3, cont4, cont5 = cache
@unpack internalnorm, abstol, reltol, adaptive = integrator.opts
alg = unwrap_alg(integrator, true)
@unpack maxiters = alg
Expand Down Expand Up @@ -785,27 +785,32 @@ end
cache.cont2 = map(zero, u)
cache.cont3 = map(zero, u)
cache.cont4 = map(zero, u)
cache.cont5 = map(zero, u)
else
c5′ = dt / cache.dtprev
c1′ = c1 * c5′
c2′ = c2 * c5′
c3′ = c3 * c5′
c4′ = c4 * c5′
z1 = @.. broadcast=false c1′*(cont1 +
(c1′ - c3m1) * (cont2 +
(c1′ - c2m1) * (cont3 + (c1′ - c1m1) * cont4)))
z2 = @.. broadcast=false c2′*(cont1 +
(c2′ - c3m1) * (cont2 +
(c2′ - c2m1) * (cont3 + (c2′ - c1m1) * cont4)))
z3 = @.. broadcast=false c3′*(cont1 +
(c3′ - c3m1) * (cont2 +
(c3′ - c2m1) * (cont3 + (c3′ - c1m1) * cont4)))
z4 = @.. broadcast=false c4′*(cont1 +
(c4′ - c3m1) * (cont2 +
(c4′ - c2m1) * (cont3 + (c4′ - c1m1) * cont4)))
z5 = @.. broadcast=false c5′*(cont1 +
(c5′ - c3m1) * (cont2 +
(c5′ - c2m1) * (cont3 + (c5′ - c1m1) * cont4)))
z1 = @.. c1′ * (cont1 +
(c1′-c4m1) * (cont2 +
(c1′ - c3m1) * (cont3 +
(c1′ - c2m1) * (cont4 + (c1′ - c1m1) * cont5))))
z2 = @.. c2′ * (cont1 +
(c2′-c4m1) * (cont2 +
(c2′ - c3m1) * (cont3 +
(c2′ - c2m1) * (cont4 + (c2′ - c1m1) * cont5))))
z3 = @.. c3′ * (cont1 +
(c3′-c4m1) * (cont2 +
(c3′ - c3m1) * (cont3 +
(c3′ - c2m1) * (cont4 + (c3′ - c1m1) * cont5))))
z4 = @.. c4′ * (cont1 +
(c4′-c4m1) * (cont2 +
(c4′ - c3m1) * (cont3 +
(c4′ - c2m1) * (cont4 + (c4′ - c1m1) * cont5))))
z5 = @.. c5′ * (cont1 +
(c5′-c4m1) * (cont2 +
(c5′ - c3m1) * (cont3 + (c5′ - c2m1) * (cont4 + (c5′ - c1m1) * cont5))))
w1 = @.. broadcast=false TI11*z1+TI12*z2+TI13*z3+TI14*z4+TI15*z5
w2 = @.. broadcast=false TI21*z1+TI22*z2+TI23*z3+TI24*z4+TI25*z5
w3 = @.. broadcast=false TI31*z1+TI32*z2+TI33*z3+TI34*z4+TI35*z5
Expand Down Expand Up @@ -898,7 +903,7 @@ end
z3 = @.. broadcast=false T31*w1+T32*w2+T33*w3+T34*w4+T35*w5
z4 = @.. broadcast=false T41*w1+T42*w2+T43*w3+T44*w4+T45*w5
z5 = @.. broadcast=false T51*w1+w2+w4 #= T52=1, T53=0, T54=1, T55=0 =#
@show z1

# check stopping criterion
iter > 1 &&= θ / (1 - θ))
if η * ndw < κ && (iter > 1 || iszero(ndw) || !iszero(integrator.success_iter))
Expand Down Expand Up @@ -953,6 +958,11 @@ end
tmp5 = @.. (tmp4 - tmp2) / c1mc3 # second derivative on [c1, c3]
tmp6 = @.. (tmp5 - tmp3) / c1mc4 # third derivative on [c1, c4]
cache.cont4 = @.. (tmp6 - cache.cont3) / c1m1 #fourth derivative on [c1, 1]
tmp7 = @.. z1 / c1 #first derivative on [0, c1]
tmp8 = @.. (tmp4 - tmp7) / c2 #second derivative on [0, c2]
tmp9 = @.. (tmp5 - tmp8) / c3 #third derivative on [0, c3]
tmp10 = @.. (tmp6 - tmp9) / c4 #fourth derivative on [0,c4]
cache.cont5 = @.. cache.cont4 - tmp10 #fifth derivative on [0,1]
end
end

Expand All @@ -969,7 +979,7 @@ end
@unpack T11, T12, T13, T14, T15, T21, T22, T23, T24, T25, T31, T32, T33, T34, T35, T41, T42, T43, T44, T45, T51 = cache.tab #= T52 = 1, T53 = 0, T54 = 1, T55 = 0=#
@unpack TI11, TI12, TI13, TI14, TI15, TI21, TI22, TI23, TI24, TI25, TI31, TI32, TI33, TI34, TI35, TI41, TI42, TI43, TI44, TI45, TI51, TI52, TI53, TI54, TI55 = cache.tab
@unpack c1, c2, c3, c4, γ, α1, β1, α2, β2, e1, e2, e3, e4, e5 = cache.tab
@unpack κ, cont1, cont2, cont3, cont4 = cache
@unpack κ, cont1, cont2, cont3, cont4, cont5 = cache
@unpack z1, z2, z3, z4, z5, w1, w2, w3, w4, w5 = cache
@unpack dw1, ubuff, dw23, dw45, cubuff1, cubuff2 = cache
@unpack k, k2, k3, k4, k5, fw1, fw2, fw3, fw4, fw5 = cache
Expand Down Expand Up @@ -1022,32 +1032,37 @@ end
@.. broadcast=false cache.cont2=uzero
@.. broadcast=false cache.cont3=uzero
@.. broadcast=false cache.cont4=uzero
@.. broadcast=false cache.cont5=uzero
else
c5′ = dt / cache.dtprev
c1′ = c1 * c5′
c2′ = c2 * c5′
c3′ = c3 * c5′
c4′ = c4 * c5′
@.. broadcast=false z1 = c1′*(cont1 +
(c1′ - c3m1) * (cont2 +
(c1′ - c2m1) * (cont3 + (c1′ - c1m1) * cont4)))
@.. broadcast=false z2 = c2′*(cont1 +
(c2′ - c3m1) * (cont2 +
(c2′ - c2m1) * (cont3 + (c2′ - c1m1) * cont4)))
@.. broadcast=false z3 = c3′*(cont1 +
(c3′ - c3m1) * (cont2 +
(c3′ - c2m1) * (cont3 + (c3′ - c1m1) * cont4)))
@.. broadcast=false z4 = c4′*(cont1 +
(c4′ - c3m1) * (cont2 +
(c4′ - c2m1) * (cont3 + (c4′ - c1m1) * cont4)))
@.. broadcast=false z5 = c5′*(cont1 +
(c5′ - c3m1) * (cont2 +
(c5′ - c2m1) * (cont3 + (c5′ - c1m1) * cont4)))
@.. broadcast=false w1 = TI11*z1+TI12*z2+TI13*z3+TI14*z4+TI15*z5
@.. broadcast=false w2 = TI21*z1+TI22*z2+TI23*z3+TI24*z4+TI25*z5
@.. broadcast=false w3 = TI31*z1+TI32*z2+TI33*z3+TI34*z4+TI35*z5
@.. broadcast=false w4 = TI41*z1+TI42*z2+TI43*z3+TI44*z4+TI45*z5
@.. broadcast=false w5 = TI51*z1+TI52*z2+TI53*z3+TI54*z4+TI55*z5
z1 = @.. c1′ * (cont1 +
(c1′-c4m1) * (cont2 +
(c1′ - c3m1) * (cont3 +
(c1′ - c2m1) * (cont4 + (c1′ - c1m1) * cont5))))
z2 = @.. c2′ * (cont1 +
(c2′-c4m1) * (cont2 +
(c2′ - c3m1) * (cont3 +
(c2′ - c2m1) * (cont4 + (c2′ - c1m1) * cont5))))
z3 = @.. c3′ * (cont1 +
(c3′-c4m1) * (cont2 +
(c3′ - c3m1) * (cont3 +
(c3′ - c2m1) * (cont4 + (c3′ - c1m1) * cont5))))
z4 = @.. c4′ * (cont1 +
(c4′-c4m1) * (cont2 +
(c4′ - c3m1) * (cont3 +
(c4′ - c2m1) * (cont4 + (c4′ - c1m1) * cont5))))
z5 = @.. c5′ * (cont1 +
(c5′-c4m1) * (cont2 +
(c5′ - c3m1) * (cont3 + (c5′ - c2m1) * (cont4 + (c5′ - c1m1) * cont5))))
w1 = @.. broadcast=false TI11*z1+TI12*z2+TI13*z3+TI14*z4+TI15*z5
w2 = @.. broadcast=false TI21*z1+TI22*z2+TI23*z3+TI24*z4+TI25*z5
w3 = @.. broadcast=false TI31*z1+TI32*z2+TI33*z3+TI34*z4+TI35*z5
w4 = @.. broadcast=false TI41*z1+TI42*z2+TI43*z3+TI44*z4+TI45*z5
w5 = @.. broadcast=false TI51*z1+TI52*z2+TI53*z3+TI54*z4+TI55*z5
end

# Newton iteration
Expand Down Expand Up @@ -1272,16 +1287,21 @@ end
if integrator.EEst <= oneunit(integrator.EEst)
cache.dtprev = dt
if alg.extrapolant != :constant
@.. cache.cont1 = (z4 - z5) / c4m1
@.. tmp = (z3 - z4) / c3mc4
@.. cache.cont2 = (tmp - cache.cont1) / c3m1
@.. tmp2 = (z2 - z3) / c2mc3
@.. tmp3 = (tmp2 - tmp) / c2mc4
@.. cache.cont3 = (tmp3 - cache.cont2) / c2m1
@.. tmp4 = (z1 - z2) / c1mc2
@.. tmp5 = (tmp4 - tmp2) / c1mc3
@.. tmp6 = (tmp5 - tmp3) / c1mc4
@.. cache.cont4 = (tmp6 - cache.cont3) / c1m1
cache.cont1 = @.. (z4 - z5) / c4m1 # first derivative on [c4, 1]
tmp1 = @.. (z3 - z4) / c3mc4 # first derivative on [c3, c4]
cache.cont2 = @.. (tmp1 - cache.cont1) / c3m1 # second derivative on [c3, 1]
tmp2 = @.. (z2 - z3) / c2mc3 # first derivative on [c2, c3]
tmp3 = @.. (tmp2 - tmp1) / c2mc4 # second derivative on [c2, c4]
cache.cont3 = @.. (tmp3 - cache.cont2) / c2m1 # third derivative on [c2, 1]
tmp4 = @.. (z1 - z2) / c1mc2 # first derivative on [c1, c2]
tmp5 = @.. (tmp4 - tmp2) / c1mc3 # second derivative on [c1, c3]
tmp6 = @.. (tmp5 - tmp3) / c1mc4 # third derivative on [c1, c4]
cache.cont4 = @.. (tmp6 - cache.cont3) / c1m1 #fourth derivative on [c1, 1]
tmp7 = @.. z1 / c1 #first derivative on [0, c1]
tmp8 = @.. (tmp4 - tmp7) / c2 #second derivative on [0, c2]
tmp9 = @.. (tmp5 - tmp8) / c3 #third derivative on [0, c3]
tmp10 = @.. (tmp6 - tmp9) / c4 #fourth derivative on [0,c4]
cache.cont5 = @.. cache.cont4 - tmp10 #fifth derivative on [0,1]
end
end

Expand Down Expand Up @@ -1363,7 +1383,7 @@ end
end
integrator.stats.nf += num_stages

fw = @.. TI * ff
fw = TI * ff
Mw = Vector{eltype(u)}(undef, num_stages)
if mass_matrix isa UniformScaling # `UniformScaling` doesn't play nicely with broadcast
for i in 1 : num_stages
Expand Down Expand Up @@ -1417,8 +1437,7 @@ end
w = @.. w - dw

# transform `w` to `z`
z = @.. T * w
@show z[1]
z = T * w

# check stopping criterion
iter > 1 &&= θ / (1 - θ))
Expand Down
12 changes: 7 additions & 5 deletions lib/OrdinaryDiffEqFIRK/src/firk_tableaus.jl
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ function adaptiveRadauTableau(T, T2, num_stages::Int)
b[i] = a[num_stages, i]
end
vals = eigvals(a_inverse)
γ = real(b[num_stages])
γ = real(vals[num_stages])
α = Vector{BigFloat}(undef, floor(Int, num_stages/2))
β = Vector{BigFloat}(undef, floor(Int, num_stages/2))
index = 1
Expand Down Expand Up @@ -338,8 +338,10 @@ function adaptiveRadauTableau(T, T2, num_stages::Int)
end
end
TI = inv(T)

#=
p = num_stages
eb = variables(:b, 1:num_stages + 1)
@variables y
zz = zeros(size(a, 1) + 1)
zz2 = zeros(size(a, 1))
eA = [zz'
Expand All @@ -348,12 +350,12 @@ function adaptiveRadauTableau(T, T2, num_stages::Int)
constraints = map(Iterators.flatten(RootedTreeIterator(i) for i in 1:p)) do t
residual_order_condition(t, RungeKuttaMethod(eA, eb, ec))
end
AA, bb, islinear = Symbolics.linear_expansion(substitute.(constraints, (eb[1]=>γ,)), eb[2:end])
AA = Float64.(map(unwrap, AA))
AA, bb, islinear = Symbolics.linear_expansion(substitute.(constraints, (eb[1]=>y,)), eb[2:end])
AA = BigFloat.(map(unwrap, AA))
idxs = qr(AA', ColumnNorm()).p[1:num_stages]
@assert rank(AA[idxs, :]) == num_stages
@assert islinear
Symbolics.expand.((AA[idxs, :] \ -bb[idxs]) - b)
Symbolics.expand.((AA[idxs, :] \ -bb[idxs]) - b)=#
#e = b_hat - b
adaptiveRadauTableau{Any, T2, Int}(T, TI, γ, α, β, c, num_stages)
end
7 changes: 3 additions & 4 deletions lib/OrdinaryDiffEqFIRK/test/ode_firk_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,11 @@ sol = solve(prob_ode_linear, AdaptiveRadau(), adaptive = false, dt = 1e-2)
sol = solve(prob_ode_linear, RadauIIA9(), adaptive = false, dt = 1e-2)
sol = solve(prob_ode_linear, RadauIIA5(), adaptive = false, dt = 1e-2)


sim21 = test_convergence(1 ./ 10 .^ (4.5:-1:2.5), prob_ode_linear, AdaptiveRadau())
@test sim21.𝒪est[:final]8 atol=testTol
sim21 = test_convergence(1 ./ 2 .^ (2.5:-1:0.5), prob_ode_linear, RadauIIA9())
@test sim21.𝒪est[:final]9 atol=testTol

sim21 = test_convergence(1 ./ 2 .^ (2.5:-1:0.5), prob_ode_2Dlinear, RadauIIA9())
@test sim21.𝒪est[:final]8 atol=testTol
@test sim21.𝒪est[:final]9 atol=testTol

# test adaptivity
for iip in (true, false)
Expand Down

0 comments on commit 9eb4538

Please sign in to comment.