Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add explicit train!, unify update!, and auto-translate the two Adams #2082

Merged
merged 18 commits into from
Nov 20, 2022
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

## v0.13.7
* Added [`@autosize` macro](https://github.com/FluxML/Flux.jl/pull/2078)
* New method of `train!` using Zygote's "explicit" mode. Part of a move away from "implicit" `Params`.

## v0.13.4
* Added [`PairwiseFusion` layer](https://github.com/FluxML/Flux.jl/pull/1983)
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ MacroTools = "0.5"
NNlib = "0.8.9"
NNlibCUDA = "0.2.4"
OneHotArrays = "0.1, 0.2"
Optimisers = "0.2.1"
Optimisers = "0.2.10"
ProgressLogging = "0.1"
Reexport = "0.2, 1.0"
SpecialFunctions = "1.8.2, 2.1.2"
Expand Down
4 changes: 2 additions & 2 deletions docs/make.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
using Documenter, Flux, NNlib, Functors, MLUtils, BSON, Optimisers, OneHotArrays, Zygote, ChainRulesCore
using Documenter, Flux, NNlib, Functors, MLUtils, BSON, Optimisers, OneHotArrays, Zygote, ChainRulesCore, Statistics


DocMeta.setdocmeta!(Flux, :DocTestSetup, :(using Flux); recursive = true)

makedocs(
modules = [Flux, NNlib, Functors, MLUtils, BSON, Optimisers, OneHotArrays, Zygote, ChainRulesCore, Base],
modules = [Flux, NNlib, Functors, MLUtils, BSON, Optimisers, OneHotArrays, Zygote, ChainRulesCore, Base, Statistics],
doctest = false,
sitename = "Flux",
# strict = [:cross_references,],
Expand Down
16 changes: 8 additions & 8 deletions docs/src/models/overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ This example will predict the output of the function `4x + 2`. Making such predi

First, import `Flux` and define the function we want to simulate:

```jldoctest overview
```jldoctest overview setup = :(using Statistics)
julia> using Flux

julia> actual(x) = 4x + 2
Expand Down Expand Up @@ -77,13 +77,13 @@ julia> predict(x_train)
In order to make better predictions, you'll need to provide a *loss function* to tell Flux how to objectively *evaluate* the quality of a prediction. Loss functions compute the cumulative distance between actual values and predictions.

```jldoctest overview; filter = r"[+-]?([0-9]*[.])?[0-9]+(f[+-]*[0-9])?"
julia> loss(x, y) = Flux.Losses.mse(predict(x), y);
julia> loss(model, x, y) = Statistics.mean(abs2.(model(x) .- y));

julia> loss(x_train, y_train)
julia> loss(predict, x_train, y_train)
122.64734f0
```

More accurate predictions will yield a lower loss. You can write your own loss functions or rely on those already provided by Flux. This loss function is called [mean squared error](https://www.statisticshowto.com/probability-and-statistics/statistics-definitions/mean-squared-error/). Flux works by iteratively reducing the loss through *training*.
More accurate predictions will yield a lower loss. You can write your own loss functions or rely on those already provided by Flux. This loss function is called [mean squared error](https://www.statisticshowto.com/probability-and-statistics/statistics-definitions/mean-squared-error/) (and built-in as [`mse`](@ref Flux.Losses.mse)). Flux works by iteratively reducing the loss through *training*.

## 3. Improve the Prediction

Expand Down Expand Up @@ -131,7 +131,7 @@ The first parameter is the weight and the second is the bias. Flux will adjust p
This optimiser implements the classic gradient descent strategy. Now improve the parameters of the model with a call to [`Flux.train!`](@ref) like this:

```jldoctest overview
julia> train!(loss, parameters, data, opt)
julia> train!(loss, predict, data, opt)
```

And check the loss:
Expand All @@ -156,10 +156,10 @@ In the previous section, we made a single call to `train!` which iterates over t

```jldoctest overview; filter = r"[+-]?([0-9]*[.])?[0-9]+(f[+-]*[0-9])?"
julia> for epoch in 1:200
train!(loss, parameters, data, opt)
train!(loss, predict, data, opt)
end

julia> loss(x_train, y_train)
julia> loss(predict, x_train, y_train)
0.00339581f0

julia> parameters
Expand Down Expand Up @@ -188,7 +188,7 @@ First, we gathered real-world data into the variables `x_train`, `y_train`, `x_t

Then, we built a single input, single output predictive model, `predict = Dense(1 => 1)`. The initial predictions weren't accurate, because we had not trained the model yet.

After building the model, we trained it with `train!(loss, parameters, data, opt)`. The loss function is first, followed by the `parameters` holding the weights and biases of the model, the training data, and the `Descent` optimizer provided by Flux. We ran the training step once, and observed that the parameters changed and the loss went down. Then, we ran the `train!` many times to finish the training process.
After building the model, we trained it with `train!(loss, predict, data, opt)`. The loss function is first, followed by the model itself, the training data, and the `Descent` optimizer provided by Flux. We ran the training step once, and observed that the parameters changed and the loss went down. Then, we ran the `train!` many times to finish the training process.

After we trained the model, we verified it with the test data to verify the results.

Expand Down
4 changes: 4 additions & 0 deletions src/Flux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ export Descent, Adam, Momentum, Nesterov, RMSProp,
AdamW, RAdam, AdaBelief, InvDecay, ExpDecay,
WeightDecay, ClipValue, ClipNorm

include("train.jl")
using .Train
# using .Train: setup, @train_autodiff

using CUDA
const use_cuda = Ref{Union{Nothing,Bool}}(nothing)

Expand Down
65 changes: 65 additions & 0 deletions src/deprecations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,68 @@ Base.@deprecate_binding ADAGrad AdaGrad
Base.@deprecate_binding ADADelta AdaDelta

@deprecate rng_from_array() default_rng_value()

#=
# Valid method in Optimise, old implicit style, is:
train!(loss, ps::Params, data, opt::AbstractOptimiser; cb = () -> ())

# Valid methods in Train, new explict style, are:
train!(loss, model, data, opt) # preferred
train!(loss, model, data, opt::Optimisers.AbstractRule) # if you forget setup

# Provide friendly errors for what happens if you mix these up:
=#
import .Optimise: train!

train!(loss, ps::Params, data, opt) = error(
"""can't mix implict Params with explict state!
To use `Flux.params(m)` in `train!`, the 4th argument must be from the old `Flux.Optimise` sub-module.
But better to use the new explicit style, in which `m` itself is the 2nd argument.
""")

train!(loss, ps::Params, data, opt::Optimisers.AbstractRule) = error(
"""can't mix implict Params with explict rule from Optimisers.jl
To use `Flux.params(m)` in `train!`, the 4th argument must be from the old `Flux.Optimise` sub-module.
But better to use the new explicit style, in which `m` itself is the 2nd argument.
""")

train!(loss, model, data, opt::Optimise.AbstractOptimiser) = train!(loss, model, data, _old_to_new(opt))

# Next, to use the new `setup` with the still-exported old-style `Adam` etc:
import .Train: setup
setup(rule::Optimise.AbstractOptimiser, model) = setup(_old_to_new(rule), model)
# ... and allow accidental use of `Optimisers.setup` to do the same:
Optimisers.setup(rule::Optimise.AbstractOptimiser, model) = setup(_old_to_new(rule), model)

for T in [:Descent, :Adam, :Momentum, :Nesterov,
:AdaGrad, :AdaMax, :AdaDelta, :AMSGrad, :NAdam, :RAdam, :OAdam, :AdaBelief,
# :InvDecay, :ExpDecay,
]
@eval function _old_to_new(rule::$T)
args = map(f -> getfield(rule, f), fieldnames(Optimisers.$T))
Optimisers.$T(args...)
end
end
_old_to_new(rule::Optimiser) = Optimisers.OptimiserChain(map(_old_to_new, rule.os)...)
const OptimiserChain = Optimise.Optimiser # lets you use new name with implicit params too.
_old_to_new(rule::WeightDecay) = Optimisers.WeightDecay(rule.wd) # called gamma now
_old_to_new(rule::ClipNorm) = Optimisers.ClipNorm(rule.thesh) # called omega, and there are more fields
_old_to_new(rule::ClipValue) = Optimisers.ClipGrad(rule.thesh) # called delta now, and struct name differs
const ClipGrad = Optimise.ClipValue
_old_to_new(rule::RMSProp) = Optimisers.RMSProp(rule.eta, rule.rho, rule.epsilon) # RMSProp has no field centred

_old_to_new(rule) = error("Flux.setup does not know how to translate this old-style implicit rule to a new-style Optimisers.jl explicit rule")

# v0.14 deprecations

# Enable these when 0.14 is released, and delete const ClipGrad = Optimise.ClipValue etc:
# Base.@deprecate_binding Optimiser OptimiserChain
# Base.@deprecate_binding ClipValue ClipGrad

# train!(loss::Function, ps::Zygote.Params, data, opt) = throw(ArgumentError(
# """On Flux 0.14, `train!` no longer accepts implicit `Zygote.Params`.
# Instead of `train!(loss_xy, Flux.params(model), data, Adam())`
# it now needs `opt = Flux.setup(Adam(), model); train!(loss_mxy, model, data, opt)`
# where `loss_mxy` accepts the model as its first argument.
# """
# ))
13 changes: 12 additions & 1 deletion src/optimise/train.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,23 @@
using ProgressLogging: @progress, @withprogress, @logprogress
import Zygote: Params, gradient, withgradient

# Add methods to Optimisers.jl's function, so that there is just one Flux.update!
# for both explicit and implicit parameters.
import Optimisers.update!

"""
update!(opt, p, g)
update!(opt, ps::Params, gs)

Perform an update step of the parameters `ps` (or the single parameter `p`)
according to optimizer `opt` and the gradients `gs` (the gradient `g`).
according to optimizer `opt::AbstractOptimiser` and the gradients `gs` (the gradient `g`).

As a result, the parameters are mutated and the optimizer's internal state may change.
The gradient could be mutated as well.

!!! note
This method for implicit `Params` (and `AbstractOptimiser`) will be removed from Flux 0.14.
The explicit method `update!(opt, model, grad)` from Optimisers.jl will remain.
"""
function update!(opt::AbstractOptimiser, x, x̄)
x̄r = copyto!(similar(x̄), x̄) # Flux.Optimise assumes it can mutate the gradient. This is not
Expand Down Expand Up @@ -88,6 +95,10 @@ batchmemaybe(x::Tuple) = x
Uses a `loss` function and training `data` to improve the
model's parameters according to a particular optimisation rule `opt`.

!!! note
This method with implicit `Params` will be removed from Flux 0.14.
It should be replaced with the explicit method `train!(loss, model, data, opt)`.

For each `d in data`, first the gradient of the `loss` is computed like this:
```
gradient(() -> loss(d...), pars) # if d isa Tuple
Expand Down
142 changes: 142 additions & 0 deletions src/train.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
module Train

using LinearAlgebra
using Optimisers: Optimisers
using Functors: fmap

import ..Flux.Optimise: train!, update! # during 0.13, we add methods to the old functions

export setup, train!

using ProgressLogging: @progress, @withprogress, @logprogress
using Zygote: Zygote, Params

"""
opt = setup(rule, model)

This is a version of `Optimisers.setup`, and is the first step before using [`train!`](@ref Flux.train!).
It differs from `Optimisers.setup` in that it:
* has one extra check for mutability
* has methods which accept Flux's old optimisers, and convert them.

# Example
```jldoctest
julia> model = Dense(2=>1, leakyrelu; init=Flux.ones32);

julia> opt = Flux.setup(Momentum(0.1), model) # this encodes the optimiser and its state
(weight = Leaf(Momentum{Float64}(0.1, 0.9), Float32[0.0 0.0]), bias = Leaf(Momentum{Float64}(0.1, 0.9), Float32[0.0]), σ = ())

julia> x1, y1 = [0.2, -0.3], [0.4]; # use the same data for two steps:

julia> Flux.train!(model, [(x1, y1), (x1, y1)], opt) do m, x, y
sum(abs.(m(x) .- y)) * 100
end
2-element Vector{Float32}:
40.1
38.7

julia> model.bias # was zero, mutated by Flux.train!
1-element Vector{Float32}:
10.190001

julia> opt # mutated by Flux.train!
(weight = Leaf(Momentum{Float64}(0.1, 0.9), Float32[-2.018 3.027]), bias = Leaf(Momentum{Float64}(0.1, 0.9), Float32[-10.09]), σ = ())
```
"""
function setup(rule::Optimisers.AbstractRule, model)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am hesitant to create a function in the Flux namespace that clashes with Optimisers.jl. It is hard enough already to keep track of where "Flux functions" actually come from.

Why not extend Optimisers.setup for Flux.Optimise.AbstractOptimiser and remove the mutability check? I am guessing this is to guard against immutable models since train! does not return the model?

Copy link
Member Author

@mcabbott mcabbott Nov 10, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought I did (in addition) extend Optimisers.setup that way, but in fact I made an error. Can change it.

Yes the guard against immutable models is the point. All Flux models are assumed mutable right now, and this just makes the check explicit.

I don't love the collision, but neither name is exported, and the consequences of using the wrong one are (I think) slight. You lose the safety check but any model which does work with Flux.setup will also work correctly with Optimisers.setup.

We can of course make train! return the model. But this isn't enough, as you also have to re-do your code to keep not discard the returned model. It's a bit awkward.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW, I picture Flux 0.14 deleting Flux.Optimise and exporting Adam etc from Optimisers.jl.

Code that goes using Flux; opt = Flux.setup(Adam(), model); train!(loss, model, data, opt) will work equally well on 0.13 and 0.14. You don't have to load Optimisers.jl yourself at all, and all will be safe.

If you do load Optimisers.jl yourself and use its functions, then you have opted into the model, _ = update!(opt, model, grad) thing where you are supposed to get back the new model.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code that goes using Flux; opt = Flux.setup(Adam(), model); train!(loss, model, data, opt) will work equally well on 0.13 and 0.14.

I guess the question is should this be the case? I do think there should be a patch release of v0.13.X that accepts Optimisers.Adam, etc. and upgrades Flux.Optimise.Adam with a warning. This will allow train! to work like quoted above too. But in v0.14, I was expecting that we force people to start using model = train!(...). Previously, train! and update! worked similarly (mutating optimizers and model), and we could say train! is "just" a loop. Diverging how they work seems worse than a minor code refactor on a breaking release. Especially given people will get warnings from before.

Copy link
Member Author

@mcabbott mcabbott Nov 14, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do think there should be a patch release of v0.13.X that accepts Optimisers.Adam, etc. and upgrades Flux.Optimise.Adam with a warning

Yes. That's what this PR makes.

I was expecting that we force people to start using model = train!(...)

Especially given people will get warnings from before

But how? You want train! not to mutate, so that everyone will wonder why their model isn't training, and why it's called train!? Or worse to make it return a copy and write NaN into the old model to trash it? These seem awful to me, deliberate breakage for which we gain nothing.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, let's forget my suggestions about the warnings, as I agree about the orthogonality.

One option here is to return the model from train! which would allow for immutable models to work. Mutable models still don't need to capture the return value to work. So, we don't force people to do model = train!(...). And we still have Flux.setup here to work in the reverse direction: warn if any leaf is immutable.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But I agree that if we are adding Flux.setup, then this seems like something that can be revisited later too.

Copy link
Member Author

@mcabbott mcabbott Nov 14, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One thing we should consider is changing Optimisers. Maybe its present function should be called update!! as that's something of a convention for "tries to mutate but may fail".

Then in [email protected], we can introduce a new function update! which demands mutability, fails on immutable parameters. And that's the one we identify with Flux's function.

That's now FluxML/Optimisers.jl#116

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think train! should guarantee mutation, this is the widespread julian convention. We can have a train and a train!! for non-mutating and mutate-if-possible versions.
In that case, whether it returns the model or not hasn't great relevance. Base functions such as replace! and map! return the mutated input. Maybe just for REPL usage convenience? In our case returning the model in the repl would just be an annoyance I guess.

Copy link
Member Author

@mcabbott mcabbott Nov 18, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the present one returns nothing, which as you say means you don't get a screenful of stuff, and also serves as a reminder that it mutates the model.

I think I'd be happiest if update! did the same. I mean Flux.update! does now, but after unifying with Optimisers.update! too.

I understand the attraction of state, model = update(state, model, grad) but IMO it's a pain to remember the order, and update! is now in a weird place where it does guarantee to mutate the state, but not the model.

state = Optimisers.setup(rule, model)
fmap(model, exclude = Optimisers.isnumeric) do x
Optimisers.maywrite(x) || error("""model must be fully mutable for `train!` to work, got `x::$(typeof(x))`.
If `x .+= dx` is in fact ok, define `Optimisers.maywrite(::$(typeof(x))) = true`""")
end
state
end

"""
train!(loss, model, data, opt)

Uses a `loss` function and training `data` to improve the `model`'s parameters
according to a particular optimisation rule `opt`. Iterates through `data` once,
evaluating `loss(model, d...)` for each `d` in data.

For example, with these definitions...
```
data = [(x1, y1), (x2, y2), (x3, y3)]; # each element must be a tuple

loss3(m, x, y) = norm(m(x) .- y) # the model is the first argument

opt = Flux.setup(Adam(), model) # explicit setup of optimiser momenta
```
...calling `Flux.train!(loss3, model, data, opt)` runs a loop much like this,
using Zygote's "explicit" mode for the gradient:
```
for d in data
∂L∂m = gradient(loss3, model, d...)[1]
update!(opt, model, ∂L∂m) # method for "explicit" gradient
end
```
You can also write this loop yourself, if you need more flexibility.
For this reason `train!` is not highly extensible.
It adds only a few featurs to the loop above:

* Stop with a `DomainError` if the loss is infinite or `NaN` at any point.

* Return a vector containing the value of the loss function at each datapoint.

* Show a progress bar using [`@withprogress`](https://github.com/JuliaLogging/ProgressLogging.jl).

Note that the built-in loss functions accept 3 arguments, allowing for instance
`train!(Flux.Losses.mse, model, data, opt)` instead of defining `loss3` as above.

!!! note
This method has significant changes from the one in Flux ≤ 0.13:
* It now takes the `model` itself, not the result of [`Flux.params`](@ref).
(This is to move away from Zygote's "implicit" parameter handling, with `Grads`.)
* Instead of `loss` being a function which typically accepts two arguments
(the input `x` and expected output `y` from each element of `data`)
now it should typically accept three, the first of which is the `model` itself.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We never restricted specifically to 2 arguments (and we don't seem to restrict to 3 now either). I think the change is

  • <= v0.13: loss accepts as many arguments as length(first(data))
  • > v0.13: loss accepts at least 1 argument, the model, and can accept additional N additional arguments where N = length(first(data))

I think the distinction is important, since for things like language models, length(first(data)) == 1 (conceivably), and the loss handles taking the single data argument and turning it into a supervised problem.

Copy link
Member Author

@mcabbott mcabbott Oct 20, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea I don't know how to word this. I think all the doc examples have 2 & need 3, so wrote "typically". Maybe it needs more explanation.

<= v0.13: loss accepts as many arguments as length(first(data))

It's weirder than that, because if first(data) isn't a tuple, then it isn't splatted. I made the new code simply demand a tuple, else error, so that (at least for now) there is just one path.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've removed the "3 arguments" bit, to say just "Instead of loss being a function which accepts only the data, now it must also accept the model itself, as the first argument."

* `data` must iterate tuples, otherwise you get an error.
(Previously non-tuple types were not splatted into the loss.
Pass in `((d,) for d in data)` to simulate this.)
* `opt` should be the result of [`Flux.setup`](@ref). Using an optimiser
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does Documenter need Flux.Train.setup instead?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe. I just tried, and I think neither this new train! nor setup appear in the docs at present. That section needs to be re-worked for explicit parameters.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doc update is #2114, will need to be rebased on this & then checked.

such as `Adam()` without this step should give you a warning.
* Callback functions are not supported.
But any code can be included in the above `for` loop.
"""
function train!(loss, model, data, opt; cb = nothing)
isnothing(cb) || error("""train! does not support callback functions.
For more control use a loop with `gradient` and `update!`.""")
losses = Float32[]
@withprogress for (i,d) in enumerate(data)
d isa Tuple || error("""train! expects as data an iterator producing tuples, but got $(typeof(d)).
Pass it `((d,) for d in data)`, or use `gradient` and `update!` for more control.""")
# l, (g, _...) = explicit_withgradient(loss, model, d...) # BTW this un-thunks gradient w.r.t. data. Could avoid that
l, (g, _...) = explicit_withgradient(m -> loss(m, d...), model)
CarloLucibello marked this conversation as resolved.
Show resolved Hide resolved
isfinite(l) || throw(DomainError("loss function returned $l, stopping training"))
opt, model = Optimisers.update!(opt, model, g)
push!(losses, l)
mcabbott marked this conversation as resolved.
Show resolved Hide resolved
@logprogress Base.haslength(data) ? i/length(data) : nothing
end
return losses # Not entirely sure returning losses is a good idea, as it may conflict with later returning immutable models alla Optimisers.jl
CarloLucibello marked this conversation as resolved.
Show resolved Hide resolved
end

# This method let you use Optimisers.Descent() without setup, when there is no state
function train!(loss, model, data, rule::Optimisers.AbstractRule)
train!(loss, model, data, _rule_to_state(model, rule))
end

function _rule_to_state(model, rule::Optimisers.AbstractRule)
state = setup(rule, model)
@gensym warn_id
name = typeof(rule).name.name
fmap(state, exclude = x -> x isa Optimisers.Leaf) do leaf
leaf.state isa Nothing || @warn """Optimiser $name has state which will be discarded after `train!` finishes.
Please run `opt = Flux.setup($name(), model)` and pass this `opt` to `train!`.""" leaf maxlog=1 _id=warn_id
leaf
end
darsnack marked this conversation as resolved.
Show resolved Hide resolved
state
end

explicit_withgradient(f, args...) = Zygote.withgradient(f, args...) # can overload this to use e.g. Yota / Diffractor

end # module
3 changes: 2 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@ Random.seed!(0)
include("utils.jl")
end

@testset "Optimise" begin
@testset "Optimise / Train" begin
include("optimise.jl")
include("train.jl")
end

@testset "Data" begin
Expand Down
39 changes: 39 additions & 0 deletions test/train.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
using Flux
# using Flux.Train
import Optimisers

using Test
using Random

@testset "Explicit Flux.train! with Zygote" begin
Random.seed!(84)
w = randn(10, 10)
w2 = randn(10, 10) # NB outside the inner @testset, else it will be exactly == w, as the RNG seed is reset.
@testset for rule in [AdamW(), AdaGrad(0.1), AdaMax(), AdaDelta(0.9), AMSGrad(),
NAdam(), RAdam(), Descent(0.1), Adam(), OAdam(), AdaBelief(),
Nesterov(), RMSProp(), Momentum()]

loss(m, x) = Flux.Losses.mse(w*x, m.weight*x .+ m.bias)
model = (weight=copy(w2), bias=zeros(10), ignore=nothing)
@test loss(model, rand(10, 10)) > 1

opt = Flux.setup(rule, model)
Flux.train!(loss, model, ((rand(10),) for _ in 1: 10^5), opt)
@test loss(model, rand(10, 10)) < 0.01
end

# Test direct use of Optimisers.jl rule, only really OK for `Descent`:
@testset "without setup, $opt" for opt in [Descent(0.1), Optimisers.Descent(0.1), Optimisers.Adam()]
loss(m, x) = Flux.Losses.mse(w*x, m.weight*x .+ m.bias)
model = (weight=copy(w2), bias=zeros(10), ignore=nothing)
@test loss(model, rand(10, 10)) > 1
Flux.train!(loss, model, ((rand(10),) for _ in 1: 10^5), opt)
@test loss(model, rand(10, 10)) < 0.01
end
end

@testset "Explicit Flux.train! features" begin
# Test errors from wrong kind of iterator
# Test NaN / Inf early stop
# Test that loss is returned
end