Skip to content

Commit

Permalink
remove 3-argument train! since this requires impure loss function, an…
Browse files Browse the repository at this point in the history
…d you can just use update! instead really.
  • Loading branch information
mcabbott committed Nov 10, 2022
1 parent 14df718 commit 2b77843
Show file tree
Hide file tree
Showing 3 changed files with 1 addition and 74 deletions.
8 changes: 0 additions & 8 deletions src/deprecations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -89,22 +89,14 @@ Base.@deprecate_binding ADADelta AdaDelta
# Valid methods in Train, new explict style, are:
train!(loss, model, data, opt)
train!(loss, model, data, opt::Optimisers.AbstractRule)
# ... and 3-arg:
train!(loss, model, opt)
train!(loss, model, opt::Optimisers.AbstractRule)
# Provide friendly errors for what happens if you mix these up:
=#
import .Optimise: train!
train!(loss, ps::Params, data, opt) = error("can't mix implict Params with explict state")
train!(loss, ps::Params, opt) = error("can't mix implict Params with explict state")

train!(loss, ps::Params, data, opt::Optimisers.AbstractRule) = error("can't mix implict Params with explict rule")
train!(loss, ps::Params, opt::Optimisers.AbstractRule) = error("can't mix implict Params with explict rule")

train!(loss, model, data, opt::Optimise.AbstractOptimiser) = train!(loss, model, data, _old_to_new(opt))
train!(loss, model, opt::Optimise.AbstractOptimiser) = train!(loss, model, _old_to_new(opt))

train!(loss, ps::Params, opt::Optimise.AbstractOptimiser; cb=0) = error("3-arg train does not exist for implicit mode")

# train!(loss::Function, ps::Zygote.Params, data, opt) = throw(ArgumentError(
# """On Flux 0.14, `train!` no longer accepts implicit `Zygote.Params`.
Expand Down
51 changes: 1 addition & 50 deletions src/train.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,6 @@ function setup(rule::Optimisers.AbstractRule, model)
state
end

# opt = Flux.setup(Adam(), model); train!(model, opt) do m ...
setup(model, rule::Optimisers.AbstractRule) = setup(rule, model)

"""
train!(loss, model, data, opt)
Expand Down Expand Up @@ -112,56 +109,10 @@ function train!(loss, model, data, opt)
return losses # Not entirely sure returning losses is a good idea
end

"""
train!(loss, model, opt)
Uses a `loss` function improve the `model`'s parameters.
While the 4-argument method of `train!` iterates over a dataset,
this 3-argument method is for a single datapoint, and calls `gradient` just once.
It expects a function `loss` which takes just one argument, the model.
For example:
```
opt = Flux.setup(Adam(), model) # explicit setup
train!(model, opt) do m # the model is passed to the function as `m`
Flux.crossentropy(m(x1), y1) # but the data point `(x1, y1)` is closed over.
end
```
This calls `Zygote.withgradient(m -> Flux.crossentropy(m(x1), y1), model)`.
(The `do` block is another syntax for this anonymous function.)
Then it updates the parameters contained within `model` according to `opt`.
Finally it returns the value of the loss function.
To iterate over a dataset, writing a loop allows more control than
calling 4-argument `train!`. For example, this adds printing and an early stop:
```
data = Flux.DataLoader((Xtrain, Ytrain), batchsize=32)
opt = Flux.setup(Adam(), model)
for (i, d) in enumerate(data)
x, y = d
ell = Flux.train!(m -> Flux.crossentropy(m(x), y), model, opt)
i%10==0 && println("on step \$i, the loss was \$ell") # prints every 10th step
ell<0.1 && break # stops training
end
```
!!! note
This method has no implicit `Params` analog in Flux ≤ 0.13.
"""
function train!(loss, model, opt)
l, (g, _...) = explicit_withgradient(loss, model)
isfinite(l) || return l
_, model = Optimisers.update!(opt, model, g)
return l
end

# These methods let you use Optimisers.Descent() without setup, when there is no state
# This method let you use Optimisers.Descent() without setup, when there is no state
function train!(loss, model, data, rule::Optimisers.AbstractRule)
train!(loss, model, data, _rule_to_state(model, rule))
end
function train!(loss, model, rule::Optimisers.AbstractRule)
train!(loss, model, _rule_to_state(model, rule))
end

function _rule_to_state(model, rule::Optimisers.AbstractRule)
state = setup(rule, model)
Expand Down
16 changes: 0 additions & 16 deletions test/train.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,22 +22,6 @@ using Random
@test loss(model, rand(10, 10)) < 0.01
end

# Test 3-arg `Flux.train!` method:
@testset for rule in [Descent(0.1), Adam(), AdamW()]

loss(m) = let x = rand(10)
Flux.Losses.mse(w*x, m.weight*x .+ m.bias)
end
model = (weight=copy(w2), bias=zeros(10), ignore=nothing)
@test loss(model) > 1

opt = Flux.setup(rule, model)
for i in 1:10^5
Flux.train!(loss, model, opt)
end
@test loss(model) < 0.01
end

# Test direct use of Optimisers.jl rule, only really OK for `Descent`:
@testset "without setup, $opt" for opt in [Descent(0.1), Optimisers.Descent(0.1), Optimisers.Adam()]
loss(m, x) = Flux.Losses.mse(w*x, m.weight*x .+ m.bias)
Expand Down

0 comments on commit 2b77843

Please sign in to comment.