Skip to content

Commit

Permalink
news, docs
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Nov 6, 2024
1 parent 557ad49 commit ea86a2a
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 7 deletions.
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ See also [github's page](https://github.com/FluxML/Flux.jl/releases) for a compl
* The `Flux.Optimise` module has been deprecated in favor of the Optimisers.jl package.
Now Flux re-exports the optimisers from Optimisers.jl. Most users will be uneffected by this change.
The module is still available for now, but will be removed in a future release.
* Further support for Enzyme.jl, via methods of `Flux.gradient`.
This still defaults to Zygote, but is no longer `=== Zygote.gradient`.

## v0.14.22
* Data movement between devices is now provided by [MLDataDevices.jl](https://github.com/LuxDL/MLDataDevices.jl).
Expand Down
29 changes: 22 additions & 7 deletions docs/src/reference/training/enzyme.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

[Enzyme.jl](https://github.com/EnzymeAD/Enzyme.jl) is a new package for automatic differentiation.
Like Zygote.jl, calling `gradient(f, x)` causes it to hooks into the compiler and transform code that is executed while calculating `f(x)`, in order to produce code for `∂f/∂x`.
But it does so much later in the optimisation process (on LLVM instead of Julia's untyped IR).
But it does so much later in the optimisation process (on LLVM instead of Julia's untyped IR) which you can [read about here](https://proceedings.nips.cc/paper/2020/file/9332c513ef44b682e9347822c2e457ac-Paper.pdf)].
It needs far fewer custom rules than Zygote/ChainRules, and in particular is able to support mutation of arrays.

Flux now builds in support for this, using Enzyme's own `Duplicated` type.
Expand All @@ -30,16 +30,32 @@ julia> x1 = randn32(28*28, 1); # fake image

julia> y1 = [i==3 for i in 0:9]; # fake label

julia> grads_f = Flux.gradient((m,x,y) -> sum(abs2, m(x) .- y), dup_model, Const(x1), Const(y1))
julia> grads_f = Flux.gradient((m,x,y) -> sum(abs2, m(x) .- y), dup_model, Const(x1), Const(y1)) # uses Enzyme
((layers = ((weight = Float32[-0.010354728 0.032972857
-0.0014538406], σ = nothing), nothing),), nothing, nothing)
```

The gradient returned here is also stored within `dup_model`, it shares the same arrays. It will be set to zero when you call `gradient` again.
The gradient returned here is also stored within `dup_model`, it shares the same arrays.
They will all be set to zero when you call `gradient` again, then replaced with the new values.
Alternatively, `gradient(f, args...; zero=false)` will add the new gradient to what's already stored.

Writing `Const(x1)` is optional, just plain `x1` is implicitly constant.
Any set of `Duplicated` and `Const` arguments may appear in any order, so long as there is at least one `Duplicated`.

The gradient `grads_f[1]` can be passed to `update!` as usual.
But for convenience, you may also use what is stored within `Duplicated`.
These are equivalent ways to perform an update step:

```julia
julia> opt_state = Flux.setup(Adam(), model)

julia> ans == Flux.setup(Adam(), dup_model)

julia> Flux.update!(opt_state, model, grads_f[1]) # exactly as for Zygote gradients

julia> Flux.update!(opt_state, dup_model) # equivlent new path, Enzyme only
```

Instead of using these FLux functions, you can also use Enzyme's own functions directly.
`Enzyme.gradient` works like this:

Expand All @@ -63,11 +79,10 @@ julia> Flux.train!((m,x,y) -> sum(abs2, m(x) .- y), dup_model, [(x1, y1)], opt_s
```



## Listing

```@docs
Flux.gradient(f, args::Union{EnzymeCore.Const, EnzymeCore.Duplicated}...)
Flux.withgradient(f, args::Union{EnzymeCore.Const, EnzymeCore.Duplicated}...)
Flux.train!(loss, model::EnzymeCore.Duplicated, data, opt)
Flux.gradient(f, args::Union{Flux.EnzymeCore.Const, Flux.EnzymeCore.Duplicated}...)
Flux.withgradient(f, args::Union{Flux.EnzymeCore.Const, Flux.EnzymeCore.Duplicated}...)
Flux.train!(loss, model::Flux.EnzymeCore.Duplicated, data, opt)
```
5 changes: 5 additions & 0 deletions src/train.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,11 @@ function setup(rule::Optimisers.AbstractRule, model)
state
end

"""
opt_state = setup(rule, model::Duplicated) = setup(rule, model.val)
Special method for use with Enzyme.jl, ignores the stored gradient.
"""
setup(rule::Optimisers.AbstractRule, model::Duplicated) = setup(rule, model.val)

"""
Expand Down

0 comments on commit ea86a2a

Please sign in to comment.