diff --git a/src/dual.jl b/src/dual.jl index 2ca4683f..367e1b6a 100644 --- a/src/dual.jl +++ b/src/dual.jl @@ -25,6 +25,9 @@ end ########## Base.ArithmeticStyle(::Type{<:Dual{T,V}}) where {T,V} = Base.ArithmeticStyle(V) +unwrap_dual(::Type{Dual{T,V,N}}) where {T,V,N} = unwrap_dual(V) +unwrap_dual(::Type{V}) where V = V + ############## # Exceptions # ############## diff --git a/src/partials.jl b/src/partials.jl index a5316e3e..d9d13f98 100644 --- a/src/partials.jl +++ b/src/partials.jl @@ -163,7 +163,11 @@ end @inline rand_tuple(::AbstractRNG, ::Type{Tuple{}}) = tuple() @inline rand_tuple(::Type{Tuple{}}) = tuple() -@generated function iszero_tuple(tup::NTuple{N,V}) where {N,V} +iszero_tuple(tup::NTuple{N,V}) where {N,V} = _iszero_tuple(unwrap_dual(V), tup) + +# default implementation; specific number types (e.g., Interval from IntervalArithmetic) +# can add specializations. +@generated function _iszero_tuple(::Type{V0}, tup::NTuple{N,V}) where {V0,N,V} ex = Expr(:&&, [:(z == tup[$i]) for i=1:N]...) return quote z = zero(V)