From 841ef9482d5ef2a6974f4d22dfbf682b1b4b0ac1 Mon Sep 17 00:00:00 2001 From: Utkarsh Date: Tue, 12 Sep 2023 18:50:11 -0400 Subject: [PATCH] Fix unwanted type promotion in InternalITP --- src/callbacks.jl | 18 ++++++------- src/internal_falsi.jl | 18 ++++++------- src/internal_itp.jl | 31 +++++++++++----------- src/solve.jl | 10 +++++--- test/callbacks.jl | 34 ++++++++++++------------- test/downstream/solve_error_handling.jl | 7 ++++- test/forwarddiff_dual_detection.jl | 3 ++- test/internal_rootfinder.jl | 4 +-- 8 files changed, 67 insertions(+), 58 deletions(-) diff --git a/src/callbacks.jl b/src/callbacks.jl index a4333d2cd..25d5393bc 100644 --- a/src/callbacks.jl +++ b/src/callbacks.jl @@ -118,26 +118,26 @@ end # Use a generated function for type stability even when many callbacks are given @inline function find_first_continuous_callback(integrator, - callbacks::Vararg{ - AbstractContinuousCallback, - N}) where {N} + callbacks::Vararg{ + AbstractContinuousCallback, + N}) where {N} find_first_continuous_callback(integrator, tuple(callbacks...)) end @generated function find_first_continuous_callback(integrator, - callbacks::NTuple{N, - AbstractContinuousCallback - }) where {N} + callbacks::NTuple{N, + AbstractContinuousCallback, + }) where {N} ex = quote tmin, upcrossing, event_occurred, event_idx = find_callback_time(integrator, - callbacks[1], 1) + callbacks[1], 1) identified_idx = 1 end for i in 2:N ex = quote $ex tmin2, upcrossing2, event_occurred2, event_idx2 = find_callback_time(integrator, - callbacks[$i], - $i) + callbacks[$i], + $i) if event_occurred2 && (tmin2 < tmin || !event_occurred) tmin = tmin2 upcrossing = upcrossing2 diff --git a/src/internal_falsi.jl b/src/internal_falsi.jl index 52e5e0d19..d8f693510 100644 --- a/src/internal_falsi.jl +++ b/src/internal_falsi.jl @@ -36,8 +36,8 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::InternalFalsi, arg if iszero(fr) return SciMLBase.build_solution(prob, alg, right, fr; - retcode = ReturnCode.ExactSolutionLeft, left = left, - right = right) + retcode = ReturnCode.ExactSolutionLeft, left = left, + right = right) end i = 1 @@ -129,7 +129,7 @@ function scalar_nlsolve_ad(prob, alg::InternalFalsi, args...; kwargs...) end function SciMLBase.solve(prob::IntervalNonlinearProblem{uType, iip, - <:ForwardDiff.Dual{T, V, P}}, + <:ForwardDiff.Dual{T, V, P}}, alg::InternalFalsi, args...; kwargs...) where {uType, iip, T, V, P} sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...) @@ -140,15 +140,15 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem{uType, iip, end function SciMLBase.solve(prob::IntervalNonlinearProblem{uType, iip, - <:AbstractArray{ - <:ForwardDiff.Dual{T, - V, - P}, - }}, + <:AbstractArray{ + <:ForwardDiff.Dual{T, + V, + P}, + }}, alg::InternalFalsi, args...; kwargs...) where {uType, iip, T, V, P} sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...) - + return SciMLBase.build_solution(prob, alg, ForwardDiff.Dual{T, V, P}(sol.u, partials), sol.resid; retcode = sol.retcode, left = ForwardDiff.Dual{T, V, P}(sol.left, partials), diff --git a/src/internal_itp.jl b/src/internal_itp.jl index 8541997eb..0b389bf97 100644 --- a/src/internal_itp.jl +++ b/src/internal_itp.jl @@ -2,7 +2,7 @@ `InternalITP`: A non-allocating ITP method, internal to DiffEqBase for simpler dependencies. """ -struct InternalITP +struct InternalITP k1::Float64 k2::Float64 n0::Int @@ -10,7 +10,8 @@ end InternalITP() = InternalITP(0.007, 1.5, 10) -function SciMLBase.solve(prob::IntervalNonlinearProblem{IP, Tuple{T,T}}, alg::InternalITP, args...; +function SciMLBase.solve(prob::IntervalNonlinearProblem{IP, Tuple{T, T}}, alg::InternalITP, + args...; maxiters = 1000, kwargs...) where {IP, T} f = Base.Fix2(prob.f, prob.p) left, right = prob.tspan # a and b @@ -26,9 +27,9 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem{IP, Tuple{T,T}}, alg::In right = right) end #defining variables/cache - k1 = alg.k1 - k2 = alg.k2 - n0 = alg.n0 + k1 = T(alg.k1) + k2 = T(alg.k2) + n0 = T(alg.n0) n_h = ceil(log2(abs(right - left) / (2 * ϵ))) mid = (left + right) / 2 x_f = (fr * left - fl * right) / (fr - fl) @@ -46,7 +47,7 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem{IP, Tuple{T,T}}, alg::In δ = k1 * (span^k2) ## Interpolation step ## - x_f = left + (right - left) * (fl/(fl - fr)) + x_f = left + (right - left) * (fl / (fl - fr)) ## Truncation step ## σ = sign(mid - x_f) @@ -79,8 +80,8 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem{IP, Tuple{T,T}}, alg::In left = prevfloat_tdir(xp, prob.tspan...) right = xp return SciMLBase.build_solution(prob, alg, left, f(left); - retcode = ReturnCode.Success, left = left, - right = right) + retcode = ReturnCode.Success, left = left, + right = right) end i += 1 mid = (left + right) / 2 @@ -127,7 +128,7 @@ function scalar_nlsolve_ad(prob, alg::InternalITP, args...; kwargs...) end function SciMLBase.solve(prob::IntervalNonlinearProblem{uType, iip, - <:ForwardDiff.Dual{T, V, P}}, + <:ForwardDiff.Dual{T, V, P}}, alg::InternalITP, args...; kwargs...) where {uType, iip, T, V, P} sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...) @@ -138,15 +139,15 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem{uType, iip, end function SciMLBase.solve(prob::IntervalNonlinearProblem{uType, iip, - <:AbstractArray{ - <:ForwardDiff.Dual{T, - V, - P}, - }}, + <:AbstractArray{ + <:ForwardDiff.Dual{T, + V, + P}, + }}, alg::InternalITP, args...; kwargs...) where {uType, iip, T, V, P} sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...) - + return SciMLBase.build_solution(prob, alg, ForwardDiff.Dual{T, V, P}(sol.u, partials), sol.resid; retcode = sol.retcode, left = ForwardDiff.Dual{T, V, P}(sol.left, partials), diff --git a/src/solve.jl b/src/solve.jl index d6637bd8c..9371e7b96 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -203,7 +203,7 @@ const NOISE_SIZE_MESSAGE = """ be double checked. """ -struct NoiseSizeIncompatabilityError <: Exception +struct NoiseSizeIncompatabilityError <: Exception prototypesize::Int noisesize::Int end @@ -1025,7 +1025,7 @@ function solve(prob::EnsembleProblem, args...; kwargs...) end end function solve(prob::SciMLBase.WeightedEnsembleProblem, args...; kwargs...) - SciMLBase.WeightedEnsembleSolution(solve(prob.ensembleprob), prob.weights) + SciMLBase.WeightedEnsembleSolution(solve(prob.ensembleprob), prob.weights) end function solve(prob::AbstractNoiseProblem, args...; kwargs...) __solve(prob, args...; kwargs...) @@ -1307,8 +1307,10 @@ function check_prob_alg_pairing(prob, alg) end if prob isa SDEProblem && prob.noise_rate_prototype !== nothing && - prob.noise !== nothing && size(prob.noise_rate_prototype,2) != length(prob.noise.W[1]) - throw(NoiseSizeIncompatabilityError(size(prob.noise_rate_prototype,2), length(prob.noise.W[1]))) + prob.noise !== nothing && + size(prob.noise_rate_prototype, 2) != length(prob.noise.W[1]) + throw(NoiseSizeIncompatabilityError(size(prob.noise_rate_prototype, 2), + length(prob.noise.W[1]))) end # Complex number support comes before arbitrary number support for a more direct diff --git a/test/callbacks.jl b/test/callbacks.jl index 252af55bd..98bcc2824 100644 --- a/test/callbacks.jl +++ b/test/callbacks.jl @@ -58,11 +58,11 @@ struct EmptyIntegrator u::Vector{Float64} end function DiffEqBase.find_callback_time(integrator::EmptyIntegrator, - callback::ContinuousCallback, counter) + callback::ContinuousCallback, counter) 1.0 + counter, 0.9 + counter, true, counter end function DiffEqBase.find_callback_time(integrator::EmptyIntegrator, - callback::VectorContinuousCallback, counter) + callback::VectorContinuousCallback, counter) 1.0 + counter, 0.9 + counter, true, counter end find_first_integrator = EmptyIntegrator([1.0, 2.0]) @@ -82,21 +82,21 @@ cond_9(u, t, integrator) = t - 1.8 cond_10(u, t, integrator) = t - 1.9 # Setup a lot of callbacks so the recursive inference failure happens callbacks = (ContinuousCallback(cond_1, affect!), - ContinuousCallback(cond_2, affect!), - ContinuousCallback(cond_3, affect!), - ContinuousCallback(cond_4, affect!), - ContinuousCallback(cond_5, affect!), - ContinuousCallback(cond_6, affect!), - ContinuousCallback(cond_7, affect!), - ContinuousCallback(cond_8, affect!), - ContinuousCallback(cond_9, affect!), - ContinuousCallback(cond_10, affect!), - VectorContinuousCallback(cond_1, vector_affect!, 2), - VectorContinuousCallback(cond_2, vector_affect!, 2), - VectorContinuousCallback(cond_3, vector_affect!, 2), - VectorContinuousCallback(cond_4, vector_affect!, 2), - VectorContinuousCallback(cond_5, vector_affect!, 2), - VectorContinuousCallback(cond_6, vector_affect!, 2)); + ContinuousCallback(cond_2, affect!), + ContinuousCallback(cond_3, affect!), + ContinuousCallback(cond_4, affect!), + ContinuousCallback(cond_5, affect!), + ContinuousCallback(cond_6, affect!), + ContinuousCallback(cond_7, affect!), + ContinuousCallback(cond_8, affect!), + ContinuousCallback(cond_9, affect!), + ContinuousCallback(cond_10, affect!), + VectorContinuousCallback(cond_1, vector_affect!, 2), + VectorContinuousCallback(cond_2, vector_affect!, 2), + VectorContinuousCallback(cond_3, vector_affect!, 2), + VectorContinuousCallback(cond_4, vector_affect!, 2), + VectorContinuousCallback(cond_5, vector_affect!, 2), + VectorContinuousCallback(cond_6, vector_affect!, 2)); function test_find_first_callback(callbacks, int) @timed(DiffEqBase.find_first_continuous_callback(int, callbacks...)) end diff --git a/test/downstream/solve_error_handling.jl b/test/downstream/solve_error_handling.jl index 12e8a79e2..4090bd94a 100644 --- a/test/downstream/solve_error_handling.jl +++ b/test/downstream/solve_error_handling.jl @@ -61,5 +61,10 @@ function g(du, u, p, t) du[2, 4] = 1.8u[2] end -prob = SDEProblem(f, g, randn(ComplexF64,2), (0.0, 1.0), noise_rate_prototype =complex(zeros(2, 4)),noise=StochasticDiffEq.RealWienerProcess(0.0,zeros(3))) +prob = SDEProblem(f, + g, + randn(ComplexF64, 2), + (0.0, 1.0), + noise_rate_prototype = complex(zeros(2, 4)), + noise = StochasticDiffEq.RealWienerProcess(0.0, zeros(3))) @test_throws DiffEqBase.NoiseSizeIncompatabilityError solve(prob, LambaEM()) diff --git a/test/forwarddiff_dual_detection.jl b/test/forwarddiff_dual_detection.jl index adfaac8f3..d880983ff 100644 --- a/test/forwarddiff_dual_detection.jl +++ b/test/forwarddiff_dual_detection.jl @@ -85,7 +85,8 @@ p_possibilities17 = [ (Mod, ForwardDiff.Dual(2.0)), (() -> 2.0, ForwardDiff.Dual(2.0)), (Base.pointer([2.0]), ForwardDiff.Dual(2.0)), ] -VERSION >= v"1.7" && push!(p_possibilities17, Returns((a = 2, b = 1.3, c = ForwardDiff.Dual(2.0f0)))) +VERSION >= v"1.7" && + push!(p_possibilities17, Returns((a = 2, b = 1.3, c = ForwardDiff.Dual(2.0f0)))) for p in p_possibilities17 @show p diff --git a/test/internal_rootfinder.jl b/test/internal_rootfinder.jl index e84b38294..323cc25be 100644 --- a/test/internal_rootfinder.jl +++ b/test/internal_rootfinder.jl @@ -20,7 +20,7 @@ for Rootfinder in (InternalFalsi, InternalITP) # https://github.com/SciML/DiffEqBase.jl/issues/916 inp = IntervalNonlinearProblem((t, p) -> min(-1.0 + 0.001427344607477125 * t, 1e-9), - (699.0079267259368, 700.6176418816023)) + (699.0079267259368, 700.6176418816023)) @test solve(inp, rf).u ≈ 700.6016590257979 # Flipped signs & reversed tspan test for bracketing algorithms @@ -36,4 +36,4 @@ for Rootfinder in (InternalFalsi, InternalITP) @test abs.(solve(inp3, rf).u) ≈ sqrt.(p) @test abs.(solve(inp4, rf).u) ≈ sqrt.(p) end -end \ No newline at end of file +end