From 7a5453105ead8354607e4b884485cc7e4bb5b9a9 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sun, 24 Nov 2024 10:53:13 -0500 Subject: [PATCH] another explicit error for Zygote mistake --- src/layers/macro.jl | 9 +++++++-- test/ext_enzyme/enzyme.jl | 3 +++ 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/src/layers/macro.jl b/src/layers/macro.jl index a650bb4e84..52e8b35d28 100644 --- a/src/layers/macro.jl +++ b/src/layers/macro.jl @@ -113,8 +113,13 @@ function _macro_enzyme(type) # One-arg method Duplicated(m::Layer) which allocates & zeros the gradient: $EnzymeCore.Duplicated(m::$type) = $EnzymeCore.Duplicated(m, $EnzymeCore.make_zero(m)) - # Not sure we want this, but make Duplicated{<:Layer} callable? - (m::$EnzymeCore.Duplicated{<:$type})(xs...) = m.val(xs...) + # Make Duplicated{<:Layer} callable: + function (m::$EnzymeCore.Duplicated{<:$type})(xs...) + Zygote.isderiving() && error("""`Duplicated(flux_model)` is only for use with Enzyme.jl. + `Flux.gradient` should detect this, but calling `Zygote.gradient` directly on + such a wrapped model is not supported.""") + m.val(xs...) + end # Not sure but this does prevent printing of 2nd copy: $Optimisers.trainable(m::$EnzymeCore.Duplicated{<:$type}) = (; val = m.val) diff --git a/test/ext_enzyme/enzyme.jl b/test/ext_enzyme/enzyme.jl index c6107cb117..3e7447fe6e 100644 --- a/test/ext_enzyme/enzyme.jl +++ b/test/ext_enzyme/enzyme.jl @@ -205,4 +205,7 @@ end @test_throws ArgumentError Flux.gradient((m,z) -> sum(m.bias)/z, Const(m1.val), Active(3f0)) # Duplicated @test_throws Exception Flux.gradient((m,z) -> sum(m.bias)/z, m1, Duplicated(3f0, 0f0)) + + # Using Duplicated within Zygote.gradient is not supported: + Zygote.gradient((m,x) -> sum(m(x)), m1, [1,2,3f0]) end