diff --git a/lib/OrdinaryDiffEqBDF/src/algorithms.jl b/lib/OrdinaryDiffEqBDF/src/algorithms.jl index 387a9d6a72..1a4c960c7b 100644 --- a/lib/OrdinaryDiffEqBDF/src/algorithms.jl +++ b/lib/OrdinaryDiffEqBDF/src/algorithms.jl @@ -150,7 +150,7 @@ end function QNDF1(; chunk_size = Val{0}(), autodiff = Val{true}(), standardtag = Val{true}(), concrete_jac = nothing, diff_type = Val{:forward}, linsolve = nothing, precs = DEFAULT_PRECS, nlsolve = NLNewton(), - extrapolant = :linear, kappa = -0.1850, + extrapolant = :linear, kappa = -37//200, controller = :Standard, step_limiter! = trivial_limiter!) QNDF1{ _unwrap_val(chunk_size), _unwrap_val(autodiff), typeof(linsolve), typeof(nlsolve), @@ -233,7 +233,7 @@ function QNDF(; max_order::Val{MO} = Val{5}(), chunk_size = Val{0}(), diff_type = Val{:forward}, linsolve = nothing, precs = DEFAULT_PRECS, nlsolve = NLNewton(), κ = nothing, tol = nothing, - extrapolant = :linear, kappa = promote(-0.1850, -1 // 9, -0.0823, -0.0415, 0), + extrapolant = :linear, kappa = (-37//200, -1//9, -823//10000, -83//2000, 0//1), controller = :Standard, step_limiter! = trivial_limiter!) where {MO} QNDF{MO, _unwrap_val(chunk_size), _unwrap_val(autodiff), typeof(linsolve), typeof(nlsolve), typeof(precs), diff_type, _unwrap_val(standardtag), diff --git a/lib/OrdinaryDiffEqBDF/src/bdf_caches.jl b/lib/OrdinaryDiffEqBDF/src/bdf_caches.jl index 7cdc4b580f..9d60023c13 100644 --- a/lib/OrdinaryDiffEqBDF/src/bdf_caches.jl +++ b/lib/OrdinaryDiffEqBDF/src/bdf_caches.jl @@ -358,7 +358,7 @@ function alg_cache(alg::QNDF{MO}, u, rate_prototype, ::Type{uEltypeNoUnits}, ::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits } where {MO} max_order = MO - γ, c = one(eltype(alg.kappa)), 1 + γ, c = one(uEltypeNoUnits), 1 nlsolver = build_nlsolver(alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(false)) dtprev = one(dt) @@ -539,7 +539,7 @@ function alg_cache(alg::FBDF{MO}, u, rate_prototype, ::Type{uEltypeNoUnits}, dt, reltol, p, calck, ::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits } where {MO} - γ, c = 1.0, 1.0 + γ, c = one(uEltypeNoUnits), 1 max_order = MO nlsolver = build_nlsolver(alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(false)) @@ -614,7 +614,7 @@ function alg_cache(alg::FBDF{MO}, u, rate_prototype, ::Type{uEltypeNoUnits}, dt, reltol, p, calck, ::Val{true}) where {MO, uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - γ, c = 1.0, 1.0 + γ, c = one(uEltypeNoUnits), 1 fsalfirst = zero(rate_prototype) max_order = MO nlsolver = build_nlsolver(alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits, diff --git a/lib/OrdinaryDiffEqBDF/src/controllers.jl b/lib/OrdinaryDiffEqBDF/src/controllers.jl index f3cbdf2d03..c001e6a586 100644 --- a/lib/OrdinaryDiffEqBDF/src/controllers.jl +++ b/lib/OrdinaryDiffEqBDF/src/controllers.jl @@ -172,7 +172,7 @@ function choose_order!(alg::FBDF, integrator, terk_tmp = @.. broadcast=false fd_weights[k - 2, 1]*u vc = _vec(terk_tmp) for i in 2:(k - 2) - @.. broadcast=false @views vc += fd_weights[i, k - 2] * u_history[:, i - 1] + @.. @views vc += fd_weights[i, k - 2] * u_history[:, i - 1] end @.. broadcast=false terk_tmp*=abs(dt^(k - 2)) calculate_residuals!(atmp, _vec(terk_tmp), _vec(uprev), _vec(u), @@ -204,22 +204,24 @@ function choose_order!(alg::FBDF, integrator, terkm1 = terkm2 fd_weights = calc_finite_difference_weights(ts_tmp, t + dt, k - 2, Val(max_order)) - terk_tmp = @.. broadcast=false fd_weights[k - 2, 1]*u + local terk_tmp if u isa Number + terk_tmp = fd_weights[k - 2, 1]*u for i in 2:(k - 2) terk_tmp += fd_weights[i, k - 2] * u_history[i - 1] end terk_tmp *= abs(dt^(k - 2)) else - vc = _vec(terk_tmp) + # we need terk_tmp to be mutable. + # so it can be updated + terk_tmp = similar(u) + @.. terk_tmp = fd_weights[k - 2, 1]*_vec(u) for i in 2:(k - 2) - @.. broadcast=false @views vc += fd_weights[i, k - 2] * - u_history[:, i - 1] + @.. @views terk_tmp += fd_weights[i, k - 2] * u_history[:, i - 1] end - terk_tmp = reshape(vc, size(terk_tmp)) - terk_tmp *= @.. broadcast=false abs(dt^(k - 2)) + @.. terk_tmp *= abs(dt^(k - 2)) end - atmp = calculate_residuals(_vec(terk_tmp), _vec(uprev), _vec(u), + atmp = calculate_residuals(terk_tmp, _vec(uprev), _vec(u), integrator.opts.abstol, integrator.opts.reltol, integrator.opts.internalnorm, t) terkm2 = integrator.opts.internalnorm(atmp, t) diff --git a/test/interface/linear_solver_test.jl b/test/interface/linear_solver_test.jl index 515fbd3459..12a449d603 100644 --- a/test/interface/linear_solver_test.jl +++ b/test/interface/linear_solver_test.jl @@ -161,14 +161,14 @@ end using OrdinaryDiffEq, StaticArrays, LinearSolve, ParameterizedFunctions hires = @ode_def Hires begin - dy1 = -1.71 * y1 + 0.43 * y2 + 8.32 * y3 + 0.0007 - dy2 = 1.71 * y1 - 8.75 * y2 - dy3 = -10.03 * y3 + 0.43 * y4 + 0.035 * y5 - dy4 = 8.32 * y2 + 1.71 * y3 - 1.12 * y4 - dy5 = -1.745 * y5 + 0.43 * y6 + 0.43 * y7 - dy6 = -280.0 * y6 * y8 + 0.69 * y4 + 1.71 * y5 - 0.43 * y6 + 0.69 * y7 - dy7 = 280.0 * y6 * y8 - 1.81 * y7 - dy8 = -280.0 * y6 * y8 + 1.81 * y7 + dy1 = -1.71f0 * y1 + 0.43f0 * y2 + 8.32f0 * y3 + 0.0007f0 + 1f-18*t + dy2 = 1.71f0 * y1 - 8.75f0 * y2 + dy3 = -10.03f0 * y3 + 0.43f0 * y4 + 0.035f0 * y5 + dy4 = 8.32f0 * y2 + 1.71f0 * y3 - 1.12f0 * y4 + dy5 = -1.745f0 * y5 + 0.43f0 * y6 + 0.43f0 * y7 + dy6 = -280.0f0 * y6 * y8 + 0.69f0 * y4 + 1.71f0 * y5 - 0.43f0 * y6 + 0.69f0 * y7 + dy7 = 280.0f0 * y6 * y8 - 1.81f0 * y7 + dy8 = -280.0f0 * y6 * y8 + 1.81f0 * y7 end u0 = zeros(8) @@ -178,7 +178,11 @@ u0[8] = 0.0057 probiip = ODEProblem{true}(hires, u0, (0.0, 10.0)) proboop = ODEProblem{false}(hires, u0, (0.0, 10.0)) probstatic = ODEProblem{false}(hires, SVector{8}(u0), (0.0, 10.0)) +probiipf32 = ODEProblem{true}(hires, Float32.(u0), (0f0, 10f0)) +proboopf32 = ODEProblem{false}(hires, Float32.(u0), (0f0, 10f0)) +probstaticf32 = ODEProblem{false}(hires, SVector{8}(Float32.(u0)), (0f0, 10f0)) probs = (; probiip, proboop, probstatic) +probsf32 = (;probiipf32, proboopf32, probstaticf32) qndf = QNDF() krylov_qndf = QNDF(linsolve = KrylovJL_GMRES()) fbdf = FBDF() @@ -197,3 +201,13 @@ refsol = solve(probiip, FBDF(), abstol = 1e-12, reltol = 1e-12) end end end + +@testset "Hires Float32 calc_W tests" begin + @testset "$probname" for (probname, prob) in pairs(probsf32) + @testset "$solname" for (solname, solver) in pairs(solvers) + sol = solve(prob, solver, maxiters = 2e4) + @test sol.retcode == ReturnCode.Success + @test isapprox(sol.u[end], refsol.u[end], rtol = 2e-3, atol = 1e-6) + end + end +end