Skip to content

Commit

Permalink
Merge pull request #2394 from oscardssmith/os/fix-Float32-oop-bdf
Browse files Browse the repository at this point in the history
fix oop BDF gamma type and terk_tmp type
  • Loading branch information
ChrisRackauckas authored Aug 19, 2024
2 parents cb87c1c + 7f00225 commit f464308
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 21 deletions.
4 changes: 2 additions & 2 deletions lib/OrdinaryDiffEqBDF/src/algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ end
function QNDF1(; chunk_size = Val{0}(), autodiff = Val{true}(), standardtag = Val{true}(),
concrete_jac = nothing, diff_type = Val{:forward},
linsolve = nothing, precs = DEFAULT_PRECS, nlsolve = NLNewton(),
extrapolant = :linear, kappa = -0.1850,
extrapolant = :linear, kappa = -37//200,
controller = :Standard, step_limiter! = trivial_limiter!)
QNDF1{
_unwrap_val(chunk_size), _unwrap_val(autodiff), typeof(linsolve), typeof(nlsolve),
Expand Down Expand Up @@ -233,7 +233,7 @@ function QNDF(; max_order::Val{MO} = Val{5}(), chunk_size = Val{0}(),
diff_type = Val{:forward},
linsolve = nothing, precs = DEFAULT_PRECS, nlsolve = NLNewton(), κ = nothing,
tol = nothing,
extrapolant = :linear, kappa = promote(-0.1850, -1 // 9, -0.0823, -0.0415, 0),
extrapolant = :linear, kappa = (-37//200, -1//9, -823//10000, -83//2000, 0//1),
controller = :Standard, step_limiter! = trivial_limiter!) where {MO}
QNDF{MO, _unwrap_val(chunk_size), _unwrap_val(autodiff), typeof(linsolve),
typeof(nlsolve), typeof(precs), diff_type, _unwrap_val(standardtag),
Expand Down
6 changes: 3 additions & 3 deletions lib/OrdinaryDiffEqBDF/src/bdf_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ function alg_cache(alg::QNDF{MO}, u, rate_prototype, ::Type{uEltypeNoUnits},
::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits
} where {MO}
max_order = MO
γ, c = one(eltype(alg.kappa)), 1
γ, c = one(uEltypeNoUnits), 1
nlsolver = build_nlsolver(alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits,
uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(false))
dtprev = one(dt)
Expand Down Expand Up @@ -539,7 +539,7 @@ function alg_cache(alg::FBDF{MO}, u, rate_prototype, ::Type{uEltypeNoUnits},
dt, reltol, p, calck,
::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits
} where {MO}
γ, c = 1.0, 1.0
γ, c = one(uEltypeNoUnits), 1
max_order = MO
nlsolver = build_nlsolver(alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits,
uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(false))
Expand Down Expand Up @@ -614,7 +614,7 @@ function alg_cache(alg::FBDF{MO}, u, rate_prototype, ::Type{uEltypeNoUnits},
dt, reltol, p, calck,
::Val{true}) where {MO, uEltypeNoUnits, uBottomEltypeNoUnits,
tTypeNoUnits}
γ, c = 1.0, 1.0
γ, c = one(uEltypeNoUnits), 1
fsalfirst = zero(rate_prototype)
max_order = MO
nlsolver = build_nlsolver(alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits,
Expand Down
18 changes: 10 additions & 8 deletions lib/OrdinaryDiffEqBDF/src/controllers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ function choose_order!(alg::FBDF, integrator,
terk_tmp = @.. broadcast=false fd_weights[k - 2, 1]*u
vc = _vec(terk_tmp)
for i in 2:(k - 2)
@.. broadcast=false @views vc += fd_weights[i, k - 2] * u_history[:, i - 1]
@.. @views vc += fd_weights[i, k - 2] * u_history[:, i - 1]
end
@.. broadcast=false terk_tmp*=abs(dt^(k - 2))
calculate_residuals!(atmp, _vec(terk_tmp), _vec(uprev), _vec(u),
Expand Down Expand Up @@ -204,22 +204,24 @@ function choose_order!(alg::FBDF, integrator,
terkm1 = terkm2
fd_weights = calc_finite_difference_weights(ts_tmp, t + dt, k - 2,
Val(max_order))
terk_tmp = @.. broadcast=false fd_weights[k - 2, 1]*u
local terk_tmp
if u isa Number
terk_tmp = fd_weights[k - 2, 1]*u
for i in 2:(k - 2)
terk_tmp += fd_weights[i, k - 2] * u_history[i - 1]
end
terk_tmp *= abs(dt^(k - 2))
else
vc = _vec(terk_tmp)
# we need terk_tmp to be mutable.
# so it can be updated
terk_tmp = similar(u)
@.. terk_tmp = fd_weights[k - 2, 1]*_vec(u)
for i in 2:(k - 2)
@.. broadcast=false @views vc += fd_weights[i, k - 2] *
u_history[:, i - 1]
@.. @views terk_tmp += fd_weights[i, k - 2] * u_history[:, i - 1]
end
terk_tmp = reshape(vc, size(terk_tmp))
terk_tmp *= @.. broadcast=false abs(dt^(k - 2))
@.. terk_tmp *= abs(dt^(k - 2))
end
atmp = calculate_residuals(_vec(terk_tmp), _vec(uprev), _vec(u),
atmp = calculate_residuals(terk_tmp, _vec(uprev), _vec(u),
integrator.opts.abstol, integrator.opts.reltol,
integrator.opts.internalnorm, t)
terkm2 = integrator.opts.internalnorm(atmp, t)
Expand Down
30 changes: 22 additions & 8 deletions test/interface/linear_solver_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -161,14 +161,14 @@ end
using OrdinaryDiffEq, StaticArrays, LinearSolve, ParameterizedFunctions

hires = @ode_def Hires begin
dy1 = -1.71 * y1 + 0.43 * y2 + 8.32 * y3 + 0.0007
dy2 = 1.71 * y1 - 8.75 * y2
dy3 = -10.03 * y3 + 0.43 * y4 + 0.035 * y5
dy4 = 8.32 * y2 + 1.71 * y3 - 1.12 * y4
dy5 = -1.745 * y5 + 0.43 * y6 + 0.43 * y7
dy6 = -280.0 * y6 * y8 + 0.69 * y4 + 1.71 * y5 - 0.43 * y6 + 0.69 * y7
dy7 = 280.0 * y6 * y8 - 1.81 * y7
dy8 = -280.0 * y6 * y8 + 1.81 * y7
dy1 = -1.71f0 * y1 + 0.43f0 * y2 + 8.32f0 * y3 + 0.0007f0 + 1f-18*t
dy2 = 1.71f0 * y1 - 8.75f0 * y2
dy3 = -10.03f0 * y3 + 0.43f0 * y4 + 0.035f0 * y5
dy4 = 8.32f0 * y2 + 1.71f0 * y3 - 1.12f0 * y4
dy5 = -1.745f0 * y5 + 0.43f0 * y6 + 0.43f0 * y7
dy6 = -280.0f0 * y6 * y8 + 0.69f0 * y4 + 1.71f0 * y5 - 0.43f0 * y6 + 0.69f0 * y7
dy7 = 280.0f0 * y6 * y8 - 1.81f0 * y7
dy8 = -280.0f0 * y6 * y8 + 1.81f0 * y7
end

u0 = zeros(8)
Expand All @@ -178,7 +178,11 @@ u0[8] = 0.0057
probiip = ODEProblem{true}(hires, u0, (0.0, 10.0))
proboop = ODEProblem{false}(hires, u0, (0.0, 10.0))
probstatic = ODEProblem{false}(hires, SVector{8}(u0), (0.0, 10.0))
probiipf32 = ODEProblem{true}(hires, Float32.(u0), (0f0, 10f0))
proboopf32 = ODEProblem{false}(hires, Float32.(u0), (0f0, 10f0))
probstaticf32 = ODEProblem{false}(hires, SVector{8}(Float32.(u0)), (0f0, 10f0))
probs = (; probiip, proboop, probstatic)
probsf32 = (;probiipf32, proboopf32, probstaticf32)
qndf = QNDF()
krylov_qndf = QNDF(linsolve = KrylovJL_GMRES())
fbdf = FBDF()
Expand All @@ -197,3 +201,13 @@ refsol = solve(probiip, FBDF(), abstol = 1e-12, reltol = 1e-12)
end
end
end

@testset "Hires Float32 calc_W tests" begin
@testset "$probname" for (probname, prob) in pairs(probsf32)
@testset "$solname" for (solname, solver) in pairs(solvers)
sol = solve(prob, solver, maxiters = 2e4)
@test sol.retcode == ReturnCode.Success
@test isapprox(sol.u[end], refsol.u[end], rtol = 2e-3, atol = 1e-6)
end
end
end

0 comments on commit f464308

Please sign in to comment.