diff --git a/NEWS.md b/NEWS.md index d83e76d62a..032819e1d1 100644 --- a/NEWS.md +++ b/NEWS.md @@ -2,6 +2,7 @@ ## v0.13.7 * Added [`@autosize` macro](https://github.com/FluxML/Flux.jl/pull/2078) +* New method of `train!` using Zygote's "explicit" mode, allows changing AD back-end. ## v0.13.4 * Added [`PairwiseFusion` layer](https://github.com/FluxML/Flux.jl/pull/1983) diff --git a/Project.toml b/Project.toml index ad77b5167e..e0d1318029 100644 --- a/Project.toml +++ b/Project.toml @@ -36,11 +36,13 @@ MacroTools = "0.5" NNlib = "0.8.9" NNlibCUDA = "0.2.4" OneHotArrays = "0.1" -Optimisers = "0.2.1" +Optimisers = "0.2.10" ProgressLogging = "0.1" Reexport = "0.2, 1.0" SpecialFunctions = "1.8.2, 2.1.2" StatsBase = "0.33" +Tracker = "0.2.22" +Yota = "0.8.1" Zygote = "0.6.34" julia = "1.6" @@ -50,7 +52,9 @@ Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" +Yota = "cd998857-8626-517d-b929-70ad188a48f0" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Test", "Documenter", "IterTools", "LinearAlgebra", "FillArrays", "ComponentArrays"] +test = ["Test", "Documenter", "IterTools", "LinearAlgebra", "FillArrays", "ComponentArrays", "Tracker", "Yota"] diff --git a/docs/src/models/overview.md b/docs/src/models/overview.md index c6d52ae083..630da338cb 100644 --- a/docs/src/models/overview.md +++ b/docs/src/models/overview.md @@ -77,9 +77,9 @@ julia> predict(x_train) In order to make better predictions, you'll need to provide a *loss function* to tell Flux how to objectively *evaluate* the quality of a prediction. Loss functions compute the cumulative distance between actual values and predictions. ```jldoctest overview; filter = r"[+-]?([0-9]*[.])?[0-9]+(f[+-]*[0-9])?" -julia> loss(x, y) = Flux.Losses.mse(predict(x), y); +julia> loss(model, x, y) = mean(abs2.(model(x) .- y)); -julia> loss(x_train, y_train) +julia> loss(predict, x_train, y_train) 122.64734f0 ``` @@ -131,7 +131,7 @@ The first parameter is the weight and the second is the bias. Flux will adjust p This optimiser implements the classic gradient descent strategy. Now improve the parameters of the model with a call to [`Flux.train!`](@ref) like this: ```jldoctest overview -julia> train!(loss, parameters, data, opt) +julia> train!(loss, predict, data, opt) ``` And check the loss: @@ -156,10 +156,10 @@ In the previous section, we made a single call to `train!` which iterates over t ```jldoctest overview; filter = r"[+-]?([0-9]*[.])?[0-9]+(f[+-]*[0-9])?" julia> for epoch in 1:200 - train!(loss, parameters, data, opt) + train!(loss, predict, data, opt) end -julia> loss(x_train, y_train) +julia> loss(predict, x_train, y_train) 0.00339581f0 julia> parameters @@ -188,7 +188,7 @@ First, we gathered real-world data into the variables `x_train`, `y_train`, `x_t Then, we built a single input, single output predictive model, `predict = Dense(1 => 1)`. The initial predictions weren't accurate, because we had not trained the model yet. -After building the model, we trained it with `train!(loss, parameters, data, opt)`. The loss function is first, followed by the `parameters` holding the weights and biases of the model, the training data, and the `Descent` optimizer provided by Flux. We ran the training step once, and observed that the parameters changed and the loss went down. Then, we ran the `train!` many times to finish the training process. +After building the model, we trained it with `train!(loss, predict, data, opt)`. The loss function is first, followed by the model itself, the training data, and the `Descent` optimizer provided by Flux. We ran the training step once, and observed that the parameters changed and the loss went down. Then, we ran the `train!` many times to finish the training process. After we trained the model, we verified it with the test data to verify the results. diff --git a/src/Flux.jl b/src/Flux.jl index fcb473ba2c..80eb87df1e 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -34,6 +34,10 @@ export Descent, Adam, Momentum, Nesterov, RMSProp, AdamW, RAdam, AdaBelief, InvDecay, ExpDecay, WeightDecay, ClipValue, ClipNorm +include("train.jl") +using .Train +# using .Train: setup, @train_autodiff + using CUDA const use_cuda = Ref{Union{Nothing,Bool}}(nothing) diff --git a/src/deprecations.jl b/src/deprecations.jl index 8c3bc963a4..862ebaa105 100644 --- a/src/deprecations.jl +++ b/src/deprecations.jl @@ -82,3 +82,65 @@ Base.@deprecate_binding ADAGrad AdaGrad Base.@deprecate_binding ADADelta AdaDelta @deprecate rng_from_array() default_rng_value() + +#= + # Valid method in Optimise, old implicit style, is: + train!(loss, ps::Params, data, opt::AbstractOptimiser; cb = () -> ()) + # Valid methods in Train, new explict style, are: + train!(loss, model, data, opt) + train!(loss, model, data, opt::Optimisers.AbstractRule) + # ... and 3-arg: + train!(loss, model, opt) + train!(loss, model, opt::Optimisers.AbstractRule) + # Provide friendly errors for what happens if you mix these up: +=# +import .Optimise: train! +train!(loss, ps::Params, data, opt) = error("can't mix implict Params with explict state") +train!(loss, ps::Params, opt) = error("can't mix implict Params with explict state") + +train!(loss, ps::Params, data, opt::Optimisers.AbstractRule) = error("can't mix implict Params with explict rule") +train!(loss, ps::Params, opt::Optimisers.AbstractRule) = error("can't mix implict Params with explict rule") + +train!(loss, model, data, opt::Optimise.AbstractOptimiser) = train!(loss, model, data, _old_to_new(opt)) +train!(loss, model, opt::Optimise.AbstractOptimiser) = train!(loss, model, _old_to_new(opt)) + +train!(loss, ps::Params, opt::Optimise.AbstractOptimiser; cb=0) = error("3-arg train does not exist for implicit mode") + +# train!(loss::Function, ps::Zygote.Params, data, opt) = throw(ArgumentError( +# """On Flux 0.14, `train!` no longer accepts implicit `Zygote.Params`. +# Instead of `train!(loss_xy, Flux.params(model), data, Adam())` +# it now needs `opt = Flux.setup(Adam(), model); train!(loss_mxy, model, data, opt)` +# where `loss_mxy` accepts the model as its first argument. +# """ +# )) + +# Next, to use the new `setup` with the still-exported old-style Adam etc: +import .Train: setup +setup(rule::Optimise.AbstractOptimiser, model) = setup(_old_to_new(rule), model) + +for T in [:Descent, :Adam, :Momentum, :Nesterov, :RMSProp, + :AdaGrad, :AdaMax, :AdaDelta, :AMSGrad, :NAdam, :RAdam, :OAdam, :AdaBelief, + # :InvDecay, :ExpDecay, + ] + @eval function _old_to_new(rule::$T) + args = map(f -> getfield(rule, f), fieldnames(Optimisers.$T)) + Optimisers.$T(args...) + end +end +_old_to_new(rule::Optimiser) = Optimisers.OptimiserChain(map(_old_to_new, rule.os)...) +const OptimiserChain = Optimise.Optimiser # lets you use new name with implicit params too. +_old_to_new(rule::WeightDecay) = Optimisers.WeightDecay(rule.wd) # called gamma now +_old_to_new(rule::ClipNorm) = Optimisers.ClipNorm(rule.thesh) # called omega, and there are more fields +_old_to_new(rule::ClipValue) = Optimisers.ClipGrad(rule.thesh) # called delta now, and struct name differs +const ClipGrad = Optimise.ClipValue +_old_to_new(rule::RMSProp) = Optimisers.RMSProp(rule.eta, rule.rho, rule.epsilon) # RMSProp has no field centred + +_old_to_new(rule) = error("Flux.setup does not know how to translate this old-style implicit rule to a new-style Optimisers.jl explicit rule") + +Optimisers.setup(rule::Optimise.AbstractOptimiser, model) = error("please use Flux.setup not Optimisers.setup, it may be able to translate this rule") + +# v0.14 deprecations + +# Enable these when 0.14 is released, and delete const ClipGrad = Optimise.ClipValue etc: +# Base.@deprecate_binding Optimiser OptimiserChain +# Base.@deprecate_binding ClipValue ClipGrad diff --git a/src/optimise/train.jl b/src/optimise/train.jl index b6d6986285..a960826835 100644 --- a/src/optimise/train.jl +++ b/src/optimise/train.jl @@ -1,16 +1,23 @@ using ProgressLogging: @progress, @withprogress, @logprogress import Zygote: Params, gradient +# Add methods to Optimisers.jl's function, so that there is just one Flux.update! +# for both explicit and implicit parameters. +import Optimisers.update! """ update!(opt, p, g) update!(opt, ps::Params, gs) Perform an update step of the parameters `ps` (or the single parameter `p`) -according to optimizer `opt` and the gradients `gs` (the gradient `g`). +according to optimizer `opt::AbstractOptimiser` and the gradients `gs` (the gradient `g`). As a result, the parameters are mutated and the optimizer's internal state may change. The gradient could be mutated as well. + +!!! note + This method for implicit `Params` (and `AbstractOptimiser`) will be removed from Flux 0.14. + The explicit method `update!(opt, model, grad)` from Optimisers.jl will remain. """ function update!(opt::AbstractOptimiser, x, x̄) x̄r = ArrayInterface.restructure(x, x̄) # address some cases where Zygote's @@ -88,6 +95,10 @@ batchmemaybe(x::Tuple) = x Uses a `loss` function and training `data` to improve the model's parameters according to a particular optimisation rule `opt`. +!!! note + This method with implicit `Params` will be removed from Flux 0.14. + It should be replaced with the explicit method `train!(loss, model, data, opt)`. + For each `d in data`, first the gradient of the `loss` is computed like this: ``` gradient(() -> loss(d...), pars) # if d isa Tuple diff --git a/src/train.jl b/src/train.jl new file mode 100644 index 0000000000..884c3578c0 --- /dev/null +++ b/src/train.jl @@ -0,0 +1,233 @@ +module Train + +using LinearAlgebra +using Optimisers: Optimisers +using Functors: fmap + +import ..Flux.Optimise: train!, update! # during 0.13, we add methods to the old functions + +export setup, @train_autodiff + +using ProgressLogging: @progress, @withprogress, @logprogress # TODO add progress logging again +using Zygote: Zygote, Params + +""" + opt = setup(rule, model) + +This is a version of `Optimisers.setup`, and is the first step before using `train!`. +It differs from `Optimisers.setup` in that it: +* has one extra check for mutability +* has methods which accept Flux's old optimisers, and convert them. + +```jldoctest +julia> model = Dense(2=>1, leakyrelu; init=Flux.ones32); + +julia> opt = Flux.setup(Momentum(0.11), model) +(weight = Leaf(Momentum{Float64}(0.11, 0.9), Float32[0.0 0.0]), bias = Leaf(Momentum{Float64}(0.11, 0.9), Float32[0.0]), σ = ()) + +julia> Flux.train!(model, opt) do m # 3-arg train!, for one data point (x = [0.2, -0.3], y = [0.4]) + sum(m([0.2, -0.3]) .- [0.4]) * 100 + end +-40.1 + +julia> model.bias # was zero, mutated by Flux.train! +1-element Vector{Float32}: + -0.11 + +julia> opt # mutated by Flux.train! +(weight = Leaf(Momentum{Float64}(0.11, 0.9), Float32[0.022 -0.033]), bias = Leaf(Momentum{Float64}(0.11, 0.9), Float32[0.11]), σ = ()) +``` +""" +function setup(rule::Optimisers.AbstractRule, model) + state = Optimisers.setup(rule, model) + fmap(model, exclude = Optimisers.isnumeric) do x + Optimisers.maywrite(x) || error("""model must be fully mutable for `train!` to work, got `x::$(typeof(x))`. + If `x .+= dx` is in fact ok, define `Optimisers.maywrite(::$(typeof(x))) = true`""") + end + state +end + +# opt = Flux.setup(Adam(), model); train!(model, opt) do m ... +setup(model, rule::Optimisers.AbstractRule) = setup(rule, model) + +""" + train!(loss, model, data, opt) + +Uses a `loss` function and training `data` to improve the `model`'s parameters +according to a particular optimisation rule `opt`. + +!!! note + This method has significant changes from the one in Flux ≤ 0.13: + * It now takes the `model` itself, not the result of [`Flux.params`](@ref). + (This is to move away from Zygote's implicit parameter handling.) + * Instead of `loss` being a function which typically accepts two arguments + (the input `x` and expected output `y` from each element of `data`) + now it should typically accept three, the first of which is the `model` itself. + * `data` should iterate tuples or NamedTuples + * `opt` should be the result of [`Flux.setup`](@ref). + * Callback functions are not supported. + +For example, with these definitions... +``` +data = [(x1, y1), (x2, y2), (x3, y3)]; # each element must be a tuple (or NamedTuple) + +loss3(m, x, y) = norm(m(x) .- y) # the model is the first argument + +opt = Flux.setup(Adam(), model) # explicit setup of optimiser momenta +``` +...calling `train!(loss3, model, data, opt)` runs a loop much like this: +``` +for d in data + ∂L∂m = Zygote.gradient(loss3, model, d...)[1] + Optimisers.update!(opt, model, ∂L∂m) +end +``` +Stops with a `DomainError` if the loss is infinite or `NaN` at any point. + +Returns a vector containing the value of the loss function at each datapoint. + +The built-in loss functions accept 3 arguments, allowing for instance `train!(Flux.Losses.mse, model, data, opt)`. + +Callback functions are not supported. But see 3-argument `train!(loss, model, opt)` for an +easy way to construct more complicated training loops. + +To change the package used to calculate gradients, use [`Flux.@train_autodiff`](@ref). +""" +function train!(loss, model, data, opt) + losses = Float32[] + @withprogress for (i,d) in enumerate(data) + l, (g, _...) = explicit_withgradient(loss, model, data_splat(d)...) + isfinite(l) || throw(DomainError("loss function returned $l, stopping training")) + opt, model = Optimisers.update!(opt, model, g) + push!(losses, l) + @logprogress Base.haslength(data) ? i/length(data) : nothing + end + return losses # Not entirely sure returning losses is a good idea +end + +data_splat(x::T) where T = error("""train! expects every d in data be a Tuple or a NamedTuple, got $T + To allow this type, define `Flux.Train.data_splat(x::$T) = (x,)`""") +data_splat(x::Tuple) = x +data_splat(x::NamedTuple) = x +data_splat(x::AbstractArray{<:Number}) = (x,) + +""" + train!(loss, model, opt) + +Uses a `loss` function improve the `model`'s parameters. + +While the 4-argument method of `train!` iterates over a dataset, +this 3-argument method is for a single datapoint, and calls `gradient` just once. +It expects a function `loss` which takes just one argument, the model. +For example: +``` +opt = Flux.setup(Adam(), model) # explicit setup +train!(model, opt) do m # the model is passed to the function as `m` + Flux.crossentropy(m(x1), y1) # but the data point `(x1, y1)` is closed over. +end +``` +This calls `Zygote.withgradient(m -> Flux.crossentropy(m(x1), y1), model)`. +(The `do` block is another syntax for this anonymous function.) +Then it updates the parameters contained within `model` according to `opt`. +Finally it returns the value of the loss function. + +To iterate over a dataset, writing a loop allows more control than +calling 4-argument `train!`. For example, this adds printing and an early stop: +``` +data = Flux.DataLoader((Xtrain, Ytrain), batchsize=32) +opt = Flux.setup(Adam(), model) +for (i, d) in enumerate(data) + x, y = d + ell = Flux.train!(m -> Flux.crossentropy(m(x), y), model, opt) + i%10==0 && println("on step \$i, the loss was \$ell") # prints every 10th step + ell<0.1 && break # stops training +end +``` + +To change the package used to calculate gradients, use [`Flux.@train_autodiff`](@ref). + +!!! note + This method has no implicit `Params` analog in Flux ≤ 0.13. +""" +function train!(loss, model, opt) + l, (g, _...) = explicit_withgradient(loss, model) + isfinite(l) || return l + _, model = Optimisers.update!(opt, model, g) + return l +end + +# These methods let you use Optimisers.Descent() without setup, when there is no state +function train!(loss, model, data, rule::Optimisers.AbstractRule) + train!(loss, model, data, _rule_to_state(model, rule)) +end +function train!(loss, model, rule::Optimisers.AbstractRule) + train!(loss, model, _rule_to_state(model, rule)) +end + +function _rule_to_state(model, rule::Optimisers.AbstractRule) + state = setup(rule, model) + @gensym warn_id + name = typeof(rule).name.name + fmap(state, exclude = x -> x isa Optimisers.Leaf) do leaf + leaf.state isa Nothing || @warn """Optimiser $name has state which will be discarded after `train!` finishes. + Please run `opt = Flux.setup($name(), model)` and pass this `opt` to `train!`.""" leaf maxlog=1 _id=warn_id + leaf + end + state +end + +explicit_withgradient(f, args...) = Zygote.withgradient(f, args...) # can overload this to use e.g. Yota / Diffractor + +""" + Flux.@train_autodiff Tracker + Flux.@train_autodiff Zygote + Flux.@train_autodiff Yota + Flux.@train_autodiff Diffractor + +This macro allows the use of `train!` with various automatic differentiation (AD) packages, +instead of the default Zygote.jl. + +You should load AD package, and then call this macro with the chosen name. +The macro overwrites a method withing Flux, thus is a global setting, lasting until you re-start Julia. + +Only works with [Yota.jl](https://github.com/dfdx/Yota.jl), +[Tracker.jl](https://github.com/FluxML/Tracker.jl) (Flux's old AD), +[Diffractor.jl](https://github.com/JuliaDiff/Diffractor.jl) (which is not yet registered), +and with the default [Zygote.jl](https://github.com/FluxML/Zygote.jl). + +!!! note + This is mechanism is experimental! And there are known bugs, in particular Tracker will not automatically switch to training mode for `Dropout` etc. +""" +macro train_autodiff(pkg) + if pkg == :Diffractor + return quote + Diffractor.gradient(sin, 0.0)[1] ≈ 1.0 # ensures an error if not loaded + function Flux.Train.explicit_withgradient(f, args...) + y, back = Diffractor.∂⃖¹(f, args...) + dy1 = Flux.Zygote.sensitivity(y) # Zygote is loaded, and this gives nice errors + return (; value = y, gradient = Base.tail(back(dy1))) + end + end |> esc + elseif pkg == :Yota + return quote + Yota.grad(sin, 0.0) # [2][1] ≈ 1.0 + function Flux.Train.explicit_withgradient(f, args...) + value, (_, gradient...) = Yota.grad(f, args...) + return (; value, gradient) + end + end |> esc + elseif pkg == :Tracker + return quote + Tracker.withgradient(sum, [1.0]).val == 1.0 # ensures an error if too-old version + Flux.Train.explicit_withgradient(f, args...) = Tracker.withgradient(f, args...) + end |> esc + elseif pkg == :Zygote + return quote + Flux.Train.explicit_withgradient(f, args...) = Flux.Zygote.withgradient(f, args...) + end |> esc + else + throw("@train_autodiff expects one of Tracker, Zygote, Yota, or Diffractor. No other arguments are understood.") + end +end + +end # module diff --git a/test/runtests.jl b/test/runtests.jl index 9027b114fc..29b2bad311 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -16,8 +16,9 @@ Random.seed!(0) include("utils.jl") end - @testset "Optimise" begin + @testset "Optimise / Train" begin include("optimise.jl") + include("train.jl") end @testset "Data" begin diff --git a/test/train.jl b/test/train.jl new file mode 100644 index 0000000000..81ffa2f3db --- /dev/null +++ b/test/train.jl @@ -0,0 +1,131 @@ +using Flux +# using Flux.Train +import Optimisers + +using Test +using Random + +@testset "Explicit Flux.train! with Zygote" begin + Random.seed!(84) + w = randn(10, 10) + w2 = randn(10, 10) # NB outside the inner @testset, else it will be exactly == w, as the RNG seed is reset. + @testset for rule in [AdamW(), AdaGrad(0.1), AdaMax(), AdaDelta(0.9), AMSGrad(), + NAdam(), RAdam(), Descent(0.1), Adam(), OAdam(), AdaBelief(), + Nesterov(), RMSProp(), Momentum()] + + loss(m, x) = Flux.Losses.mse(w*x, m.weight*x .+ m.bias) + model = (weight=copy(w2), bias=zeros(10), ignore=nothing) + @test loss(model, rand(10, 10)) > 1 + + opt = Flux.setup(rule, model) + Flux.train!(loss, model, ((rand(10),) for _ in 1: 10^5), opt) + @test loss(model, rand(10, 10)) < 0.01 + end + + # Test 3-arg `Flux.train!` method: + @testset for rule in [Descent(0.1), Adam(), AdamW()] + + loss(m) = let x = rand(10) + Flux.Losses.mse(w*x, m.weight*x .+ m.bias) + end + model = (weight=copy(w2), bias=zeros(10), ignore=nothing) + @test loss(model) > 1 + + opt = Flux.setup(rule, model) + for i in 1:10^5 + Flux.train!(loss, model, opt) + end + @test loss(model) < 0.01 + end + + # Test direct use of Optimisers.jl rule, only really OK for `Descent`: + @testset "without setup, $opt" for opt in [Descent(0.1), Optimisers.Descent(0.1), Optimisers.Adam()] + loss(m, x) = Flux.Losses.mse(w*x, m.weight*x .+ m.bias) + model = (weight=copy(w2), bias=zeros(10), ignore=nothing) + @test loss(model, rand(10, 10)) > 1 + Flux.train!(loss, model, ((rand(10),) for _ in 1: 10^5), opt) + @test loss(model, rand(10, 10)) < 0.01 + end +end + +@testset "Explicit Flux.train! features" begin + # Test that splat accepts NamedTuple + # Test NaN / Inf early stop + # Test that loss is returned +end + +import Tracker +Flux.@train_autodiff Tracker + +@testset "Explicit Flux.train! with Tracker" begin + Random.seed!(84) + w = randn(10, 10) + w2 = randn(10, 10) # NB outside the inner @testset, else it will be exactly == w, as the RNG seed is reset. + @testset for rule in [Descent(0.1), Adam(), AdamW()] + + loss(m, x) = begin + Flux.istraining() && error("This test is not in fact using Tracker!") + Flux.Losses.mse(w*x, m.weight*x .+ m.bias) + end + model = (weight=copy(w2), bias=zeros(10), ignore=nothing) + @test loss(model, rand(10, 10)) > 1 + + opt = Flux.setup(rule, model) + Flux.train!(loss, model, ((rand(10),) for _ in 1: 10^5), opt) + @test loss(model, rand(10, 10)) < 0.01 + end + + # Test 3-arg `Flux.train!` method: + @testset for rule in [Descent(0.1), Adam()] + + loss(m) = let x = rand(10) + Flux.istraining() && error("This test is not in fact using Tracker!") + Flux.Losses.mse(w*x, m.weight*x .+ m.bias) + end + model = (weight=copy(w2), bias=zeros(10), ignore=nothing) + @test loss(model) > 1 + + opt = Flux.setup(rule, model) + for i in 1:10^5 + Flux.train!(loss, model, opt) + end + @test loss(model) < 0.01 + end +end + +import Yota +Flux.@train_autodiff Yota + +@testset "Explicit Flux.train! with Yota" begin + Random.seed!(84) + w = randn(10, 10) + w2 = randn(10, 10) # NB outside the inner @testset, else it will be exactly == w, as the RNG seed is reset. + @testset for rule in [Descent(0.1), Adam(), AdamW()] + + loss(m, x) = Flux.Losses.mse(w*x, m.weight*x .+ m.bias) + model = (weight=copy(w2), bias=zeros(10), ignore=nothing) + @test loss(model, rand(10, 10)) > 1 + + opt = Flux.setup(rule, model) + Flux.train!(loss, model, ((rand(10),) for _ in 1: 10^5), opt) + @test loss(model, rand(10, 10)) < 0.01 + end + + # Test 3-arg `Flux.train!` method: + @testset for rule in [Descent(0.1), Adam()] + + loss(m) = let x = rand(10) + Flux.Losses.mse(w*x, m.weight*x .+ m.bias) + end + model = (weight=copy(w2), bias=zeros(10), ignore=nothing) + @test loss(model) > 1 + + opt = Flux.setup(rule, model) + for i in 1:10^5 + Flux.train!(loss, model, opt) + end + @test loss(model) < 0.01 + end +end + +Flux.@train_autodiff Zygote