From ee22f3ac829f46231874f1b196d373d486fc4c6d Mon Sep 17 00:00:00 2001 From: Nicholas Bauer Date: Sun, 15 Sep 2024 21:38:11 -0400 Subject: [PATCH] Add tests --- test/chainrules.jl | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/test/chainrules.jl b/test/chainrules.jl index 7e55720de..00fd1b0af 100644 --- a/test/chainrules.jl +++ b/test/chainrules.jl @@ -420,3 +420,15 @@ end @test z2d_compiled.c.a === z2d_fallback.c.a @test z2d_compiled.c.b === z2d_fallback.c.b end + +@testset "ChainRules translation" begin + @test Zygote.wrap_chainrules_input(nothing) == ChainRules.ZeroTangent() + @test Zygote.wrap_chainrules_input((nothing,)) == ChainRules.ZeroTangent() + @test Zygote.wrap_chainrules_input([nothing]) == ChainRules.ZeroTangent() + @test Zygote.wrap_chainrules_input(((1.0, 2.0), 3.0)) == ChainRules.Tangent{Any}(ChainRules.Tangent{Any}(1.0, 2.0), 3.0) + @test Zygote.wrap_chainrules_input((; a = 1.0, b = 2.0)) == ChainRules.Tangent{Any}(a = 1.0, b = 2.0) + @test Zygote.wrap_chainrules_input(Ref(1)) == 1 + @test Zygote.wrap_chainrules_input([2.0; 4.0]) == [2.0; 4.0] + @test Zygote.wrap_chainrules_input([[2.0; 4.0], [1.0; 3.0]]) == [[2.0; 4.0], [1.0; 3.0]] + @test Zygote.wrap_chainrules_input([nothing; 4.0]) == [0.0; 4.0] # ChainRules uses the numeric zero where possible +end