From a0d23e78cef34a81bdd0beee92badd9889c8b6ae Mon Sep 17 00:00:00 2001 From: Tim Holy Date: Wed, 27 Sep 2023 21:16:14 -0500 Subject: [PATCH] Add flexibility in dispatch for `iszero_tuple` MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit IntervalArithmetic.jl may abandon support for `==` among intervals (https://github.com/JuliaIntervals/IntervalArithmetic.jl/pull/571). To support specialization for specific Number subtypes, this makes `iszero_tuple` into a "trait"-dispatched function, first unwrapping all the way down to the elementary numeric type and then jointly dispatching on that type and the actual tuple. This makes it possible to create an extension in IntervalArithmetic that specializes the implementation to use the new comparison operator `≛`. The use of recursive unwrapping enables support for higher-order derivatives. --- src/dual.jl | 3 +++ src/partials.jl | 6 +++++- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/src/dual.jl b/src/dual.jl index 2ca4683f..367e1b6a 100644 --- a/src/dual.jl +++ b/src/dual.jl @@ -25,6 +25,9 @@ end ########## Base.ArithmeticStyle(::Type{<:Dual{T,V}}) where {T,V} = Base.ArithmeticStyle(V) +unwrap_dual(::Type{Dual{T,V,N}}) where {T,V,N} = unwrap_dual(V) +unwrap_dual(::Type{V}) where V = V + ############## # Exceptions # ############## diff --git a/src/partials.jl b/src/partials.jl index a5316e3e..d9d13f98 100644 --- a/src/partials.jl +++ b/src/partials.jl @@ -163,7 +163,11 @@ end @inline rand_tuple(::AbstractRNG, ::Type{Tuple{}}) = tuple() @inline rand_tuple(::Type{Tuple{}}) = tuple() -@generated function iszero_tuple(tup::NTuple{N,V}) where {N,V} +iszero_tuple(tup::NTuple{N,V}) where {N,V} = _iszero_tuple(unwrap_dual(V), tup) + +# default implementation; specific number types (e.g., Interval from IntervalArithmetic) +# can add specializations. +@generated function _iszero_tuple(::Type{V0}, tup::NTuple{N,V}) where {V0,N,V} ex = Expr(:&&, [:(z == tup[$i]) for i=1:N]...) return quote z = zero(V)