diff --git a/src/dual.jl b/src/dual.jl index 22129b68..8f557860 100644 --- a/src/dual.jl +++ b/src/dual.jl @@ -98,7 +98,7 @@ macro define_binary_dual_op(f, xy_body, x_body, y_body) @inline $(f)(x::Dual{Txy}, y::Dual{Txy}) where {Txy} = $xy_body @inline $(f)(x::Dual{Tx}, y::Dual{Ty}) where {Tx,Ty} = Ty ≺ Tx ? $x_body : $y_body end - for R in REAL_TYPES + for R in AMBIGUOUS_TYPES expr = quote @inline $(f)(x::Dual{Tx}, y::$R) where {Tx} = $x_body @inline $(f)(x::$R, y::Dual{Ty}) where {Ty} = $y_body @@ -124,7 +124,7 @@ macro define_ternary_dual_op(f, xyz_body, xy_body, xz_body, yz_body, x_body, y_b end end end - for R in REAL_TYPES + for R in AMBIGUOUS_TYPES expr = quote @inline $(f)(x::Dual{Txy}, y::Dual{Txy}, z::$R) where {Txy} = $xy_body @inline $(f)(x::Dual{Tx}, y::Dual{Ty}, z::$R) where {Tx, Ty} = Ty ≺ Tx ? $x_body : $y_body @@ -134,7 +134,7 @@ macro define_ternary_dual_op(f, xyz_body, xy_body, xz_body, yz_body, x_body, y_b @inline $(f)(x::$R, y::Dual{Ty}, z::Dual{Tz}) where {Ty,Tz} = Tz ≺ Ty ? $y_body : $z_body end append!(defs.args, expr.args) - for Q in REAL_TYPES + for Q in AMBIGUOUS_TYPES Q === R && continue expr = quote @inline $(f)(x::Dual{Tx}, y::$R, z::$Q) where {Tx} = $x_body diff --git a/src/prelude.jl b/src/prelude.jl index d3748fe4..4e3fa9b5 100644 --- a/src/prelude.jl +++ b/src/prelude.jl @@ -1,6 +1,6 @@ const NANSAFE_MODE_ENABLED = false -const REAL_TYPES = (AbstractFloat, Irrational, Integer, Rational, Real) +const AMBIGUOUS_TYPES = (AbstractFloat, Irrational, Integer, Rational, Real, RoundingMode) const UNARY_PREDICATES = Symbol[:isinf, :isnan, :isfinite, :iseven, :isodd, :isreal, :isinteger] diff --git a/test/DualTest.jl b/test/DualTest.jl index d2686a1d..4d6f1a8c 100644 --- a/test/DualTest.jl +++ b/test/DualTest.jl @@ -19,10 +19,9 @@ samerng() = MersenneTwister(1) intrand(V) = V == Int ? rand(2:10) : rand(V) dual_isapprox(a, b) = isapprox(a, b) -dual_isapprox(a::Dual{T,T1,T2}, b::Dual{T,T3,T4}) where {T,T1,T2,T3,T4} = - isapprox(value(a), value(b)) && isapprox(partials(a), partials(b)) -dual_isapprox(a::Dual{T,T1,T2}, b::Dual{T3,T4,T5}) where {T,T1,T2,T3,T4,T5} = - error("Tags don't match") +dual_isapprox(a::Dual{T,T1,T2}, b::Dual{T,T3,T4}) where {T,T1,T2,T3,T4} = isapprox(value(a), value(b)) && isapprox(partials(a), partials(b)) +dual_isapprox(a::Dual{T,T1,T2}, b::Dual{T3,T4,T5}) where {T,T1,T2,T3,T4,T5} = error("Tags don't match") + ForwardDiff.:≺(::Type{TestTag()}, ::Int) = true ForwardDiff.:≺(::Int, ::Type{TestTag()}) = false @@ -394,7 +393,7 @@ for N in (0,3), M in (0,4), V in (Int, Float32) if V != Int for (M, f, arity) in DiffRules.diffrules() - in(f, (:hankelh1, :hankelh1x, :hankelh2, :hankelh2x, :/)) && continue + in(f, (:hankelh1, :hankelh1x, :hankelh2, :hankelh2x, :/, :rem2pi)) && continue println(" ...auto-testing $(M).$(f) with $arity arguments") if arity == 1 deriv = DiffRules.diffrule(M, f, :x)