Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

clean up radau tableau generation #2531

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions lib/OrdinaryDiffEqFIRK/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ version = "1.3.0"
[deps]
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898"
FastGaussQuadrature = "442a2c76-b920-505d-bb47-c5924d526838"
FastPower = "a4df4552-cc26-4903-aec0-212e50a0e84b"
GenericLinearAlgebra = "14197337-ba66-59df-a3e3-ca00e7dcff7a"
GenericSchur = "c145ed77-6b09-5dd9-b285-bf645a82121e"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
Expand All @@ -18,16 +18,14 @@ OrdinaryDiffEqNonlinearSolve = "127b3ac7-2247-4354-8eb6-78cf4e7c58e8"
Polynomials = "f27b6e38-b328-58d1-80ce-0feddd5e7a45"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
RootedTrees = "47965b36-3f3e-11e9-0dcf-4570dfd42a8c"
SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961"
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"

[compat]
DiffEqBase = "6.152.2"
DiffEqDevTools = "2.44.4"
FastBroadcast = "0.3.5"
FastGaussQuadrature = "1.0.2"
FastPower = "1"
GenericLinearAlgebra = "0.3.13"
GenericSchur = "0.5.4"
LinearAlgebra = "<0.0.1, 1"
LinearSolve = "2.32.0"
Expand All @@ -40,10 +38,8 @@ Polynomials = "4.0.11"
Random = "<0.0.1, 1"
RecursiveArrayTools = "3.27.0"
Reexport = "1.2.2"
RootedTrees = "2.23.1"
SafeTestsets = "0.1.0"
SciMLOperators = "0.3.9"
Symbolics = "6.15.3"
Test = "<0.0.1, 1"
julia = "1.10"

