diff --git a/NEWS.md b/NEWS.md index 863107aa8c..6868754ec1 100644 --- a/NEWS.md +++ b/NEWS.md @@ -11,6 +11,8 @@ See also [github's page](https://github.com/FluxML/Flux.jl/releases) for a compl * The `Flux.Optimise` module has been deprecated in favor of the Optimisers.jl package. Now Flux re-exports the optimisers from Optimisers.jl. Most users will be uneffected by this change. The module is still available for now, but will be removed in a future release. +* Further support for Enzyme.jl, via methods of `Flux.gradient`. + This still defaults to Zygote, but is no longer `=== Zygote.gradient`. ## v0.14.22 * Data movement between devices is now provided by [MLDataDevices.jl](https://github.com/LuxDL/MLDataDevices.jl). diff --git a/docs/src/reference/training/enzyme.md b/docs/src/reference/training/enzyme.md index 148edc90ee..68ba3d5e8a 100644 --- a/docs/src/reference/training/enzyme.md +++ b/docs/src/reference/training/enzyme.md @@ -3,7 +3,7 @@ [Enzyme.jl](https://github.com/EnzymeAD/Enzyme.jl) is a new package for automatic differentiation. Like Zygote.jl, calling `gradient(f, x)` causes it to hooks into the compiler and transform code that is executed while calculating `f(x)`, in order to produce code for `∂f/∂x`. -But it does so much later in the optimisation process (on LLVM instead of Julia's untyped IR). +But it does so much later in the optimisation process (on LLVM instead of Julia's untyped IR) which you can [read about here](https://proceedings.nips.cc/paper/2020/file/9332c513ef44b682e9347822c2e457ac-Paper.pdf)]. It needs far fewer custom rules than Zygote/ChainRules, and in particular is able to support mutation of arrays. Flux now builds in support for this, using Enzyme's own `Duplicated` type. @@ -30,16 +30,32 @@ julia> x1 = randn32(28*28, 1); # fake image julia> y1 = [i==3 for i in 0:9]; # fake label -julia> grads_f = Flux.gradient((m,x,y) -> sum(abs2, m(x) .- y), dup_model, Const(x1), Const(y1)) +julia> grads_f = Flux.gradient((m,x,y) -> sum(abs2, m(x) .- y), dup_model, Const(x1), Const(y1)) # uses Enzyme ((layers = ((weight = Float32[-0.010354728 0.032972857 … -0.0014538406], σ = nothing), nothing),), nothing, nothing) ``` -The gradient returned here is also stored within `dup_model`, it shares the same arrays. It will be set to zero when you call `gradient` again. +The gradient returned here is also stored within `dup_model`, it shares the same arrays. +They will all be set to zero when you call `gradient` again, then replaced with the new values. +Alternatively, `gradient(f, args...; zero=false)` will add the new gradient to what's already stored. Writing `Const(x1)` is optional, just plain `x1` is implicitly constant. Any set of `Duplicated` and `Const` arguments may appear in any order, so long as there is at least one `Duplicated`. +The gradient `grads_f[1]` can be passed to `update!` as usual. +But for convenience, you may also use what is stored within `Duplicated`. +These are equivalent ways to perform an update step: + +```julia +julia> opt_state = Flux.setup(Adam(), model) + +julia> ans == Flux.setup(Adam(), dup_model) + +julia> Flux.update!(opt_state, model, grads_f[1]) # exactly as for Zygote gradients + +julia> Flux.update!(opt_state, dup_model) # equivlent new path, Enzyme only +``` + Instead of using these FLux functions, you can also use Enzyme's own functions directly. `Enzyme.gradient` works like this: @@ -63,11 +79,10 @@ julia> Flux.train!((m,x,y) -> sum(abs2, m(x) .- y), dup_model, [(x1, y1)], opt_s ``` - ## Listing ```@docs -Flux.gradient(f, args::Union{EnzymeCore.Const, EnzymeCore.Duplicated}...) -Flux.withgradient(f, args::Union{EnzymeCore.Const, EnzymeCore.Duplicated}...) -Flux.train!(loss, model::EnzymeCore.Duplicated, data, opt) +Flux.gradient(f, args::Union{Flux.EnzymeCore.Const, Flux.EnzymeCore.Duplicated}...) +Flux.withgradient(f, args::Union{Flux.EnzymeCore.Const, Flux.EnzymeCore.Duplicated}...) +Flux.train!(loss, model::Flux.EnzymeCore.Duplicated, data, opt) ``` diff --git a/src/train.jl b/src/train.jl index 58b94ef39f..7c6dde483a 100644 --- a/src/train.jl +++ b/src/train.jl @@ -53,6 +53,11 @@ function setup(rule::Optimisers.AbstractRule, model) state end +""" + opt_state = setup(rule, model::Duplicated) = setup(rule, model.val) + +Special method for use with Enzyme.jl, ignores the stored gradient. +""" setup(rule::Optimisers.AbstractRule, model::Duplicated) = setup(rule, model.val) """