From ca5a20f24bcd60560f4961da3646b8f44852fb31 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Tue, 5 Nov 2024 20:38:46 -0500 Subject: [PATCH] let Flux own the function update! to avoid piracy --- ext/FluxEnzymeExt/FluxEnzymeExt.jl | 30 +-------- src/gradient.jl | 5 ++ src/optimise/train.jl | 6 +- src/train.jl | 103 ++++++++++++++++++++++++++++- 4 files changed, 110 insertions(+), 34 deletions(-) diff --git a/ext/FluxEnzymeExt/FluxEnzymeExt.jl b/ext/FluxEnzymeExt/FluxEnzymeExt.jl index 3c807b02e2..eff45bb95c 100644 --- a/ext/FluxEnzymeExt/FluxEnzymeExt.jl +++ b/ext/FluxEnzymeExt/FluxEnzymeExt.jl @@ -2,7 +2,7 @@ module FluxEnzymeExt using Flux using Flux: _make_zero! -import Flux.Train: _enzyme_train!, _rule_to_state +import Flux.Train: _enzyme_train!, _rule_to_state, _grad_or_nothing import Flux.Optimise import Optimisers import Enzyme @@ -21,11 +21,6 @@ function Flux._enzyme_gradient(f, args::Union{Const, Duplicated}...; zero::Bool= map(_grad_or_nothing, args) end -# This function strips the returned gradient to be Zygote-like: -_grad_or_nothing(dup::Duplicated) = Flux.fmapstructure(_grad_or_nothing, dup.dval; prune=nothing) -_grad_or_nothing(::Const) = nothing -_grad_or_nothing(x) = Optimisers.isnumeric(x) ? x : nothing - function Flux._enzyme_withgradient(f, args::Union{Const, Duplicated}...; zero::Bool=true) for x in args zero && x isa Duplicated && _make_zero!(x.dval) @@ -52,31 +47,10 @@ function _enzyme_train!(loss, model::Duplicated, data, opt; cb = nothing) if !isfinite(l) throw(DomainError(lazy"Loss is $l on data item $i, stopping training")) end - opt, model2 = Optimisers.update!(opt, model.val, model.dval) - model = Duplicated(model2, model.dval) + Flux.update!(opt, model) @logprogress Base.haslength(data) ? i/length(data) : nothing end end - -### Optimisers.update!, piracy, for now! - -""" - Flux.update!(opt_state, model::Duplicated) - -Method of `update!` for use with Enzyme, and in particular with `gradient(loss, Duplicated(model))`. -Since `Duplicated(model)` stores the gradient, `update!` can read it & update the model itself, -by calling `Flux.update!(opt_state, model.val, model.dval)`. - -!!! warning "Experimental" - Enzyme support like this is new and somewhat experimental. - This method is piracy, and must either move to Optimisers.jl - or else Flux should own this function, and fall back to Optimisers. -""" -function Flux.update!(opt_state, model::Duplicated) - Flux.update!(opt_state, model.val, _grad_or_nothing(model)) - model -end - end # FluxEnzymeExt diff --git a/src/gradient.jl b/src/gradient.jl index 0c18dccc64..71d7ea590f 100644 --- a/src/gradient.jl +++ b/src/gradient.jl @@ -62,6 +62,7 @@ With the keyword `zero=false`, the new gradient will instead be added to what is !!! warning "Experimental" Enzyme support like this is new and somewhat experimental. + This method was added in Flux 0.15. # Example ``` @@ -169,6 +170,10 @@ Only available when Enzyme is loaded! Does not at present allow `f` to return a tuple of `(loss, aux)` the way `Zygote.withgradient` does. +!!! warning "Experimental" + Enzyme support like this is new and somewhat experimental. + This method was added in Flux 0.15. + # Example ``` diff --git a/src/optimise/train.jl b/src/optimise/train.jl index 7bd3f9b277..e6dd10a5d6 100644 --- a/src/optimise/train.jl +++ b/src/optimise/train.jl @@ -1,9 +1,9 @@ using ProgressLogging: @progress, @withprogress, @logprogress import Zygote: Params, gradient, withgradient -# Add methods to Optimisers.jl's function, so that there is just one Flux.update! -# for both explicit and implicit parameters. -import Optimisers.update! +# Flux 0.13 and 0.14 used Optimisers.jl's function, but Flux 0.15 owns it: +# import Optimisers.update! +function update! end """ update!(opt, p, g) diff --git a/src/train.jl b/src/train.jl index 7c6dde483a..e21a00cbe6 100644 --- a/src/train.jl +++ b/src/train.jl @@ -3,14 +3,15 @@ module Train using LinearAlgebra using Optimisers: Optimisers using Functors: fmap, fmapstructure -using ..Flux: Flux # used only in docstring import ..Flux.Optimise: train!, update!, Optimise # during 0.13, we add methods to the old functions -export setup, train! +using ..Flux: Flux # used only in docstring + +export setup, train!, update! using ProgressLogging: @progress, @withprogress, @logprogress using Zygote: Zygote, Params -using EnzymeCore: Duplicated +using EnzymeCore: EnzymeCore, Duplicated, Const """ opt_state = setup(rule, model) @@ -169,4 +170,100 @@ function train!(loss, model::Duplicated, data, rule::Optimisers.AbstractRule; cb train!(loss, model, data, _rule_to_state(model, rule); cb) end +"""" + update!(opt_state, model, grad) + +Uses the optimiser and the gradient to change the trainable parameters in the model. +The optimisers state comes from `setup(rule, model)`, and the gradient from `grad = gradient(loss, model, args...)[1]`. + +This is a version of `Optimisers.update!`, which differs in that it returns `nothing`. +It also differs in having a method which accepts a model and gradient packaged together as `Duplicated(model, grad)`, +which is convenient for use with Enzyme. + +# Example + +```jldoctest +julia> model = Chain(Embedding([1;;2;;3.0;;]), Dense([4;-5.0;;], true, relu)) +Chain( + Embedding(3 => 1), # 3 parameters + Dense(1 => 2, relu), # 4 parameters +) # Total: 3 arrays, 7 parameters, 216 bytes. + +julia> opt_state = Flux.setup(Momentum(1/9), model) +(layers = ((weight = Leaf(Momentum(0.111111, 0.9), [0.0 0.0 0.0]),), (weight = Leaf(Momentum(0.111111, 0.9), [0.0; 0.0;;]), bias = Leaf(Momentum(0.111111, 0.9), [0.0, 0.0]), σ = ())),) + +julia> val, grads = Flux.withgradient(m -> first(m(2)), model) +(val = 8.0, grad = ((layers = ((weight = [0.0 4.0 0.0],), (weight = [2.0; 0.0;;], bias = [1.0, 0.0], σ = nothing)),),)) + +julia> Flux.update!(opt_state, model, grads[1]); + +julia> round.(model(2); digits=3) # has changed! Compare val = 8.0 +2-element Vector{Float64}: + 5.765 + 0.0 + +julia> opt_state # has also changed +(layers = ((weight = Leaf(Momentum(0.111111, 0.9), [0.0 0.444444 0.0]),), (weight = Leaf(Momentum(0.111111, 0.9), [0.222222; 0.0;;]), bias = Leaf(Momentum(0.111111, 0.9), [0.111111, 0.0]), σ = ())),) +``` + +""" +update!(opt_state, model, grad) = Optimisers.update!(opt_state, model, grad) + +""" + update!(opt_state, model::Duplicated) + +Method of `update!` for use with Enzyme, and in particular with `gradient(loss, Duplicated(model))`. +Since `Duplicated(model)` stores the gradient, `update!` can read it & update the model itself. +Approximately equivalent to calling `Flux.update!(opt_state, model.val, model.dval)`, +but more careful about shared parameters. + +!!! warning "Experimental" + Enzyme support like this is new and somewhat experimental. + This method was added in Flux 0.15. + +# Example + +```julia +julia> using Flux, Enzyme + +julia> dup_model = Chain(Embedding([1;;2;;3.0;;]), Dense([4;-5.0;;], true, relu)) |> Duplicated +Duplicated( + Chain( + Embedding(3 => 1), # 3 parameters + Dense(1 => 2, relu), # 4 parameters + ), + # norm(∇) ≈ 0.0 +) # Total: 3 arrays, 7 parameters, 216 bytes. + +julia> dup_model(2) +2-element Vector{Float64}: + 8.0 + 0.0 + +julia> opt_state = Flux.setup(Momentum(1/9), dup_model) +(layers = ((weight = Leaf(Momentum(0.111111, 0.9), [0.0 0.0 0.0]),), (weight = Leaf(Momentum(0.111111, 0.9), [0.0; 0.0;;]), bias = Leaf(Momentum(0.111111, 0.9), [0.0, 0.0]), σ = ())),) + +julia> Flux.gradient(m -> first(m(2)), dup_model); # updates gradient within Duplicated + +julia> Flux.update!(opt_state, dup_model); + +julia> round.(dup_model(2); digits=3) # has changed! Compare val = 8.0 +2-element Vector{Float64}: + 5.765 + 0.0 + +julia> opt_state # has also changed +(layers = ((weight = Leaf(Momentum(0.111111, 0.9), [0.0 0.444444 0.0]),), (weight = Leaf(Momentum(0.111111, 0.9), [0.222222; 0.0;;]), bias = Leaf(Momentum(0.111111, 0.9), [0.111111, 0.0]), σ = ())),) +``` +""" +function update!(opt_state, model::Duplicated) + update!(opt_state, model.val, _grad_or_nothing(model)) + model +end + +# This function strips the returned gradient to be Zygote-like: +_grad_or_nothing(dup::Duplicated) = Flux.fmapstructure(_grad_or_nothing, dup.dval; prune=nothing) +_grad_or_nothing(::Const) = nothing +_grad_or_nothing(x) = Optimisers.isnumeric(x) ? x : nothing + end # module Train