diff --git a/src/rules.jl b/src/rules.jl index 5267097..436ad13 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -165,6 +165,7 @@ hasdiffrule(f::Symbol, arity::Int) = in((f, arity), DEFINED_DIFFRULES) @define_diffrule hypot(x, y) = :($x / hypot($x, $y)), :($y / hypot($x, $y)) @define_diffrule mod(x, y) = :(first(promote(ifelse(isinteger($x / $y), NaN, 1), NaN))), :(z = $x / $y; first(promote(ifelse(isinteger(z), NaN, -floor(z)), NaN))) @define_diffrule rem(x, y) = :(first(promote(ifelse(isinteger($x / $y), NaN, 1), NaN))), :(z = $x / $y; first(promote(ifelse(isinteger(z), NaN, -trunc(z)), NaN))) +@define_diffrule rem2pi(x, r) = :(1), :NaN #################### # SpecialFunctions # diff --git a/test/RulesTests.jl b/test/RulesTests.jl index f874976..fa968b2 100644 --- a/test/RulesTests.jl +++ b/test/RulesTests.jl @@ -38,3 +38,17 @@ for f in vcat(RealInterface.BINARY_MATH, RealInterface.BINARY_ARITHMETIC, RealIn end end end + +# Treat rem2pi separately because of its non-numeric second argument: +derivs = DiffBase.diffrule(:rem2pi, :x, :y) +for xtype in [:Float64, :BigFloat, :Int64] + for mode in [:RoundUp, :RoundDown, :RoundToZero, :RoundNearest] + @eval begin + x = $xtype(rand(1 : 10)) + y = $mode + dx, dy = $(derivs[1]), $(derivs[2]) + @test isapprox(dx, finitediff(z -> rem2pi(z, y), float(x)), rtol=0.05) + @test isnan(dy) + end + end +end