-
-
Notifications
You must be signed in to change notification settings - Fork 612
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add explicit train!
, unify update!
, and auto-translate the two Adam
s
#2082
Changes from 9 commits
c20fb9e
9c22c11
fa022b3
4e937df
36a2bb5
28a7718
5d74b04
7eaf3ea
f3e1559
63ad543
2bd0dad
20326ea
db2a9b9
c617807
0389de3
732fa13
d9699c0
db7ad43
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,142 @@ | ||
module Train | ||
|
||
using LinearAlgebra | ||
using Optimisers: Optimisers | ||
using Functors: fmap | ||
|
||
import ..Flux.Optimise: train!, update! # during 0.13, we add methods to the old functions | ||
|
||
export setup, train! | ||
|
||
using ProgressLogging: @progress, @withprogress, @logprogress | ||
using Zygote: Zygote, Params | ||
|
||
""" | ||
opt = setup(rule, model) | ||
|
||
This is a version of `Optimisers.setup`, and is the first step before using [`train!`](@ref Flux.train!). | ||
It differs from `Optimisers.setup` in that it: | ||
* has one extra check for mutability | ||
* has methods which accept Flux's old optimisers, and convert them. | ||
|
||
# Example | ||
```jldoctest | ||
julia> model = Dense(2=>1, leakyrelu; init=Flux.ones32); | ||
|
||
julia> opt = Flux.setup(Momentum(0.1), model) # this encodes the optimiser and its state | ||
(weight = Leaf(Momentum{Float64}(0.1, 0.9), Float32[0.0 0.0]), bias = Leaf(Momentum{Float64}(0.1, 0.9), Float32[0.0]), σ = ()) | ||
|
||
julia> x1, y1 = [0.2, -0.3], [0.4]; # use the same data for two steps: | ||
|
||
julia> Flux.train!(model, [(x1, y1), (x1, y1)], opt) do m, x, y | ||
sum(abs.(m(x) .- y)) * 100 | ||
end | ||
2-element Vector{Float32}: | ||
40.1 | ||
38.7 | ||
|
||
julia> model.bias # was zero, mutated by Flux.train! | ||
1-element Vector{Float32}: | ||
10.190001 | ||
|
||
julia> opt # mutated by Flux.train! | ||
(weight = Leaf(Momentum{Float64}(0.1, 0.9), Float32[-2.018 3.027]), bias = Leaf(Momentum{Float64}(0.1, 0.9), Float32[-10.09]), σ = ()) | ||
``` | ||
""" | ||
function setup(rule::Optimisers.AbstractRule, model) | ||
state = Optimisers.setup(rule, model) | ||
fmap(model, exclude = Optimisers.isnumeric) do x | ||
Optimisers.maywrite(x) || error("""model must be fully mutable for `train!` to work, got `x::$(typeof(x))`. | ||
If `x .+= dx` is in fact ok, define `Optimisers.maywrite(::$(typeof(x))) = true`""") | ||
end | ||
state | ||
end | ||
|
||
""" | ||
train!(loss, model, data, opt) | ||
|
||
Uses a `loss` function and training `data` to improve the `model`'s parameters | ||
according to a particular optimisation rule `opt`. Iterates through `data` once, | ||
evaluating `loss(model, d...)` for each `d` in data. | ||
|
||
For example, with these definitions... | ||
``` | ||
data = [(x1, y1), (x2, y2), (x3, y3)]; # each element must be a tuple | ||
|
||
loss3(m, x, y) = norm(m(x) .- y) # the model is the first argument | ||
|
||
opt = Flux.setup(Adam(), model) # explicit setup of optimiser momenta | ||
``` | ||
...calling `Flux.train!(loss3, model, data, opt)` runs a loop much like this, | ||
using Zygote's "explicit" mode for the gradient: | ||
``` | ||
for d in data | ||
∂L∂m = gradient(loss3, model, d...)[1] | ||
update!(opt, model, ∂L∂m) # method for "explicit" gradient | ||
end | ||
``` | ||
You can also write this loop yourself, if you need more flexibility. | ||
For this reason `train!` is not highly extensible. | ||
It adds only a few featurs to the loop above: | ||
|
||
* Stop with a `DomainError` if the loss is infinite or `NaN` at any point. | ||
|
||
* Return a vector containing the value of the loss function at each datapoint. | ||
|
||
* Show a progress bar using [`@withprogress`](https://github.com/JuliaLogging/ProgressLogging.jl). | ||
|
||
Note that the built-in loss functions accept 3 arguments, allowing for instance | ||
`train!(Flux.Losses.mse, model, data, opt)` instead of defining `loss3` as above. | ||
|
||
!!! note | ||
This method has significant changes from the one in Flux ≤ 0.13: | ||
* It now takes the `model` itself, not the result of [`Flux.params`](@ref). | ||
(This is to move away from Zygote's "implicit" parameter handling, with `Grads`.) | ||
* Instead of `loss` being a function which typically accepts two arguments | ||
(the input `x` and expected output `y` from each element of `data`) | ||
now it should typically accept three, the first of which is the `model` itself. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We never restricted specifically to 2 arguments (and we don't seem to restrict to 3 now either). I think the change is
I think the distinction is important, since for things like language models, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yea I don't know how to word this. I think all the doc examples have 2 & need 3, so wrote "typically". Maybe it needs more explanation.
It's weirder than that, because if There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've removed the "3 arguments" bit, to say just "Instead of |
||
* `data` must iterate tuples, otherwise you get an error. | ||
(Previously non-tuple types were not splatted into the loss. | ||
Pass in `((d,) for d in data)` to simulate this.) | ||
* `opt` should be the result of [`Flux.setup`](@ref). Using an optimiser | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does Documenter need There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe. I just tried, and I think neither this new There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Doc update is #2114, will need to be rebased on this & then checked. |
||
such as `Adam()` without this step should give you a warning. | ||
* Callback functions are not supported. | ||
But any code can be included in the above `for` loop. | ||
""" | ||
function train!(loss, model, data, opt; cb = nothing) | ||
isnothing(cb) || error("""train! does not support callback functions. | ||
For more control use a loop with `gradient` and `update!`.""") | ||
losses = Float32[] | ||
@withprogress for (i,d) in enumerate(data) | ||
d isa Tuple || error("""train! expects as data an iterator producing tuples, but got $(typeof(d)). | ||
Pass it `((d,) for d in data)`, or use `gradient` and `update!` for more control.""") | ||
# l, (g, _...) = explicit_withgradient(loss, model, d...) # BTW this un-thunks gradient w.r.t. data. Could avoid that | ||
l, (g, _...) = explicit_withgradient(m -> loss(m, d...), model) | ||
CarloLucibello marked this conversation as resolved.
Show resolved
Hide resolved
|
||
isfinite(l) || throw(DomainError("loss function returned $l, stopping training")) | ||
opt, model = Optimisers.update!(opt, model, g) | ||
push!(losses, l) | ||
mcabbott marked this conversation as resolved.
Show resolved
Hide resolved
|
||
@logprogress Base.haslength(data) ? i/length(data) : nothing | ||
end | ||
return losses # Not entirely sure returning losses is a good idea, as it may conflict with later returning immutable models alla Optimisers.jl | ||
CarloLucibello marked this conversation as resolved.
Show resolved
Hide resolved
|
||
end | ||
|
||
# This method let you use Optimisers.Descent() without setup, when there is no state | ||
function train!(loss, model, data, rule::Optimisers.AbstractRule) | ||
train!(loss, model, data, _rule_to_state(model, rule)) | ||
end | ||
|
||
function _rule_to_state(model, rule::Optimisers.AbstractRule) | ||
state = setup(rule, model) | ||
@gensym warn_id | ||
name = typeof(rule).name.name | ||
fmap(state, exclude = x -> x isa Optimisers.Leaf) do leaf | ||
leaf.state isa Nothing || @warn """Optimiser $name has state which will be discarded after `train!` finishes. | ||
Please run `opt = Flux.setup($name(), model)` and pass this `opt` to `train!`.""" leaf maxlog=1 _id=warn_id | ||
leaf | ||
end | ||
darsnack marked this conversation as resolved.
Show resolved
Hide resolved
|
||
state | ||
end | ||
|
||
explicit_withgradient(f, args...) = Zygote.withgradient(f, args...) # can overload this to use e.g. Yota / Diffractor | ||
|
||
end # module |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
using Flux | ||
# using Flux.Train | ||
import Optimisers | ||
|
||
using Test | ||
using Random | ||
|
||
@testset "Explicit Flux.train! with Zygote" begin | ||
Random.seed!(84) | ||
w = randn(10, 10) | ||
w2 = randn(10, 10) # NB outside the inner @testset, else it will be exactly == w, as the RNG seed is reset. | ||
@testset for rule in [AdamW(), AdaGrad(0.1), AdaMax(), AdaDelta(0.9), AMSGrad(), | ||
NAdam(), RAdam(), Descent(0.1), Adam(), OAdam(), AdaBelief(), | ||
Nesterov(), RMSProp(), Momentum()] | ||
|
||
loss(m, x) = Flux.Losses.mse(w*x, m.weight*x .+ m.bias) | ||
model = (weight=copy(w2), bias=zeros(10), ignore=nothing) | ||
@test loss(model, rand(10, 10)) > 1 | ||
|
||
opt = Flux.setup(rule, model) | ||
Flux.train!(loss, model, ((rand(10),) for _ in 1: 10^5), opt) | ||
@test loss(model, rand(10, 10)) < 0.01 | ||
end | ||
|
||
# Test direct use of Optimisers.jl rule, only really OK for `Descent`: | ||
@testset "without setup, $opt" for opt in [Descent(0.1), Optimisers.Descent(0.1), Optimisers.Adam()] | ||
loss(m, x) = Flux.Losses.mse(w*x, m.weight*x .+ m.bias) | ||
model = (weight=copy(w2), bias=zeros(10), ignore=nothing) | ||
@test loss(model, rand(10, 10)) > 1 | ||
Flux.train!(loss, model, ((rand(10),) for _ in 1: 10^5), opt) | ||
@test loss(model, rand(10, 10)) < 0.01 | ||
end | ||
end | ||
|
||
@testset "Explicit Flux.train! features" begin | ||
# Test errors from wrong kind of iterator | ||
# Test NaN / Inf early stop | ||
# Test that loss is returned | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am hesitant to create a function in the Flux namespace that clashes with Optimisers.jl. It is hard enough already to keep track of where "Flux functions" actually come from.
Why not extend
Optimisers.setup
forFlux.Optimise.AbstractOptimiser
and remove the mutability check? I am guessing this is to guard against immutable models sincetrain!
does not return the model?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought I did (in addition) extend
Optimisers.setup
that way, but in fact I made an error. Can change it.Yes the guard against immutable models is the point. All Flux models are assumed mutable right now, and this just makes the check explicit.
I don't love the collision, but neither name is exported, and the consequences of using the wrong one are (I think) slight. You lose the safety check but any model which does work with Flux.setup will also work correctly with Optimisers.setup.
We can of course make
train!
return the model. But this isn't enough, as you also have to re-do your code to keep not discard the returned model. It's a bit awkward.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BTW, I picture Flux 0.14 deleting Flux.Optimise and exporting Adam etc from Optimisers.jl.
Code that goes
using Flux; opt = Flux.setup(Adam(), model); train!(loss, model, data, opt)
will work equally well on 0.13 and 0.14. You don't have to load Optimisers.jl yourself at all, and all will be safe.If you do load Optimisers.jl yourself and use its functions, then you have opted into the
model, _ = update!(opt, model, grad)
thing where you are supposed to get back the new model.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess the question is should this be the case? I do think there should be a patch release of v0.13.X that accepts
Optimisers.Adam
, etc. and upgradesFlux.Optimise.Adam
with a warning. This will allowtrain!
to work like quoted above too. But in v0.14, I was expecting that we force people to start usingmodel = train!(...)
. Previously,train!
andupdate!
worked similarly (mutating optimizers and model), and we could saytrain!
is "just" a loop. Diverging how they work seems worse than a minor code refactor on a breaking release. Especially given people will get warnings from before.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. That's what this PR makes.
But how? You want
train!
not to mutate, so that everyone will wonder why their model isn't training, and why it's calledtrain!
? Or worse to make it return a copy and write NaN into the old model to trash it? These seem awful to me, deliberate breakage for which we gain nothing.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, let's forget my suggestions about the warnings, as I agree about the orthogonality.
One option here is to return the model from
train!
which would allow for immutable models to work. Mutable models still don't need to capture the return value to work. So, we don't force people to domodel = train!(...)
. And we still haveFlux.setup
here to work in the reverse direction: warn if any leaf is immutable.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But I agree that if we are adding
Flux.setup
, then this seems like something that can be revisited later too.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One thing we should consider is changing Optimisers. Maybe its present function should be called
update!!
as that's something of a convention for "tries to mutate but may fail".Then in [email protected], we can introduce a new function
update!
which demands mutability, fails on immutable parameters. And that's the one we identify with Flux's function.That's now FluxML/Optimisers.jl#116
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think
train!
should guarantee mutation, this is the widespread julian convention. We can have atrain
and atrain!!
for non-mutating and mutate-if-possible versions.In that case, whether it returns the model or not hasn't great relevance. Base functions such as
replace!
andmap!
return the mutated input. Maybe just for REPL usage convenience? In our case returning the model in the repl would just be an annoyance I guess.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, the present one returns
nothing
, which as you say means you don't get a screenful of stuff, and also serves as a reminder that it mutates the model.I think I'd be happiest if
update!
did the same. I meanFlux.update!
does now, but after unifying withOptimisers.update!
too.I understand the attraction of
state, model = update(state, model, grad)
but IMO it's a pain to remember the order, andupdate!
is now in a weird place where it does guarantee to mutate thestate
, but not the model.