Skip to content

Commit

Permalink
explicit train, take 2
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Oct 13, 2022
1 parent c7ed5fe commit d80bf53
Show file tree
Hide file tree
Showing 9 changed files with 457 additions and 10 deletions.
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

## v0.13.7
* Added [`@autosize` macro](https://github.com/FluxML/Flux.jl/pull/2078)
* New method of `train!` using Zygote's "explicit" mode, allows changing AD back-end.

## v0.13.4
* Added [`PairwiseFusion` layer](https://github.com/FluxML/Flux.jl/pull/1983)
Expand Down
8 changes: 6 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,13 @@ MacroTools = "0.5"
NNlib = "0.8.9"
NNlibCUDA = "0.2.4"
OneHotArrays = "0.1"
Optimisers = "0.2.1"
Optimisers = "0.2.10"
ProgressLogging = "0.1"
Reexport = "0.2, 1.0"
SpecialFunctions = "1.8.2, 2.1.2"
StatsBase = "0.33"
Tracker = "0.2.22"
Yota = "0.8.1"
Zygote = "0.6.34"
julia = "1.6"

Expand All @@ -50,7 +52,9 @@ Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Yota = "cd998857-8626-517d-b929-70ad188a48f0"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test", "Documenter", "IterTools", "LinearAlgebra", "FillArrays", "ComponentArrays"]
test = ["Test", "Documenter", "IterTools", "LinearAlgebra", "FillArrays", "ComponentArrays", "Tracker", "Yota"]
12 changes: 6 additions & 6 deletions docs/src/models/overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,9 @@ julia> predict(x_train)
In order to make better predictions, you'll need to provide a *loss function* to tell Flux how to objectively *evaluate* the quality of a prediction. Loss functions compute the cumulative distance between actual values and predictions.

```jldoctest overview; filter = r"[+-]?([0-9]*[.])?[0-9]+(f[+-]*[0-9])?"
julia> loss(x, y) = Flux.Losses.mse(predict(x), y);
julia> loss(model, x, y) = mean(abs2.(model(x) .- y));
julia> loss(x_train, y_train)
julia> loss(predict, x_train, y_train)
122.64734f0
```

Expand Down Expand Up @@ -131,7 +131,7 @@ The first parameter is the weight and the second is the bias. Flux will adjust p
This optimiser implements the classic gradient descent strategy. Now improve the parameters of the model with a call to [`Flux.train!`](@ref) like this:

```jldoctest overview
julia> train!(loss, parameters, data, opt)
julia> train!(loss, predict, data, opt)
```

And check the loss:
Expand All @@ -156,10 +156,10 @@ In the previous section, we made a single call to `train!` which iterates over t

```jldoctest overview; filter = r"[+-]?([0-9]*[.])?[0-9]+(f[+-]*[0-9])?"
julia> for epoch in 1:200
train!(loss, parameters, data, opt)
train!(loss, predict, data, opt)
end
julia> loss(x_train, y_train)
julia> loss(predict, x_train, y_train)
0.00339581f0
julia> parameters
Expand Down Expand Up @@ -188,7 +188,7 @@ First, we gathered real-world data into the variables `x_train`, `y_train`, `x_t

Then, we built a single input, single output predictive model, `predict = Dense(1 => 1)`. The initial predictions weren't accurate, because we had not trained the model yet.

After building the model, we trained it with `train!(loss, parameters, data, opt)`. The loss function is first, followed by the `parameters` holding the weights and biases of the model, the training data, and the `Descent` optimizer provided by Flux. We ran the training step once, and observed that the parameters changed and the loss went down. Then, we ran the `train!` many times to finish the training process.
After building the model, we trained it with `train!(loss, predict, data, opt)`. The loss function is first, followed by the model itself, the training data, and the `Descent` optimizer provided by Flux. We ran the training step once, and observed that the parameters changed and the loss went down. Then, we ran the `train!` many times to finish the training process.

After we trained the model, we verified it with the test data to verify the results.

Expand Down
4 changes: 4 additions & 0 deletions src/Flux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ export Descent, Adam, Momentum, Nesterov, RMSProp,
AdamW, RAdam, AdaBelief, InvDecay, ExpDecay,
WeightDecay, ClipValue, ClipNorm

include("train.jl")
using .Train
# using .Train: setup, @train_autodiff

using CUDA
const use_cuda = Ref{Union{Nothing,Bool}}(nothing)

Expand Down
62 changes: 62 additions & 0 deletions src/deprecations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,65 @@ Base.@deprecate_binding ADAGrad AdaGrad
Base.@deprecate_binding ADADelta AdaDelta

@deprecate rng_from_array() default_rng_value()

#=
# Valid method in Optimise, old implicit style, is:
train!(loss, ps::Params, data, opt::AbstractOptimiser; cb = () -> ())
# Valid methods in Train, new explict style, are:
train!(loss, model, data, opt)
train!(loss, model, data, opt::Optimisers.AbstractRule)
# ... and 3-arg:
train!(loss, model, opt)
train!(loss, model, opt::Optimisers.AbstractRule)
# Provide friendly errors for what happens if you mix these up:
=#
import .Optimise: train!
train!(loss, ps::Params, data, opt) = error("can't mix implict Params with explict state")
train!(loss, ps::Params, opt) = error("can't mix implict Params with explict state")

train!(loss, ps::Params, data, opt::Optimisers.AbstractRule) = error("can't mix implict Params with explict rule")
train!(loss, ps::Params, opt::Optimisers.AbstractRule) = error("can't mix implict Params with explict rule")

train!(loss, model, data, opt::Optimise.AbstractOptimiser) = train!(loss, model, data, _old_to_new(opt))
train!(loss, model, opt::Optimise.AbstractOptimiser) = train!(loss, model, _old_to_new(opt))

train!(loss, ps::Params, opt::Optimise.AbstractOptimiser; cb=0) = error("3-arg train does not exist for implicit mode")

# train!(loss::Function, ps::Zygote.Params, data, opt) = throw(ArgumentError(
# """On Flux 0.14, `train!` no longer accepts implicit `Zygote.Params`.
# Instead of `train!(loss_xy, Flux.params(model), data, Adam())`
# it now needs `opt = Flux.setup(Adam(), model); train!(loss_mxy, model, data, opt)`
# where `loss_mxy` accepts the model as its first argument.
# """
# ))

# Next, to use the new `setup` with the still-exported old-style Adam etc:
import .Train: setup
setup(rule::Optimise.AbstractOptimiser, model) = setup(_old_to_new(rule), model)

for T in [:Descent, :Adam, :Momentum, :Nesterov, :RMSProp,
:AdaGrad, :AdaMax, :AdaDelta, :AMSGrad, :NAdam, :RAdam, :OAdam, :AdaBelief,
# :InvDecay, :ExpDecay,
]
@eval function _old_to_new(rule::$T)
args = map(f -> getfield(rule, f), fieldnames(Optimisers.$T))
Optimisers.$T(args...)
end
end
_old_to_new(rule::Optimiser) = Optimisers.OptimiserChain(map(_old_to_new, rule.os)...)
const OptimiserChain = Optimise.Optimiser # lets you use new name with implicit params too.
_old_to_new(rule::WeightDecay) = Optimisers.WeightDecay(rule.wd) # called gamma now
_old_to_new(rule::ClipNorm) = Optimisers.ClipNorm(rule.thesh) # called omega, and there are more fields
_old_to_new(rule::ClipValue) = Optimisers.ClipGrad(rule.thesh) # called delta now, and struct name differs
const ClipGrad = Optimise.ClipValue
_old_to_new(rule::RMSProp) = Optimisers.RMSProp(rule.eta, rule.rho, rule.epsilon) # RMSProp has no field centred

_old_to_new(rule) = error("Flux.setup does not know how to translate this old-style implicit rule to a new-style Optimisers.jl explicit rule")

Optimisers.setup(rule::Optimise.AbstractOptimiser, model) = error("please use Flux.setup not Optimisers.setup, it may be able to translate this rule")

# v0.14 deprecations

# Enable these when 0.14 is released, and delete const ClipGrad = Optimise.ClipValue etc:
# Base.@deprecate_binding Optimiser OptimiserChain
# Base.@deprecate_binding ClipValue ClipGrad
13 changes: 12 additions & 1 deletion src/optimise/train.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,23 @@
using ProgressLogging: @progress, @withprogress, @logprogress
import Zygote: Params, gradient

# Add methods to Optimisers.jl's function, so that there is just one Flux.update!
# for both explicit and implicit parameters.
import Optimisers.update!

"""
update!(opt, p, g)
update!(opt, ps::Params, gs)
Perform an update step of the parameters `ps` (or the single parameter `p`)
according to optimizer `opt` and the gradients `gs` (the gradient `g`).
according to optimizer `opt::AbstractOptimiser` and the gradients `gs` (the gradient `g`).
As a result, the parameters are mutated and the optimizer's internal state may change.
The gradient could be mutated as well.
!!! note
This method for implicit `Params` (and `AbstractOptimiser`) will be removed from Flux 0.14.
The explicit method `update!(opt, model, grad)` from Optimisers.jl will remain.
"""
function update!(opt::AbstractOptimiser, x, x̄)
x̄r = ArrayInterface.restructure(x, x̄) # address some cases where Zygote's
Expand Down Expand Up @@ -88,6 +95,10 @@ batchmemaybe(x::Tuple) = x
Uses a `loss` function and training `data` to improve the
model's parameters according to a particular optimisation rule `opt`.
!!! note
This method with implicit `Params` will be removed from Flux 0.14.
It should be replaced with the explicit method `train!(loss, model, data, opt)`.
For each `d in data`, first the gradient of the `loss` is computed like this:
```
gradient(() -> loss(d...), pars) # if d isa Tuple
Expand Down
Loading

0 comments on commit d80bf53

Please sign in to comment.