Skip to content

Commit

Permalink
tidy up
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Nov 10, 2022
1 parent 333e26d commit 3291917
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 34 deletions.
6 changes: 1 addition & 5 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,6 @@ ProgressLogging = "0.1"
Reexport = "0.2, 1.0"
SpecialFunctions = "1.8.2, 2.1.2"
StatsBase = "0.33"
Tracker = "0.2.22"
Yota = "0.8.1"
Zygote = "0.6.34"
julia = "1.6"

Expand All @@ -50,9 +48,7 @@ Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Yota = "cd998857-8626-517d-b929-70ad188a48f0"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test", "Documenter", "IterTools", "LinearAlgebra", "FillArrays", "ComponentArrays", "Tracker", "Yota"]
test = ["Test", "Documenter", "IterTools", "LinearAlgebra", "FillArrays", "ComponentArrays"]
71 changes: 42 additions & 29 deletions src/train.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,36 +6,41 @@ using Functors: fmap

import ..Flux.Optimise: train!, update! # during 0.13, we add methods to the old functions

export setup, @train_autodiff
export setup, train!

using ProgressLogging: @progress, @withprogress, @logprogress
using Zygote: Zygote, Params

"""
opt = setup(rule, model)
This is a version of `Optimisers.setup`, and is the first step before using `train!`.
This is a version of `Optimisers.setup`, and is the first step before using [`train!`](@ref Flux.train!).
It differs from `Optimisers.setup` in that it:
* has one extra check for mutability
* has methods which accept Flux's old optimisers, and convert them.
# Example
```jldoctest
julia> model = Dense(2=>1, leakyrelu; init=Flux.ones32);
julia> opt = Flux.setup(Momentum(0.11), model)
(weight = Leaf(Momentum{Float64}(0.11, 0.9), Float32[0.0 0.0]), bias = Leaf(Momentum{Float64}(0.11, 0.9), Float32[0.0]), σ = ())
julia> opt = Flux.setup(Momentum(0.1), model) # this encodes the optimiser and its state
(weight = Leaf(Momentum{Float64}(0.1, 0.9), Float32[0.0 0.0]), bias = Leaf(Momentum{Float64}(0.1, 0.9), Float32[0.0]), σ = ())
julia> Flux.train!(model, opt) do m # 3-arg train!, for one data point (x = [0.2, -0.3], y = [0.4])
sum(m([0.2, -0.3]) .- [0.4]) * 100
julia> x1, y1 = [0.2, -0.3], [0.4]; # use the same data for two steps:
julia> Flux.train!(model, [(x1, y1), (x1, y1)], opt) do m, x, y
sum(abs.(m(x) .- y)) * 100
end
-40.1
2-element Vector{Float32}:
40.1
38.7
julia> model.bias # was zero, mutated by Flux.train!
1-element Vector{Float32}:
-0.11
10.190001
julia> opt # mutated by Flux.train!
(weight = Leaf(Momentum{Float64}(0.11, 0.9), Float32[0.022 -0.033]), bias = Leaf(Momentum{Float64}(0.11, 0.9), Float32[0.11]), σ = ())
(weight = Leaf(Momentum{Float64}(0.1, 0.9), Float32[-2.018 3.027]), bias = Leaf(Momentum{Float64}(0.1, 0.9), Float32[-10.09]), σ = ())
```
"""
function setup(rule::Optimisers.AbstractRule, model)
Expand All @@ -51,18 +56,8 @@ end
train!(loss, model, data, opt)
Uses a `loss` function and training `data` to improve the `model`'s parameters
according to a particular optimisation rule `opt`.
!!! note
This method has significant changes from the one in Flux ≤ 0.13:
* It now takes the `model` itself, not the result of [`Flux.params`](@ref).
(This is to move away from Zygote's implicit parameter handling.)
* Instead of `loss` being a function which typically accepts two arguments
(the input `x` and expected output `y` from each element of `data`)
now it should typically accept three, the first of which is the `model` itself.
* `data` must iterate tuples. Each `d in data` is used as `loss(model, d...)`.
* `opt` should be the result of [`Flux.setup`](@ref), it will warn you if not.
* Callback functions are not supported.
according to a particular optimisation rule `opt`. Iterates through `data` once,
evaluating `loss(model, d...)` for each `d` in data.
For example, with these definitions...
```
Expand All @@ -72,15 +67,17 @@ loss3(m, x, y) = norm(m(x) .- y) # the model is the first argument
opt = Flux.setup(Adam(), model) # explicit setup of optimiser momenta
```
...calling `train!(loss3, model, data, opt)` runs a loop much like this:
...calling `Flux.train!(loss3, model, data, opt)` runs a loop much like this,
using Zygote's "explicit" mode for the gradient:
```
for d in data
∂L∂m = Zygote.gradient(loss3, model, d...)[1]
Optimisers.update!(opt, model, ∂L∂m)
∂L∂m = gradient(loss3, model, d...)[1]
update!(opt, model, ∂L∂m) # method for "explicit" gradient
end
```
You can also write this loop yourself, if you need more flexibility.
Besides the loop, `train!` will:
For this reason `train!` is not highly extensible.
It adds only a few featurs to the loop above:
* Stop with a `DomainError` if the loss is infinite or `NaN` at any point.
Expand All @@ -91,20 +88,36 @@ Besides the loop, `train!` will:
Note that the built-in loss functions accept 3 arguments, allowing for instance
`train!(Flux.Losses.mse, model, data, opt)` instead of defining `loss3` as above.
Note that callback functions are not supported. But arbitrary code can be inserted into the loop.
!!! note
This method has significant changes from the one in Flux ≤ 0.13:
* It now takes the `model` itself, not the result of [`Flux.params`](@ref).
(This is to move away from Zygote's "implicit" parameter handling, with `Grads`.)
* Instead of `loss` being a function which typically accepts two arguments
(the input `x` and expected output `y` from each element of `data`)
now it should typically accept three, the first of which is the `model` itself.
* `data` must iterate tuples, otherwise you get an error.
(Previously non-tuple types were not splatted into the loss.
Pass in `((d,) for d in data)` to simulate this.)
* `opt` should be the result of [`Flux.setup`](@ref). Using an optimiser
such as `Adam()` without this step should give you a warning.
* Callback functions are not supported.
But any code can be included in the above `for` loop.
"""
function train!(loss, model, data, opt)
function train!(loss, model, data, opt; cb = nothing)
isnothing(cb) || error("""train! does not support callback functions.
For more control use a loop with `gradient` and `update!`.""")
losses = Float32[]
@withprogress for (i,d) in enumerate(data)
d isa Tuple || error("""train! expects as data an iterator producing tuples, but got $(typeof(d)).
Pass it `((d,) for d in data)`, or use `gradient` and `update!` for more control.""")
l, (g, _...) = explicit_withgradient(loss, model, d...)
# l, (g, _...) = explicit_withgradient(loss, model, d...) # BTW this un-thunks gradient w.r.t. data. Could avoid that
l, (g, _...) = explicit_withgradient(m -> loss(m, d...), model)
isfinite(l) || throw(DomainError("loss function returned $l, stopping training"))
opt, model = Optimisers.update!(opt, model, g)
push!(losses, l)
@logprogress Base.haslength(data) ? i/length(data) : nothing
end
return losses # Not entirely sure returning losses is a good idea
return losses # Not entirely sure returning losses is a good idea, as it may conflict with later returning immutable models alla Optimisers.jl
end

# This method let you use Optimisers.Descent() without setup, when there is no state
Expand Down

0 comments on commit 3291917

Please sign in to comment.