diff --git a/test/rulesets/Base/base.jl b/test/rulesets/Base/base.jl index 36452da1e..7d06b1421 100644 --- a/test/rulesets/Base/base.jl +++ b/test/rulesets/Base/base.jl @@ -153,6 +153,17 @@ test_rrule(muladd, 10randn(), randn(), randn()) end + @testset "muladd ZeroTangent" begin + test_frule(muladd, 2.0, 3.0, ZeroTangent()) + test_frule(muladd, 2.0, ZeroTangent(), 4.0) + test_frule(muladd, ZeroTangent(), 3.0, 4.0) + + test_rrule(muladd, 2.0, 3.0, ZeroTangent()) + test_rrule(muladd, 2.0, ZeroTangent(), 4.0) + test_rrule(muladd, ZeroTangent(), 3.0, 4.0) + end + + @testset "fma" begin test_frule(fma, 10randn(), randn(), randn()) test_rrule(fma, 10randn(), randn(), randn())