diff --git a/src/deprecations.jl b/src/deprecations.jl index c878c5192b..782efde473 100644 --- a/src/deprecations.jl +++ b/src/deprecations.jl @@ -89,22 +89,14 @@ Base.@deprecate_binding ADADelta AdaDelta # Valid methods in Train, new explict style, are: train!(loss, model, data, opt) train!(loss, model, data, opt::Optimisers.AbstractRule) - # ... and 3-arg: - train!(loss, model, opt) - train!(loss, model, opt::Optimisers.AbstractRule) # Provide friendly errors for what happens if you mix these up: =# import .Optimise: train! train!(loss, ps::Params, data, opt) = error("can't mix implict Params with explict state") -train!(loss, ps::Params, opt) = error("can't mix implict Params with explict state") train!(loss, ps::Params, data, opt::Optimisers.AbstractRule) = error("can't mix implict Params with explict rule") -train!(loss, ps::Params, opt::Optimisers.AbstractRule) = error("can't mix implict Params with explict rule") train!(loss, model, data, opt::Optimise.AbstractOptimiser) = train!(loss, model, data, _old_to_new(opt)) -train!(loss, model, opt::Optimise.AbstractOptimiser) = train!(loss, model, _old_to_new(opt)) - -train!(loss, ps::Params, opt::Optimise.AbstractOptimiser; cb=0) = error("3-arg train does not exist for implicit mode") # train!(loss::Function, ps::Zygote.Params, data, opt) = throw(ArgumentError( # """On Flux 0.14, `train!` no longer accepts implicit `Zygote.Params`. diff --git a/src/train.jl b/src/train.jl index 1c35da1b0a..525912e784 100644 --- a/src/train.jl +++ b/src/train.jl @@ -47,9 +47,6 @@ function setup(rule::Optimisers.AbstractRule, model) state end -# opt = Flux.setup(Adam(), model); train!(model, opt) do m ... -setup(model, rule::Optimisers.AbstractRule) = setup(rule, model) - """ train!(loss, model, data, opt) @@ -112,56 +109,10 @@ function train!(loss, model, data, opt) return losses # Not entirely sure returning losses is a good idea end -""" - train!(loss, model, opt) - -Uses a `loss` function improve the `model`'s parameters. - -While the 4-argument method of `train!` iterates over a dataset, -this 3-argument method is for a single datapoint, and calls `gradient` just once. -It expects a function `loss` which takes just one argument, the model. -For example: -``` -opt = Flux.setup(Adam(), model) # explicit setup -train!(model, opt) do m # the model is passed to the function as `m` - Flux.crossentropy(m(x1), y1) # but the data point `(x1, y1)` is closed over. -end -``` -This calls `Zygote.withgradient(m -> Flux.crossentropy(m(x1), y1), model)`. -(The `do` block is another syntax for this anonymous function.) -Then it updates the parameters contained within `model` according to `opt`. -Finally it returns the value of the loss function. - -To iterate over a dataset, writing a loop allows more control than -calling 4-argument `train!`. For example, this adds printing and an early stop: -``` -data = Flux.DataLoader((Xtrain, Ytrain), batchsize=32) -opt = Flux.setup(Adam(), model) -for (i, d) in enumerate(data) - x, y = d - ell = Flux.train!(m -> Flux.crossentropy(m(x), y), model, opt) - i%10==0 && println("on step \$i, the loss was \$ell") # prints every 10th step - ell<0.1 && break # stops training -end -``` - -!!! note - This method has no implicit `Params` analog in Flux ≤ 0.13. -""" -function train!(loss, model, opt) - l, (g, _...) = explicit_withgradient(loss, model) - isfinite(l) || return l - _, model = Optimisers.update!(opt, model, g) - return l -end - -# These methods let you use Optimisers.Descent() without setup, when there is no state +# This method let you use Optimisers.Descent() without setup, when there is no state function train!(loss, model, data, rule::Optimisers.AbstractRule) train!(loss, model, data, _rule_to_state(model, rule)) end -function train!(loss, model, rule::Optimisers.AbstractRule) - train!(loss, model, _rule_to_state(model, rule)) -end function _rule_to_state(model, rule::Optimisers.AbstractRule) state = setup(rule, model) diff --git a/test/train.jl b/test/train.jl index ce5a3c3ee2..607dc1e9a6 100644 --- a/test/train.jl +++ b/test/train.jl @@ -22,22 +22,6 @@ using Random @test loss(model, rand(10, 10)) < 0.01 end - # Test 3-arg `Flux.train!` method: - @testset for rule in [Descent(0.1), Adam(), AdamW()] - - loss(m) = let x = rand(10) - Flux.Losses.mse(w*x, m.weight*x .+ m.bias) - end - model = (weight=copy(w2), bias=zeros(10), ignore=nothing) - @test loss(model) > 1 - - opt = Flux.setup(rule, model) - for i in 1:10^5 - Flux.train!(loss, model, opt) - end - @test loss(model) < 0.01 - end - # Test direct use of Optimisers.jl rule, only really OK for `Descent`: @testset "without setup, $opt" for opt in [Descent(0.1), Optimisers.Descent(0.1), Optimisers.Adam()] loss(m, x) = Flux.Losses.mse(w*x, m.weight*x .+ m.bias)