Skip to content

Commit

Permalink
tweaks & tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Nov 5, 2024
1 parent 637fc86 commit 6310548
Show file tree
Hide file tree
Showing 5 changed files with 119 additions and 7 deletions.
68 changes: 66 additions & 2 deletions docs/src/reference/training/enzyme.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,73 @@

# [Automatic Differentiation using Enzyme.jl](@id autodiff-enzyme)

Flux now builds in support for Enzyme.jl
[Enzyme.jl](https://github.com/EnzymeAD/Enzyme.jl) is a new package for automatic differentiation.
Like Zygote.jl, calling `gradient(f, x)` causes it to hooks into the compiler and transform code that is executed while calculating `f(x)`, in order to produce code for `∂f/∂x`.
But it does so much later in the optimisation process (on LLVM instead of Julia's untyped IR).
It needs far fewer custom rules than Zygote/ChainRules, and in particular is able to support mutation of arrays.

Flux now builds in support for this, using Enzyme's own `Duplicated` type.
Calling `Duplicated` on any Flux model which was defined using `@layer` will allocate space for the gradient,
and passing that to `gradient` (or `withgradient`, or `train!`) will then use Enzyme instead of Zygote.
The gradient functions still return the gradient as usual, which can then be passed to `update!`:

```julia
julia> using Flux, Enzyme

julia> model = Chain(Dense(28^2 => 32, sigmoid), Dense(32 => 10), softmax); # from model zoo

julia> dup_model = Enzyme.Duplicated(model) # this allocates space for the gradient
Duplicated(
Chain(
Dense(784 => 32, σ), # 25_120 parameters
Dense(32 => 10), # 330 parameters
NNlib.softmax,
),
# norm(∇) ≈ 0.0f0
) # Total: 4 arrays, 25_450 parameters, 199.391 KiB.

julia> x1 = randn32(28*28, 1); # fake image

julia> y1 = [i==3 for i in 0:9]; # fake label

julia> grads_f = Flux.gradient((m,x,y) -> sum(abs2, m(x) .- y), dup_model, Const(x1), Const(y1))
((layers = ((weight = Float32[-0.010354728 0.032972857
-0.0014538406], σ = nothing), nothing),), nothing, nothing)
```

The gradient returned here is also stored within `dup_model`, it shares the same arrays. It will be set to zero when you call `gradient` again.

Writing `Const(x1)` is optional, just plain `x1` is implicitly constant.
Any set of `Duplicated` and `Const` arguments may appear in any order, so long as there is at least one `Duplicated`.

Instead of using these FLux functions, you can also use Enzyme's own functions directly.
`Enzyme.gradient` works like this:

```julia
julia> grads_e = Enzyme.gradient(Reverse, (m,x,y) -> sum(abs2, m(x) .- y), model, Const(x1), Const(y1))
(Chain(Dense(784 => 32, σ), Dense(32 => 10), softmax), nothing, nothing)

julia> grads_f[1].layers[2].bias grads_e[1].layers[2].bias
true
```

Note that what `Enzyme.gradient` returns is an object like `deepcopy(model)` of the same type, `grads_e[1] isa Chain`.
But its fields contain the same gradient.

There is also a method of `train!` which similarly takes `Duplicated(model)`:

```julia
julia> opt_state = Flux.setup(Adam(0), model);

julia> Flux.train!((m,x,y) -> sum(abs2, m(x) .- y), dup_model, [(x1, y1)], opt_state)
```



## Listing

```@docs
gradient(f, args::Union{EnzymeCore.Const, EnzymeCore.Duplicated}...)
Flux.gradient(f, args::Union{EnzymeCore.Const, EnzymeCore.Duplicated}...)
Flux.withgradient(f, args::Union{EnzymeCore.Const, EnzymeCore.Duplicated}...)
Flux.train!(loss, model::EnzymeCore.Duplicated, data, opt)
```
3 changes: 1 addition & 2 deletions ext/FluxEnzymeExt/FluxEnzymeExt.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
module FluxEnzymeExt

using Flux
using Flux: _make_zero!
import Flux.Train: train!, _rule_to_state
import Flux.Train: _make_zero!, _enzyme_train!, _rule_to_state
import Flux.Optimise
import Optimisers
import Enzyme
Expand Down
26 changes: 24 additions & 2 deletions src/gradient.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@ function gradient(f, args...; zero::Bool=true)
for a in args
a isa EnzymeCore.Duplicated && return _enzyme_gradient(f, map(_ensure_enzyme, args)...; zero)
end
for a in args
a isa EnzymeCore.Const && throw(ArgumentError(
"The method `gradient(f, xs...)` using Enzyme.jl requires at least one `Duplicated` argument, not just `Const`."
))
end
Zygote.gradient(f, args...)
end

Expand Down Expand Up @@ -97,8 +102,14 @@ julia> Flux.gradient(dup_model, [1]; zero=false) do m, x # implict Const([1]),
"""
gradient(f, args::Union{EnzymeCore.Const, EnzymeCore.Duplicated}...; zero::Bool=true) = _enzyme_gradient(f, args...; zero)

gradient(f, args::EnzymeCore.Const...; zero::Bool=true) = throw(ArgumentError(
"The method `gradient(f, xs...)` using Enzyme.jl requires at least one `Duplicated` argument, not just `Const`."
))

# FluxEnzymeExt defines more specific _enzyme_gradient(f, args::Union{Const, Duplicated}...; zero)
_enzyme_gradient(f, args...; zero) = error("methods like `gradient(f, x::Duplicated)` are only available when Enzyme is loaded.")
_enzyme_gradient(f, args...; zero) = throw(ArgumentError(
"Methods like `gradient(f, x::Duplicated)` are only available when Enzyme is loaded."
))


"""
Expand Down Expand Up @@ -140,6 +151,11 @@ function withgradient(f, args...; zero::Bool=true)
for a in args
a isa EnzymeCore.Duplicated && return _enzyme_withgradient(f, map(_ensure_enzyme, args)...; zero)
end
for a in args
a isa EnzymeCore.Const && throw(ArgumentError(
"The method `withgradient(f, xs...)` using Enzyme.jl requires at least one `Duplicated` argument, not just `Const`."
))
end
Zygote.withgradient(f, args...)
end

Expand Down Expand Up @@ -172,5 +188,11 @@ julia> Flux.withgradient(m -> m(3), Duplicated(model)) # this uses Enzyme
"""
withgradient(f, args::Union{EnzymeCore.Const, EnzymeCore.Duplicated}...; zero::Bool=true) = _enzyme_withgradient(f, args...; zero)

withgradient(f, args::EnzymeCore.Const...; zero::Bool=true) = throw(ArgumentError(
"The method `withgradient(f, xs...)` using Enzyme.jl requires at least one `Duplicated` argument, not just `Const`."
))

# FluxEnzymeExt defines more specific _enzyme_withgradient(f, args::Union{Const, Duplicated}...; zero)
_enzyme_withgradient(f, args...; zero) = error("methods like `withgradient(f, x::Duplicated)` are only available when Enzyme is loaded.")
_enzyme_withgradient(f, args...; zero) = throw(ArgumentError(
"Methods like `withgradient(f, x::Duplicated)` are only available when Enzyme is loaded."
))
4 changes: 3 additions & 1 deletion src/train.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ function setup(rule::Optimisers.AbstractRule, model)
state
end

setup(rule, model::Duplicated) = setup(rule, model.val)

"""
train!(loss, model, data, opt_state)
Expand Down Expand Up @@ -150,7 +152,7 @@ Only available when Enzyme is loaded.
train!(loss, model::Duplicated, data, opt; cb = nothing) = _enzyme_train!(loss, model, data, opt; cb = nothing)

# FluxEnzymeExt defines more specific _enzyme_train!(loss, model::Duplicated, data, opt; cb)
_enzyme_train!(loss, model, data, opt; cb = nothing) = error("The method `train!(loss, Duplicated(model), data, opt_state)` is only available when Enzyme.jl is loaded")
_enzyme_train!(loss, model, data, opt; cb = nothing) = throw(ArgumentError("The method `train!(loss, Duplicated(model), data, opt_state)` is only available when Enzyme.jl is loaded"))

# Following src/deprecations.jl
function train!(loss, model::Duplicated, data, opt::Optimise.AbstractOptimiser; cb=nothing)
Expand Down
25 changes: 25 additions & 0 deletions test/ext_enzyme/enzyme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -146,3 +146,28 @@ end
end
end
end

@testset "gradient, withgradient, Duplicated" begin
# Tests above are about how Enzyme digests Flux layers.
# Tests here are just the interface Flux.gradient(f, Duplicated(model)) etc.
m1 = Duplicated(Dense(3=>2))
@test m1 isa Duplicated
g1 = Flux.gradient(m -> sum(m.bias), m1) |> only
@test iszero(g1.weight)
@test g1.bias == [1, 1]
@test m1.dval.bias == [1, 1]

g2 = Flux.withgradient((m,x) -> sum(m(x)), m1, [1,2,3f0])
@test g2.val sum(m1([1,2,3f0]))
@test g2.grad[1].weight [1 2 3; 1 2 3]
@test g2.grad[2] === nothing # implicitly Const

# setup understands Duplicated:
@test Flux.setup(Adam(), m1) == Flux.setup(Adam(), m1.val)

# At least one Duplicated is required:
@test_throws ArgumentError Flux.gradient(m -> sum(m.bias), Const(m1))
@test_throws ArgumentError Flux.gradient((m,x) -> sum(m(x)), Const(m1), [1,2,3f0])
@test_throws ArgumentError Flux.withgradient(m -> sum(m.bias), Const(m1))
@test_throws ArgumentError Flux.withgradient((m,x) -> sum(m(x)), Const(m1), [1,2,3f0])
end

0 comments on commit 6310548

Please sign in to comment.