Skip to content

Commit

Permalink
another explicit error for Zygote mistake
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Nov 24, 2024
1 parent 55d24eb commit 7a54531
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 2 deletions.
9 changes: 7 additions & 2 deletions src/layers/macro.jl
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,13 @@ function _macro_enzyme(type)
# One-arg method Duplicated(m::Layer) which allocates & zeros the gradient:
$EnzymeCore.Duplicated(m::$type) = $EnzymeCore.Duplicated(m, $EnzymeCore.make_zero(m))

# Not sure we want this, but make Duplicated{<:Layer} callable?
(m::$EnzymeCore.Duplicated{<:$type})(xs...) = m.val(xs...)
# Make Duplicated{<:Layer} callable:
function (m::$EnzymeCore.Duplicated{<:$type})(xs...)
Zygote.isderiving() && error("""`Duplicated(flux_model)` is only for use with Enzyme.jl.
`Flux.gradient` should detect this, but calling `Zygote.gradient` directly on
such a wrapped model is not supported.""")
m.val(xs...)
end

# Not sure but this does prevent printing of 2nd copy:
$Optimisers.trainable(m::$EnzymeCore.Duplicated{<:$type}) = (; val = m.val)
Expand Down
3 changes: 3 additions & 0 deletions test/ext_enzyme/enzyme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -205,4 +205,7 @@ end
@test_throws ArgumentError Flux.gradient((m,z) -> sum(m.bias)/z, Const(m1.val), Active(3f0))
# Duplicated
@test_throws Exception Flux.gradient((m,z) -> sum(m.bias)/z, m1, Duplicated(3f0, 0f0))

# Using Duplicated within Zygote.gradient is not supported:
Zygote.gradient((m,x) -> sum(m(x)), m1, [1,2,3f0])
end

0 comments on commit 7a54531

Please sign in to comment.