From b4487e9df71135b2c522c8ebd3cca04fa99dad3f Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Sun, 20 Oct 2024 10:51:13 -0400 Subject: [PATCH 1/2] Improve float conversions in PI controllers MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This is fairly hard to test, but it basically stems from `float(1//50)::Float64`, which means that this was always doing some float64 stuff, even when it should've been doing fastpow stuff. The test case just really required a Float32: ```julia using DiffEqCallbacks, OrdinaryDiffEq, Tracker Base.prevfloat(r::Tracker.TrackedReal) = Tracker.track(prevfloat, r) Tracker.@grad function prevfloat(r::Real) prevfloat(Tracker.data(r)), Δ -> (Δ,) end Base.nextfloat(r::Tracker.TrackedReal) = Tracker.track(nextfloat, r) Tracker.@grad function nextfloat(r::Real) nextfloat(Tracker.data(r)), Δ -> (Δ,) end function rober(u, p::TrackedArray, t) y₁, y₂, y₃ = u k₁, k₂, k₃ = p return Tracker.collect([-k₁ * y₁ + k₃ * y₂ * y₃, k₁ * y₁ - k₂ * y₂^2 - k₃ * y₂ * y₃, k₂ * y₂^2]) end p = TrackedArray([1.9f0, 1.0f0, 3.0f0]) u0 = TrackedArray([1.0f0, 0.0f0, 0.0f0]) tspan = TrackedArray([0.0f0, 1.0f0]) prob = ODEProblem{false}(rober, u0, tspan, p) p = TrackedArray([1.9f0, 1.0f0, 3.0f0]) u0 = TrackedArray([1.0f0, 0.0f0, 0.0f0]) tspan = TrackedArray([0.0f0, 1.0f0]) prob = ODEProblem{false}(rober, u0, tspan, p) saved_values = SavedValues(eltype(tspan), eltype(p)) cb = SavingCallback((u, t, integrator) -> integrator.EEst * integrator.dt, saved_values) solve(remake(prob, u0 = u0, p = p, tspan = tspan), Tsit5(), sensealg = SensitivityADPassThrough(), callback = cb) @test !all(iszero.(Tracker.gradient( p -> begin solve(remake(prob, u0 = u0, p = p, tspan = tspan), Tsit5(), sensealg = SensitivityADPassThrough(), callback = cb) return sum(saved_values.saveval) end, p)[1])) ``` Thus downstream tests with FastPower.jl used catches this, and it's somewhat hard to construct a case that's this sensitive to the type. --- lib/OrdinaryDiffEqCore/src/integrators/controllers.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/OrdinaryDiffEqCore/src/integrators/controllers.jl b/lib/OrdinaryDiffEqCore/src/integrators/controllers.jl index 1f47f61ebc..9a0bb53267 100644 --- a/lib/OrdinaryDiffEqCore/src/integrators/controllers.jl +++ b/lib/OrdinaryDiffEqCore/src/integrators/controllers.jl @@ -141,8 +141,8 @@ end if iszero(EEst) q = inv(qmax) else - q11 = FastPower.fastpower(EEst, float(beta1)) - q = q11 / FastPower.fastpower(qold, float(beta2)) + q11 = FastPower.fastpower(EEst, convert(typeof(EEst),beta1)) + q = q11 / FastPower.fastpower(qold, convert(typeof(EEst),beta2)) integrator.q11 = q11 @fastmath q = max(inv(qmax), min(inv(qmin), q / gamma)) end From c755f9d56e706708465a6e287933baeac58ccadf Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Mon, 21 Oct 2024 04:23:43 -0400 Subject: [PATCH 2/2] Update ode_firk_tests.jl --- lib/OrdinaryDiffEqFIRK/test/ode_firk_tests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/OrdinaryDiffEqFIRK/test/ode_firk_tests.jl b/lib/OrdinaryDiffEqFIRK/test/ode_firk_tests.jl index fc2bbbd7e6..b90c9da113 100644 --- a/lib/OrdinaryDiffEqFIRK/test/ode_firk_tests.jl +++ b/lib/OrdinaryDiffEqFIRK/test/ode_firk_tests.jl @@ -1,7 +1,7 @@ using OrdinaryDiffEqFIRK, DiffEqDevTools, Test, LinearAlgebra import ODEProblemLibrary: prob_ode_linear, prob_ode_2Dlinear, van -testTol = 0.3 +testTol = 0.35 for prob in [prob_ode_linear, prob_ode_2Dlinear] sim21 = test_convergence(1 .// 2 .^ (6:-1:3), prob, RadauIIA5())