diff --git a/lib/OrdinaryDiffEqFIRK/src/alg_utils.jl b/lib/OrdinaryDiffEqFIRK/src/alg_utils.jl index f8ed7b9b63..428393e0bf 100644 --- a/lib/OrdinaryDiffEqFIRK/src/alg_utils.jl +++ b/lib/OrdinaryDiffEqFIRK/src/alg_utils.jl @@ -3,7 +3,7 @@ qmax_default(alg::Union{RadauIIA3, RadauIIA5, RadauIIA9}) = 8 alg_order(alg::RadauIIA3) = 3 alg_order(alg::RadauIIA5) = 5 alg_order(alg::RadauIIA9) = 9 -alg_order(alg::AdaptiveRadau) = 9 +alg_order(alg::AdaptiveRadau) = 5 isfirk(alg::RadauIIA3) = true isfirk(alg::RadauIIA5) = true diff --git a/lib/OrdinaryDiffEqFIRK/src/firk_caches.jl b/lib/OrdinaryDiffEqFIRK/src/firk_caches.jl index e17f486570..035c1f6815 100644 --- a/lib/OrdinaryDiffEqFIRK/src/firk_caches.jl +++ b/lib/OrdinaryDiffEqFIRK/src/firk_caches.jl @@ -287,6 +287,7 @@ mutable struct RadauIIA9ConstantCache{F, Tab, Tol, Dt, U, JType} <: cont2::U cont3::U cont4::U + cont5::U dtprev::Dt W_γdt::Dt status::NLStatus @@ -304,7 +305,7 @@ function alg_cache(alg::RadauIIA9, u, rate_prototype, ::Type{uEltypeNoUnits}, κ = alg.κ !== nothing ? convert(uToltype, alg.κ) : convert(uToltype, 1 // 100) J = false .* _vec(rate_prototype) .* _vec(rate_prototype)' - RadauIIA9ConstantCache(uf, tab, κ, one(uToltype), 10000, u, u, u, u, dt, dt, + RadauIIA9ConstantCache(uf, tab, κ, one(uToltype), 10000, u, u, u, u, u, dt, dt, Convergence, J) end @@ -333,6 +334,7 @@ mutable struct RadauIIA9Cache{uType, cuType, uNoUnitsType, rateType, JType, W1Ty cont2::uType cont3::uType cont4::uType + cont5::uType du1::rateType fsalfirst::rateType k::rateType @@ -407,6 +409,7 @@ function alg_cache(alg::RadauIIA9, u, rate_prototype, ::Type{uEltypeNoUnits}, cont2 = zero(u) cont3 = zero(u) cont4 = zero(u) + cont5 = zero(u) fsalfirst = zero(rate_prototype) k = zero(rate_prototype) @@ -462,7 +465,7 @@ function alg_cache(alg::RadauIIA9, u, rate_prototype, ::Type{uEltypeNoUnits}, RadauIIA9Cache(u, uprev, z1, z2, z3, z4, z5, w1, w2, w3, w4, w5, - dw1, ubuff, dw23, dw45, cubuff1, cubuff2, cont1, cont2, cont3, cont4, + dw1, ubuff, dw23, dw45, cubuff1, cubuff2, cont1, cont2, cont3, cont4, cont5, du1, fsalfirst, k, k2, k3, k4, k5, fw1, fw2, fw3, fw4, fw5, J, W1, W2, W3, uf, tab, κ, one(uToltype), 10000, diff --git a/lib/OrdinaryDiffEqFIRK/src/firk_perform_step.jl b/lib/OrdinaryDiffEqFIRK/src/firk_perform_step.jl index 1f0b418970..012f20f904 100644 --- a/lib/OrdinaryDiffEqFIRK/src/firk_perform_step.jl +++ b/lib/OrdinaryDiffEqFIRK/src/firk_perform_step.jl @@ -723,9 +723,9 @@ end if integrator.EEst <= oneunit(integrator.EEst) cache.dtprev = dt if alg.extrapolant != :constant - @.. broadcast=false cache.cont1=(z2 - z3) / c2m1 - @.. broadcast=false tmp=(z1 - z2) / c1mc2 - @.. broadcast=false cache.cont2=(tmp - cache.cont1) / c1m1 + @.. broadcast=false cache.cont1=(z2 - z3) / c2m1 + @.. broadcast=false tmp=(z1 - z2) / c1mc2 + @.. broadcast=false cache.cont2=(tmp - cache.cont1) / c1m1 @.. broadcast=false cache.cont3=cache.cont2 - (tmp - z1 / c1) / c2 end end @@ -740,7 +740,7 @@ end @unpack T11, T12, T13, T14, T15, T21, T22, T23, T24, T25, T31, T32, T33, T34, T35, T41, T42, T43, T44, T45, T51 = cache.tab #= T52 = 1, T53 = 0, T54 = 1, T55 = 0=# @unpack TI11, TI12, TI13, TI14, TI15, TI21, TI22, TI23, TI24, TI25, TI31, TI32, TI33, TI34, TI35, TI41, TI42, TI43, TI44, TI45, TI51, TI52, TI53, TI54, TI55 = cache.tab @unpack c1, c2, c3, c4, γ, α1, β1, α2, β2, e1, e2, e3, e4, e5 = cache.tab - @unpack κ, cont1, cont2, cont3, cont4 = cache + @unpack κ, cont1, cont2, cont3, cont4, cont5 = cache @unpack internalnorm, abstol, reltol, adaptive = integrator.opts alg = unwrap_alg(integrator, true) @unpack maxiters = alg @@ -785,27 +785,32 @@ end cache.cont2 = map(zero, u) cache.cont3 = map(zero, u) cache.cont4 = map(zero, u) + cache.cont5 = map(zero, u) else c5′ = dt / cache.dtprev c1′ = c1 * c5′ c2′ = c2 * c5′ c3′ = c3 * c5′ c4′ = c4 * c5′ - z1 = @.. broadcast=false c1′*(cont1 + - (c1′ - c3m1) * (cont2 + - (c1′ - c2m1) * (cont3 + (c1′ - c1m1) * cont4))) - z2 = @.. broadcast=false c2′*(cont1 + - (c2′ - c3m1) * (cont2 + - (c2′ - c2m1) * (cont3 + (c2′ - c1m1) * cont4))) - z3 = @.. broadcast=false c3′*(cont1 + - (c3′ - c3m1) * (cont2 + - (c3′ - c2m1) * (cont3 + (c3′ - c1m1) * cont4))) - z4 = @.. broadcast=false c4′*(cont1 + - (c4′ - c3m1) * (cont2 + - (c4′ - c2m1) * (cont3 + (c4′ - c1m1) * cont4))) - z5 = @.. broadcast=false c5′*(cont1 + - (c5′ - c3m1) * (cont2 + - (c5′ - c2m1) * (cont3 + (c5′ - c1m1) * cont4))) + z1 = @.. c1′ * (cont1 + + (c1′-c4m1) * (cont2 + + (c1′ - c3m1) * (cont3 + + (c1′ - c2m1) * (cont4 + (c1′ - c1m1) * cont5)))) + z2 = @.. c2′ * (cont1 + + (c2′-c4m1) * (cont2 + + (c2′ - c3m1) * (cont3 + + (c2′ - c2m1) * (cont4 + (c2′ - c1m1) * cont5)))) + z3 = @.. c3′ * (cont1 + + (c3′-c4m1) * (cont2 + + (c3′ - c3m1) * (cont3 + + (c3′ - c2m1) * (cont4 + (c3′ - c1m1) * cont5)))) + z4 = @.. c4′ * (cont1 + + (c4′-c4m1) * (cont2 + + (c4′ - c3m1) * (cont3 + + (c4′ - c2m1) * (cont4 + (c4′ - c1m1) * cont5)))) + z5 = @.. c5′ * (cont1 + + (c5′-c4m1) * (cont2 + + (c5′ - c3m1) * (cont3 + (c5′ - c2m1) * (cont4 + (c5′ - c1m1) * cont5)))) w1 = @.. broadcast=false TI11*z1+TI12*z2+TI13*z3+TI14*z4+TI15*z5 w2 = @.. broadcast=false TI21*z1+TI22*z2+TI23*z3+TI24*z4+TI25*z5 w3 = @.. broadcast=false TI31*z1+TI32*z2+TI33*z3+TI34*z4+TI35*z5 @@ -898,7 +903,7 @@ end z3 = @.. broadcast=false T31*w1+T32*w2+T33*w3+T34*w4+T35*w5 z4 = @.. broadcast=false T41*w1+T42*w2+T43*w3+T44*w4+T45*w5 z5 = @.. broadcast=false T51*w1+w2+w4 #= T52=1, T53=0, T54=1, T55=0 =# - @show z1 + # check stopping criterion iter > 1 && (η = θ / (1 - θ)) if η * ndw < κ && (iter > 1 || iszero(ndw) || !iszero(integrator.success_iter)) @@ -953,6 +958,11 @@ end tmp5 = @.. (tmp4 - tmp2) / c1mc3 # second derivative on [c1, c3] tmp6 = @.. (tmp5 - tmp3) / c1mc4 # third derivative on [c1, c4] cache.cont4 = @.. (tmp6 - cache.cont3) / c1m1 #fourth derivative on [c1, 1] + tmp7 = @.. z1 / c1 #first derivative on [0, c1] + tmp8 = @.. (tmp4 - tmp7) / c2 #second derivative on [0, c2] + tmp9 = @.. (tmp5 - tmp8) / c3 #third derivative on [0, c3] + tmp10 = @.. (tmp6 - tmp9) / c4 #fourth derivative on [0,c4] + cache.cont5 = @.. cache.cont4 - tmp10 #fifth derivative on [0,1] end end @@ -969,7 +979,7 @@ end @unpack T11, T12, T13, T14, T15, T21, T22, T23, T24, T25, T31, T32, T33, T34, T35, T41, T42, T43, T44, T45, T51 = cache.tab #= T52 = 1, T53 = 0, T54 = 1, T55 = 0=# @unpack TI11, TI12, TI13, TI14, TI15, TI21, TI22, TI23, TI24, TI25, TI31, TI32, TI33, TI34, TI35, TI41, TI42, TI43, TI44, TI45, TI51, TI52, TI53, TI54, TI55 = cache.tab @unpack c1, c2, c3, c4, γ, α1, β1, α2, β2, e1, e2, e3, e4, e5 = cache.tab - @unpack κ, cont1, cont2, cont3, cont4 = cache + @unpack κ, cont1, cont2, cont3, cont4, cont5 = cache @unpack z1, z2, z3, z4, z5, w1, w2, w3, w4, w5 = cache @unpack dw1, ubuff, dw23, dw45, cubuff1, cubuff2 = cache @unpack k, k2, k3, k4, k5, fw1, fw2, fw3, fw4, fw5 = cache @@ -1022,32 +1032,37 @@ end @.. broadcast=false cache.cont2=uzero @.. broadcast=false cache.cont3=uzero @.. broadcast=false cache.cont4=uzero + @.. broadcast=false cache.cont5=uzero else c5′ = dt / cache.dtprev c1′ = c1 * c5′ c2′ = c2 * c5′ c3′ = c3 * c5′ c4′ = c4 * c5′ - @.. broadcast=false z1 = c1′*(cont1 + - (c1′ - c3m1) * (cont2 + - (c1′ - c2m1) * (cont3 + (c1′ - c1m1) * cont4))) - @.. broadcast=false z2 = c2′*(cont1 + - (c2′ - c3m1) * (cont2 + - (c2′ - c2m1) * (cont3 + (c2′ - c1m1) * cont4))) - @.. broadcast=false z3 = c3′*(cont1 + - (c3′ - c3m1) * (cont2 + - (c3′ - c2m1) * (cont3 + (c3′ - c1m1) * cont4))) - @.. broadcast=false z4 = c4′*(cont1 + - (c4′ - c3m1) * (cont2 + - (c4′ - c2m1) * (cont3 + (c4′ - c1m1) * cont4))) - @.. broadcast=false z5 = c5′*(cont1 + - (c5′ - c3m1) * (cont2 + - (c5′ - c2m1) * (cont3 + (c5′ - c1m1) * cont4))) - @.. broadcast=false w1 = TI11*z1+TI12*z2+TI13*z3+TI14*z4+TI15*z5 - @.. broadcast=false w2 = TI21*z1+TI22*z2+TI23*z3+TI24*z4+TI25*z5 - @.. broadcast=false w3 = TI31*z1+TI32*z2+TI33*z3+TI34*z4+TI35*z5 - @.. broadcast=false w4 = TI41*z1+TI42*z2+TI43*z3+TI44*z4+TI45*z5 - @.. broadcast=false w5 = TI51*z1+TI52*z2+TI53*z3+TI54*z4+TI55*z5 + z1 = @.. c1′ * (cont1 + + (c1′-c4m1) * (cont2 + + (c1′ - c3m1) * (cont3 + + (c1′ - c2m1) * (cont4 + (c1′ - c1m1) * cont5)))) + z2 = @.. c2′ * (cont1 + + (c2′-c4m1) * (cont2 + + (c2′ - c3m1) * (cont3 + + (c2′ - c2m1) * (cont4 + (c2′ - c1m1) * cont5)))) + z3 = @.. c3′ * (cont1 + + (c3′-c4m1) * (cont2 + + (c3′ - c3m1) * (cont3 + + (c3′ - c2m1) * (cont4 + (c3′ - c1m1) * cont5)))) + z4 = @.. c4′ * (cont1 + + (c4′-c4m1) * (cont2 + + (c4′ - c3m1) * (cont3 + + (c4′ - c2m1) * (cont4 + (c4′ - c1m1) * cont5)))) + z5 = @.. c5′ * (cont1 + + (c5′-c4m1) * (cont2 + + (c5′ - c3m1) * (cont3 + (c5′ - c2m1) * (cont4 + (c5′ - c1m1) * cont5)))) + w1 = @.. broadcast=false TI11*z1+TI12*z2+TI13*z3+TI14*z4+TI15*z5 + w2 = @.. broadcast=false TI21*z1+TI22*z2+TI23*z3+TI24*z4+TI25*z5 + w3 = @.. broadcast=false TI31*z1+TI32*z2+TI33*z3+TI34*z4+TI35*z5 + w4 = @.. broadcast=false TI41*z1+TI42*z2+TI43*z3+TI44*z4+TI45*z5 + w5 = @.. broadcast=false TI51*z1+TI52*z2+TI53*z3+TI54*z4+TI55*z5 end # Newton iteration @@ -1272,16 +1287,21 @@ end if integrator.EEst <= oneunit(integrator.EEst) cache.dtprev = dt if alg.extrapolant != :constant - @.. cache.cont1 = (z4 - z5) / c4m1 - @.. tmp = (z3 - z4) / c3mc4 - @.. cache.cont2 = (tmp - cache.cont1) / c3m1 - @.. tmp2 = (z2 - z3) / c2mc3 - @.. tmp3 = (tmp2 - tmp) / c2mc4 - @.. cache.cont3 = (tmp3 - cache.cont2) / c2m1 - @.. tmp4 = (z1 - z2) / c1mc2 - @.. tmp5 = (tmp4 - tmp2) / c1mc3 - @.. tmp6 = (tmp5 - tmp3) / c1mc4 - @.. cache.cont4 = (tmp6 - cache.cont3) / c1m1 + cache.cont1 = @.. (z4 - z5) / c4m1 # first derivative on [c4, 1] + tmp1 = @.. (z3 - z4) / c3mc4 # first derivative on [c3, c4] + cache.cont2 = @.. (tmp1 - cache.cont1) / c3m1 # second derivative on [c3, 1] + tmp2 = @.. (z2 - z3) / c2mc3 # first derivative on [c2, c3] + tmp3 = @.. (tmp2 - tmp1) / c2mc4 # second derivative on [c2, c4] + cache.cont3 = @.. (tmp3 - cache.cont2) / c2m1 # third derivative on [c2, 1] + tmp4 = @.. (z1 - z2) / c1mc2 # first derivative on [c1, c2] + tmp5 = @.. (tmp4 - tmp2) / c1mc3 # second derivative on [c1, c3] + tmp6 = @.. (tmp5 - tmp3) / c1mc4 # third derivative on [c1, c4] + cache.cont4 = @.. (tmp6 - cache.cont3) / c1m1 #fourth derivative on [c1, 1] + tmp7 = @.. z1 / c1 #first derivative on [0, c1] + tmp8 = @.. (tmp4 - tmp7) / c2 #second derivative on [0, c2] + tmp9 = @.. (tmp5 - tmp8) / c3 #third derivative on [0, c3] + tmp10 = @.. (tmp6 - tmp9) / c4 #fourth derivative on [0,c4] + cache.cont5 = @.. cache.cont4 - tmp10 #fifth derivative on [0,1] end end @@ -1363,7 +1383,7 @@ end end integrator.stats.nf += num_stages - fw = @.. TI * ff + fw = TI * ff Mw = Vector{eltype(u)}(undef, num_stages) if mass_matrix isa UniformScaling # `UniformScaling` doesn't play nicely with broadcast for i in 1 : num_stages @@ -1417,8 +1437,7 @@ end w = @.. w - dw # transform `w` to `z` - z = @.. T * w - @show z[1] + z = T * w # check stopping criterion iter > 1 && (η = θ / (1 - θ)) diff --git a/lib/OrdinaryDiffEqFIRK/src/firk_tableaus.jl b/lib/OrdinaryDiffEqFIRK/src/firk_tableaus.jl index a0b7404476..a7934174ce 100644 --- a/lib/OrdinaryDiffEqFIRK/src/firk_tableaus.jl +++ b/lib/OrdinaryDiffEqFIRK/src/firk_tableaus.jl @@ -308,7 +308,7 @@ function adaptiveRadauTableau(T, T2, num_stages::Int) b[i] = a[num_stages, i] end vals = eigvals(a_inverse) - γ = real(b[num_stages]) + γ = real(vals[num_stages]) α = Vector{BigFloat}(undef, floor(Int, num_stages/2)) β = Vector{BigFloat}(undef, floor(Int, num_stages/2)) index = 1 @@ -338,8 +338,10 @@ function adaptiveRadauTableau(T, T2, num_stages::Int) end end TI = inv(T) - + #= + p = num_stages eb = variables(:b, 1:num_stages + 1) + @variables y zz = zeros(size(a, 1) + 1) zz2 = zeros(size(a, 1)) eA = [zz' @@ -348,12 +350,12 @@ function adaptiveRadauTableau(T, T2, num_stages::Int) constraints = map(Iterators.flatten(RootedTreeIterator(i) for i in 1:p)) do t residual_order_condition(t, RungeKuttaMethod(eA, eb, ec)) end - AA, bb, islinear = Symbolics.linear_expansion(substitute.(constraints, (eb[1]=>γ,)), eb[2:end]) - AA = Float64.(map(unwrap, AA)) + AA, bb, islinear = Symbolics.linear_expansion(substitute.(constraints, (eb[1]=>y,)), eb[2:end]) + AA = BigFloat.(map(unwrap, AA)) idxs = qr(AA', ColumnNorm()).p[1:num_stages] @assert rank(AA[idxs, :]) == num_stages @assert islinear - Symbolics.expand.((AA[idxs, :] \ -bb[idxs]) - b) + Symbolics.expand.((AA[idxs, :] \ -bb[idxs]) - b)=# #e = b_hat - b adaptiveRadauTableau{Any, T2, Int}(T, TI, γ, α, β, c, num_stages) end \ No newline at end of file diff --git a/lib/OrdinaryDiffEqFIRK/test/ode_firk_tests.jl b/lib/OrdinaryDiffEqFIRK/test/ode_firk_tests.jl index 36289bdf7b..3d1e32989c 100644 --- a/lib/OrdinaryDiffEqFIRK/test/ode_firk_tests.jl +++ b/lib/OrdinaryDiffEqFIRK/test/ode_firk_tests.jl @@ -12,12 +12,11 @@ sol = solve(prob_ode_linear, AdaptiveRadau(), adaptive = false, dt = 1e-2) sol = solve(prob_ode_linear, RadauIIA9(), adaptive = false, dt = 1e-2) sol = solve(prob_ode_linear, RadauIIA5(), adaptive = false, dt = 1e-2) - -sim21 = test_convergence(1 ./ 10 .^ (4.5:-1:2.5), prob_ode_linear, AdaptiveRadau()) -@test sim21.𝒪est[:final]≈8 atol=testTol +sim21 = test_convergence(1 ./ 2 .^ (2.5:-1:0.5), prob_ode_linear, RadauIIA9()) +@test sim21.𝒪est[:final]≈9 atol=testTol sim21 = test_convergence(1 ./ 2 .^ (2.5:-1:0.5), prob_ode_2Dlinear, RadauIIA9()) -@test sim21.𝒪est[:final]≈8 atol=testTol +@test sim21.𝒪est[:final]≈9 atol=testTol # test adaptivity for iip in (true, false)