Skip to content

Commit

Permalink
let Flux own the function update! to avoid piracy
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Nov 8, 2024
1 parent 7a1483c commit ca5a20f
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 34 deletions.
30 changes: 2 additions & 28 deletions ext/FluxEnzymeExt/FluxEnzymeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module FluxEnzymeExt

using Flux
using Flux: _make_zero!
import Flux.Train: _enzyme_train!, _rule_to_state
import Flux.Train: _enzyme_train!, _rule_to_state, _grad_or_nothing
import Flux.Optimise
import Optimisers
import Enzyme
Expand All @@ -21,11 +21,6 @@ function Flux._enzyme_gradient(f, args::Union{Const, Duplicated}...; zero::Bool=
map(_grad_or_nothing, args)

Check warning on line 21 in ext/FluxEnzymeExt/FluxEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/FluxEnzymeExt/FluxEnzymeExt.jl#L16-L21

Added lines #L16 - L21 were not covered by tests
end

# This function strips the returned gradient to be Zygote-like:
_grad_or_nothing(dup::Duplicated) = Flux.fmapstructure(_grad_or_nothing, dup.dval; prune=nothing)
_grad_or_nothing(::Const) = nothing
_grad_or_nothing(x) = Optimisers.isnumeric(x) ? x : nothing

function Flux._enzyme_withgradient(f, args::Union{Const, Duplicated}...; zero::Bool=true)
for x in args
zero && x isa Duplicated && _make_zero!(x.dval)
Expand All @@ -52,31 +47,10 @@ function _enzyme_train!(loss, model::Duplicated, data, opt; cb = nothing)
if !isfinite(l)
throw(DomainError(lazy"Loss is $l on data item $i, stopping training"))
end
opt, model2 = Optimisers.update!(opt, model.val, model.dval)
model = Duplicated(model2, model.dval)
Flux.update!(opt, model)

Check warning on line 50 in ext/FluxEnzymeExt/FluxEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/FluxEnzymeExt/FluxEnzymeExt.jl#L50

Added line #L50 was not covered by tests

@logprogress Base.haslength(data) ? i/length(data) : nothing
end
end


### Optimisers.update!, piracy, for now!

"""
Flux.update!(opt_state, model::Duplicated)
Method of `update!` for use with Enzyme, and in particular with `gradient(loss, Duplicated(model))`.
Since `Duplicated(model)` stores the gradient, `update!` can read it & update the model itself,
by calling `Flux.update!(opt_state, model.val, model.dval)`.
!!! warning "Experimental"
Enzyme support like this is new and somewhat experimental.
This method is piracy, and must either move to Optimisers.jl
or else Flux should own this function, and fall back to Optimisers.
"""
function Flux.update!(opt_state, model::Duplicated)
Flux.update!(opt_state, model.val, _grad_or_nothing(model))
model
end

end # FluxEnzymeExt
5 changes: 5 additions & 0 deletions src/gradient.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ With the keyword `zero=false`, the new gradient will instead be added to what is
!!! warning "Experimental"
Enzyme support like this is new and somewhat experimental.
This method was added in Flux 0.15.
# Example
```
Expand Down Expand Up @@ -169,6 +170,10 @@ Only available when Enzyme is loaded!
Does not at present allow `f` to return a tuple of `(loss, aux)` the way `Zygote.withgradient` does.
!!! warning "Experimental"
Enzyme support like this is new and somewhat experimental.
This method was added in Flux 0.15.
# Example
```
Expand Down
6 changes: 3 additions & 3 deletions src/optimise/train.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
using ProgressLogging: @progress, @withprogress, @logprogress
import Zygote: Params, gradient, withgradient

# Add methods to Optimisers.jl's function, so that there is just one Flux.update!
# for both explicit and implicit parameters.
import Optimisers.update!
# Flux 0.13 and 0.14 used Optimisers.jl's function, but Flux 0.15 owns it:
# import Optimisers.update!
function update! end

"""
update!(opt, p, g)
Expand Down
103 changes: 100 additions & 3 deletions src/train.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@ module Train
using LinearAlgebra
using Optimisers: Optimisers
using Functors: fmap, fmapstructure
using ..Flux: Flux # used only in docstring
import ..Flux.Optimise: train!, update!, Optimise # during 0.13, we add methods to the old functions

export setup, train!
using ..Flux: Flux # used only in docstring

export setup, train!, update!

using ProgressLogging: @progress, @withprogress, @logprogress
using Zygote: Zygote, Params
using EnzymeCore: Duplicated
using EnzymeCore: EnzymeCore, Duplicated, Const

"""
opt_state = setup(rule, model)
Expand Down Expand Up @@ -169,4 +170,100 @@ function train!(loss, model::Duplicated, data, rule::Optimisers.AbstractRule; cb
train!(loss, model, data, _rule_to_state(model, rule); cb)

Check warning on line 170 in src/train.jl

View check run for this annotation

Codecov / codecov/patch

src/train.jl#L169-L170

Added lines #L169 - L170 were not covered by tests
end

""""
update!(opt_state, model, grad)
Uses the optimiser and the gradient to change the trainable parameters in the model.
The optimisers state comes from `setup(rule, model)`, and the gradient from `grad = gradient(loss, model, args...)[1]`.
This is a version of `Optimisers.update!`, which differs in that it returns `nothing`.
It also differs in having a method which accepts a model and gradient packaged together as `Duplicated(model, grad)`,
which is convenient for use with Enzyme.
# Example
```jldoctest
julia> model = Chain(Embedding([1;;2;;3.0;;]), Dense([4;-5.0;;], true, relu))
Chain(
Embedding(3 => 1), # 3 parameters
Dense(1 => 2, relu), # 4 parameters
) # Total: 3 arrays, 7 parameters, 216 bytes.
julia> opt_state = Flux.setup(Momentum(1/9), model)
(layers = ((weight = Leaf(Momentum(0.111111, 0.9), [0.0 0.0 0.0]),), (weight = Leaf(Momentum(0.111111, 0.9), [0.0; 0.0;;]), bias = Leaf(Momentum(0.111111, 0.9), [0.0, 0.0]), σ = ())),)
julia> val, grads = Flux.withgradient(m -> first(m(2)), model)
(val = 8.0, grad = ((layers = ((weight = [0.0 4.0 0.0],), (weight = [2.0; 0.0;;], bias = [1.0, 0.0], σ = nothing)),),))
julia> Flux.update!(opt_state, model, grads[1]);
julia> round.(model(2); digits=3) # has changed! Compare val = 8.0
2-element Vector{Float64}:
5.765
0.0
julia> opt_state # has also changed
(layers = ((weight = Leaf(Momentum(0.111111, 0.9), [0.0 0.444444 0.0]),), (weight = Leaf(Momentum(0.111111, 0.9), [0.222222; 0.0;;]), bias = Leaf(Momentum(0.111111, 0.9), [0.111111, 0.0]), σ = ())),)
```
"""
update!(opt_state, model, grad) = Optimisers.update!(opt_state, model, grad)

"""
update!(opt_state, model::Duplicated)
Method of `update!` for use with Enzyme, and in particular with `gradient(loss, Duplicated(model))`.
Since `Duplicated(model)` stores the gradient, `update!` can read it & update the model itself.
Approximately equivalent to calling `Flux.update!(opt_state, model.val, model.dval)`,
but more careful about shared parameters.
!!! warning "Experimental"
Enzyme support like this is new and somewhat experimental.
This method was added in Flux 0.15.
# Example
```julia
julia> using Flux, Enzyme
julia> dup_model = Chain(Embedding([1;;2;;3.0;;]), Dense([4;-5.0;;], true, relu)) |> Duplicated
Duplicated(
Chain(
Embedding(3 => 1), # 3 parameters
Dense(1 => 2, relu), # 4 parameters
),
# norm(∇) ≈ 0.0
) # Total: 3 arrays, 7 parameters, 216 bytes.
julia> dup_model(2)
2-element Vector{Float64}:
8.0
0.0
julia> opt_state = Flux.setup(Momentum(1/9), dup_model)
(layers = ((weight = Leaf(Momentum(0.111111, 0.9), [0.0 0.0 0.0]),), (weight = Leaf(Momentum(0.111111, 0.9), [0.0; 0.0;;]), bias = Leaf(Momentum(0.111111, 0.9), [0.0, 0.0]), σ = ())),)
julia> Flux.gradient(m -> first(m(2)), dup_model); # updates gradient within Duplicated
julia> Flux.update!(opt_state, dup_model);
julia> round.(dup_model(2); digits=3) # has changed! Compare val = 8.0
2-element Vector{Float64}:
5.765
0.0
julia> opt_state # has also changed
(layers = ((weight = Leaf(Momentum(0.111111, 0.9), [0.0 0.444444 0.0]),), (weight = Leaf(Momentum(0.111111, 0.9), [0.222222; 0.0;;]), bias = Leaf(Momentum(0.111111, 0.9), [0.111111, 0.0]), σ = ())),)
```
"""
function update!(opt_state, model::Duplicated)
update!(opt_state, model.val, _grad_or_nothing(model))
model

Check warning on line 261 in src/train.jl

View check run for this annotation

Codecov / codecov/patch

src/train.jl#L259-L261

Added lines #L259 - L261 were not covered by tests
end

# This function strips the returned gradient to be Zygote-like:
_grad_or_nothing(dup::Duplicated) = Flux.fmapstructure(_grad_or_nothing, dup.dval; prune=nothing)
_grad_or_nothing(::Const) = nothing
_grad_or_nothing(x) = Optimisers.isnumeric(x) ? x : nothing

Check warning on line 267 in src/train.jl

View check run for this annotation

Codecov / codecov/patch

src/train.jl#L265-L267

Added lines #L265 - L267 were not covered by tests

end # module Train

0 comments on commit ca5a20f

Please sign in to comment.