Expand Down
1 change: 0 additions & 1 deletion lib/OrdinaryDiffEqFIRK/src/OrdinaryDiffEqFIRK.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ import OrdinaryDiffEqCore: alg_order, calculate_residuals!,
get_current_adaptive_order, get_fsalfirstlast,
isfirk, generic_solver_docstring
using MuladdMacro, DiffEqBase, RecursiveArrayTools
using Polynomials, GenericLinearAlgebra, GenericSchur
using SciMLOperators: AbstractSciMLOperator
using LinearAlgebra: I, UniformScaling, mul!, lu
import LinearSolve
Expand Down
14 changes: 8 additions & 6 deletions lib/OrdinaryDiffEqFIRK/src/firk_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -497,13 +497,14 @@ function alg_cache(alg::AdaptiveRadau, u, rate_prototype, ::Type{uEltypeNoUnits}
::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
uf = UDerivativeWrapper(f, t, p)
uToltype = constvalue(uBottomEltypeNoUnits)
tTolType = constvalue(tTypeNoUnits)
num_stages = alg.min_stages
max = alg.max_stages
tabs = [BigRadauIIA5Tableau(uToltype, constvalue(tTypeNoUnits)), BigRadauIIA9Tableau(uToltype, constvalue(tTypeNoUnits)), BigRadauIIA13Tableau(uToltype, constvalue(tTypeNoUnits))]
tabs = [RadauIIATableau(uToltype, tTolType, 3), RadauIIATableau(uToltype, tTolType, 5), RadauIIATableau(uToltype, tTolType, 7)]

i = 9
while i <= alg.max_stages
push!(tabs, adaptiveRadauTableau(uToltype, constvalue(tTypeNoUnits), i))
push!(tabs, RadauIIATableau(uToltype, tTolType, i))
i += 2
end
cont = Vector{typeof(u)}(undef, max)
Expand Down Expand Up @@ -567,11 +568,12 @@ function alg_cache(alg::AdaptiveRadau, u, rate_prototype, ::Type{uEltypeNoUnits}
::Val{true}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
uf = UJacobianWrapper(f, t, p)
uToltype = constvalue(uBottomEltypeNoUnits)
tToltype = constvalue(tTypeNoUnits)

max = alg.max_stages
num_stages = alg.min_stages

tabs = [BigRadauIIA5Tableau(uToltype, constvalue(tTypeNoUnits)), BigRadauIIA9Tableau(uToltype, constvalue(tTypeNoUnits)), BigRadauIIA13Tableau(uToltype, constvalue(tTypeNoUnits))]
tabs = [RadauIIATableau(uToltype, tToltype, 3), RadauIIATableau(uToltype, tToltype, 4), RadauIIATableau(uToltype, tToltype, 5)]
i = 9
while i <= max
push!(tabs, adaptiveRadauTableau(uToltype, constvalue(tTypeNoUnits), i))
Expand Down Expand Up @@ -609,7 +611,7 @@ function alg_cache(alg::AdaptiveRadau, u, rate_prototype, ::Type{uEltypeNoUnits}
fsalfirst = zero(rate_prototype)
fw = [zero(rate_prototype) for i in 1 : max]
ks = [zero(rate_prototype) for i in 1 : max]

k = ks[1]

J, W1 = build_J_W(alg, u, uprev, p, t, dt, f, uEltypeNoUnits, Val(true))
Expand Down Expand Up @@ -641,7 +643,7 @@ function alg_cache(alg::AdaptiveRadau, u, rate_prototype, ::Type{uEltypeNoUnits}
atol = reltol isa Number ? reltol : zero(reltol)

AdaptiveRadauCache(u, uprev,
z, w, c_prime, dw1, ubuff, dw2, cubuff, dw, cont, derivatives,
z, w, c_prime, dw1, ubuff, dw2, cubuff, dw, cont, derivatives,
du1, fsalfirst, ks, k, fw,
J, W1, W2,
uf, tabs, κ, one(uToltype), 10000, tmp,
Expand Down
24 changes: 12 additions & 12 deletions lib/OrdinaryDiffEqFIRK/src/firk_perform_step.jl
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ end

@muladd function perform_step!(integrator, cache::RadauIIA3ConstantCache)
@unpack t, dt, uprev, u, f, p = integrator
@unpack T11, T12, T21, T22, TI11, TI12, TI21, TI22 = cache.tab
@unpack T11, T12, T21, TI12, TI21, TI22 = cache.tab
@unpack c1, c2, α, β, e1, e2 = cache.tab
@unpack κ, cont1, cont2 = cache
@unpack internalnorm, abstol, reltol, adaptive = integrator.opts
Expand Down Expand Up @@ -153,7 +153,7 @@ end
ff2 = f(uprev + z2, p, t + c2 * dt)
OrdinaryDiffEqCore.increment_nf!(integrator.stats, 2)

fw1 = @. TI11 * ff1 + TI12 * ff2
fw1 = @. TI12 * ff2 #TI11 = 0
fw2 = @. TI21 * ff1 + TI22 * ff2

if mass_matrix isa UniformScaling
Expand Down Expand Up @@ -193,7 +193,7 @@ end

# transform `w` to `z`
z1 = @. T11 * w1 + T12 * w2
z2 = @. T21 * w1 + T22 * w2
z2 = @. T21 * w1 # T22 = 0

# check stopping criterion
iter > 1 && (η = θ / (1 - θ))
Expand Down Expand Up @@ -226,7 +226,7 @@ end

@muladd function perform_step!(integrator, cache::RadauIIA3Cache, repeat_step = false)
@unpack t, dt, uprev, u, f, p, fsallast, fsalfirst = integrator
@unpack T11, T12, T21, T22, TI11, TI12, TI21, TI22 = cache.tab
@unpack T11, T12, T21, TI12, TI21, TI22 = cache.tab
@unpack c1, c2, α, β, e1, e2 = cache.tab
@unpack κ, cont1, cont2 = cache
@unpack z1, z2, w1, w2,
Expand Down Expand Up @@ -273,7 +273,7 @@ end
f(k2, tmp, p, t + c2 * dt)
OrdinaryDiffEqCore.increment_nf!(integrator.stats, 2)

@. fw1 = TI11 * fsallast + TI12 * k2
@. fw1 = TI12 * k2 # TI11=0
@. fw2 = TI21 * fsallast + TI22 * k2

if mass_matrix === I
Expand Down Expand Up @@ -332,7 +332,7 @@ end

# transform `w` to `z`
@. z1 = T11 * w1 + T12 * w2
@. z2 = T21 * w1 + T22 * w2
@. z2 = T21 * w1 #T22 = 0

# check stopping criterion
iter > 1 && (η = θ / (1 - θ))
Expand Down Expand Up @@ -1493,7 +1493,7 @@ end
break
end
end

for i in 1 : num_stages
w[i] = @.. w[i] - z[i]
end
Expand All @@ -1513,7 +1513,7 @@ end
i += 2
end


# check stopping criterion
iter > 1 && (η = θ / (1 - θ))
if η * ndw < κ && (iter > 1 || iszero(ndw) || !iszero(integrator.success_iter))
Expand All @@ -1534,7 +1534,7 @@ end
cache.iter = iter

u = @.. uprev + z[num_stages]

if adaptive
tmp = 0
for i in 1 : num_stages
Expand Down Expand Up @@ -1638,7 +1638,7 @@ end
@.. z[i] = cont[num_stages] * (c_prime[i] - c[1] + 1) + cont[num_stages - 1]
j = num_stages - 2
while j > 0
@.. z[i] *= (c_prime[i] - c[num_stages - j] + 1)
@.. z[i] *= (c_prime[i] - c[num_stages - j] + 1)
@.. z[i] += cont[j]
j = j - 1
end
Expand Down Expand Up @@ -1682,7 +1682,7 @@ end
Mw = w
elseif mass_matrix isa UniformScaling
for i in 1 : num_stages
mul!(z[i], mass_matrix.λ, w[i])
mul!(z[i], mass_matrix.λ, w[i])
end
Mw = z
else
Expand Down Expand Up @@ -1784,7 +1784,7 @@ end
@.. broadcast=false u=uprev + z[num_stages]

step_limiter!(u, integrator, p, t + dt)

if adaptive
utilde = w[2]
@.. tmp = 0
Expand Down
Loading