diff --git a/NEWS.md b/NEWS.md index d83e76d62a..6d48d6380c 100644 --- a/NEWS.md +++ b/NEWS.md @@ -2,6 +2,7 @@ ## v0.13.7 * Added [`@autosize` macro](https://github.com/FluxML/Flux.jl/pull/2078) +* New method of `train!` using Zygote's "explicit" mode. Part of a move away from "implicit" `Params`. ## v0.13.4 * Added [`PairwiseFusion` layer](https://github.com/FluxML/Flux.jl/pull/1983) diff --git a/Project.toml b/Project.toml index ad77b5167e..84e20d8e9c 100644 --- a/Project.toml +++ b/Project.toml @@ -1,10 +1,9 @@ name = "Flux" uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c" -version = "0.13.6" +version = "0.13.8" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" -ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" @@ -27,16 +26,15 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] Adapt = "3.0" -ArrayInterface = "3.1, 4, 5, 6" CUDA = "3" ChainRulesCore = "1.12" Functors = "0.3" -MLUtils = "0.2" +MLUtils = "0.2, 0.3.1" MacroTools = "0.5" NNlib = "0.8.9" NNlibCUDA = "0.2.4" -OneHotArrays = "0.1" -Optimisers = "0.2.1" +OneHotArrays = "0.1, 0.2" +Optimisers = "0.2.10" ProgressLogging = "0.1" Reexport = "0.2, 1.0" SpecialFunctions = "1.8.2, 2.1.2" diff --git a/docs/src/models/overview.md b/docs/src/models/overview.md index c6d52ae083..5700e02b25 100644 --- a/docs/src/models/overview.md +++ b/docs/src/models/overview.md @@ -77,13 +77,15 @@ julia> predict(x_train) In order to make better predictions, you'll need to provide a *loss function* to tell Flux how to objectively *evaluate* the quality of a prediction. Loss functions compute the cumulative distance between actual values and predictions. ```jldoctest overview; filter = r"[+-]?([0-9]*[.])?[0-9]+(f[+-]*[0-9])?" -julia> loss(x, y) = Flux.Losses.mse(predict(x), y); +julia> using Statistics -julia> loss(x_train, y_train) +julia> loss(model, x, y) = mean(abs2.(model(x) .- y)); + +julia> loss(predict, x_train, y_train) 122.64734f0 ``` -More accurate predictions will yield a lower loss. You can write your own loss functions or rely on those already provided by Flux. This loss function is called [mean squared error](https://www.statisticshowto.com/probability-and-statistics/statistics-definitions/mean-squared-error/). Flux works by iteratively reducing the loss through *training*. +More accurate predictions will yield a lower loss. You can write your own loss functions or rely on those already provided by Flux. This loss function is called [mean squared error](https://www.statisticshowto.com/probability-and-statistics/statistics-definitions/mean-squared-error/) (and built-in as [`mse`](@ref Flux.Losses.mse)). Flux works by iteratively reducing the loss through *training*. ## 3. Improve the Prediction @@ -112,40 +114,28 @@ julia> predict.bias 0.0 ``` -The dimensions of these model parameters depend on the number of inputs and outputs. Since models can have hundreds of inputs and several layers, it helps to have a function to collect the parameters into the data structure Flux expects: - -```jldoctest overview; filter = r"[+-]?([0-9]*[.])?[0-9]+(f[+-]*[0-9])?" -julia> parameters = Flux.params(predict) -Params([Float32[0.9066542], Float32[0.0]]) -``` - -These are the parameters Flux will change, one step at a time, to improve predictions. At each step, the contents of this `Params` object changes too, since it is just a collection of references to the mutable arrays inside the model: - -```jldoctest overview -julia> predict.weight in parameters, predict.bias in parameters -(true, true) -``` +The dimensions of these model parameters depend on the number of inputs and outputs. -The first parameter is the weight and the second is the bias. Flux will adjust predictions by iteratively changing these parameters according to the optimizer. +Flux will adjust predictions by iteratively changing these parameters according to the optimizer. This optimiser implements the classic gradient descent strategy. Now improve the parameters of the model with a call to [`Flux.train!`](@ref) like this: ```jldoctest overview -julia> train!(loss, parameters, data, opt) +julia> train!(loss, predict, data, opt) ``` And check the loss: ```jldoctest overview; filter = r"[+-]?([0-9]*[.])?[0-9]+(f[+-]*[0-9])?" -julia> loss(x_train, y_train) +julia> loss(predict, x_train, y_train) 116.38745f0 ``` It went down. Why? ```jldoctest overview; filter = r"[+-]?([0-9]*[.])?[0-9]+(f[+-]*[0-9])?" -julia> parameters -Params([Float32[7.5777884], Float32[1.9466728]]) +julia> predict.weight, predict.bias +(Float32[7.5777884], Float32[1.9466728]) ``` The parameters have changed. This single step is the essence of machine learning. @@ -156,14 +146,14 @@ In the previous section, we made a single call to `train!` which iterates over t ```jldoctest overview; filter = r"[+-]?([0-9]*[.])?[0-9]+(f[+-]*[0-9])?" julia> for epoch in 1:200 - train!(loss, parameters, data, opt) + train!(loss, predict, data, opt) end -julia> loss(x_train, y_train) +julia> loss(predict, x_train, y_train) 0.00339581f0 -julia> parameters -Params([Float32[4.0178537], Float32[2.0050256]]) +julia> predict.weight, predict.bias +(Float32[4.0178537], Float32[2.0050256]) ``` After 200 training steps, the loss went down, and the parameters are getting close to those in the function the model is built to predict. @@ -188,7 +178,7 @@ First, we gathered real-world data into the variables `x_train`, `y_train`, `x_t Then, we built a single input, single output predictive model, `predict = Dense(1 => 1)`. The initial predictions weren't accurate, because we had not trained the model yet. -After building the model, we trained it with `train!(loss, parameters, data, opt)`. The loss function is first, followed by the `parameters` holding the weights and biases of the model, the training data, and the `Descent` optimizer provided by Flux. We ran the training step once, and observed that the parameters changed and the loss went down. Then, we ran the `train!` many times to finish the training process. +After building the model, we trained it with `train!(loss, predict, data, opt)`. The loss function is first, followed by the model itself, the training data, and the `Descent` optimizer provided by Flux. We ran the training step once, and observed that the parameters changed and the loss went down. Then, we ran the `train!` many times to finish the training process. After we trained the model, we verified it with the test data to verify the results. diff --git a/src/Flux.jl b/src/Flux.jl index fcb473ba2c..80eb87df1e 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -34,6 +34,10 @@ export Descent, Adam, Momentum, Nesterov, RMSProp, AdamW, RAdam, AdaBelief, InvDecay, ExpDecay, WeightDecay, ClipValue, ClipNorm +include("train.jl") +using .Train +# using .Train: setup, @train_autodiff + using CUDA const use_cuda = Ref{Union{Nothing,Bool}}(nothing) diff --git a/src/deprecations.jl b/src/deprecations.jl index 8c3bc963a4..1cac2a4b86 100644 --- a/src/deprecations.jl +++ b/src/deprecations.jl @@ -82,3 +82,105 @@ Base.@deprecate_binding ADAGrad AdaGrad Base.@deprecate_binding ADADelta AdaDelta @deprecate rng_from_array() default_rng_value() + +#= + # Valid method in Optimise, old implicit style, is: + train!(loss, ps::Params, data, opt::AbstractOptimiser; cb = () -> ()) + + # Valid methods in Train, new explict style, are: + train!(loss, model, data, opt) # preferred + train!(loss, model, data, opt::Optimisers.AbstractRule) # if you forget setup + + # Provide friendly errors for what happens if you mix these up: +=# +import .Optimise: train! + +train!(loss, ps::Params, data, opt) = error( + """can't mix implict Params with explict state! + To use `Flux.params(m)` in `train!`, the 4th argument must be from the old `Flux.Optimise` sub-module. + But better to use the new explicit style, in which `m` itself is the 2nd argument. + """) + +train!(loss, ps::Params, data, opt::Optimisers.AbstractRule) = error( + """can't mix implict Params with explict rule from Optimisers.jl + To use `Flux.params(m)` in `train!`, the 4th argument must be from the old `Flux.Optimise` sub-module. + But better to use the new explicit style, in which `m` itself is the 2nd argument. + """) + +train!(loss, model, data, opt::Optimise.AbstractOptimiser) = train!(loss, model, data, _old_to_new(opt)) + +# Next, to use the new `setup` with the still-exported old-style `Adam` etc: +import .Train: setup +setup(rule::Optimise.AbstractOptimiser, model) = setup(_old_to_new(rule), model) +# ... and allow accidental use of `Optimisers.setup` to do the same: +Optimisers.setup(rule::Optimise.AbstractOptimiser, model) = setup(_old_to_new(rule), model) + +for T in [:Descent, :Adam, :Momentum, :Nesterov, + :AdaGrad, :AdaMax, :AdaDelta, :AMSGrad, :NAdam, :RAdam, :OAdam, :AdaBelief, + # :InvDecay, :ExpDecay, + ] + @eval function _old_to_new(rule::$T) + args = map(f -> getfield(rule, f), fieldnames(Optimisers.$T)) + Optimisers.$T(args...) + end +end +_old_to_new(rule::Optimiser) = Optimisers.OptimiserChain(map(_old_to_new, rule.os)...) +const OptimiserChain = Optimise.Optimiser # lets you use new name with implicit params too. +_old_to_new(rule::WeightDecay) = Optimisers.WeightDecay(rule.wd) # called gamma now +_old_to_new(rule::ClipNorm) = Optimisers.ClipNorm(rule.thesh) # called omega, and there are more fields +_old_to_new(rule::ClipValue) = Optimisers.ClipGrad(rule.thesh) # called delta now, and struct name differs +const ClipGrad = Optimise.ClipValue +_old_to_new(rule::RMSProp) = Optimisers.RMSProp(rule.eta, rule.rho, rule.epsilon) # RMSProp has no field centred + +_old_to_new(rule) = error("Flux.setup does not know how to translate this old-style implicit rule to a new-style Optimisers.jl explicit rule") + +# Since `update!` should be called in a loop, it makes less sense to call `setup` for you if you forgot. +# But let's make sure that such uses give a helpful error: +import .Optimise: update! + +function update!(opt::Optimise.AbstractOptimiser, model, grad) + # This error method requires narrowing the main worker method of Flux.Optimise + # to accept only arrays. Remove if this causes problems! + # update!(opt::Flux.Optimise.AbstractOptimiser, x::AbstractArray, x̄) + error("""Invalid input to `update!`. + * For the implicit style, this needs `update(::AbstractOptimiser, ::Params, ::Grads)` + * For the explicit style, `update(state, model, grad)` needs `state = Flux.setup(opt, model)`. + """) +end + +# An easy error to make is to pass result of explicit gradient(...), not gradient(...)[1] +# Can't catch every case, but can catch many simple Flux models: + +function update!(opt, model::Chain, grads::Tuple) + # Zygote will make a NamedTuple{(:layers,)} for the gradient of Chain, Diffractor a Tangent + @warn """explicit `update!(opt, model, grad)` wants the gradient for the model alone, + not the whole tuple from `gradient(m -> loss(m, x, y), model)`. You probably want `grads[1]`.""" + update!(opt, model, grads[1]) +end + +function update!(opt::Optimise.AbstractOptimiser, model::Chain, grads::Tuple) # ambiguity + update!(opt, model, grads[1]) # calls error case "Invalid input" just above +end + +# One more easy error to catch is using explicit gradient with `params(m)`: + +function update!(opt::Optimise.AbstractOptimiser, ::Params, grads::Union{Tuple, NamedTuple}) + error("""can't mix implicit Params with explicit gradients! + * For the implicit style, this needs `update(::AbstractOptimiser, ::Params, ::Grads)` with implicit gradient. + * For the explicit style, `update(state, model, grad)` needs the model itself, and `state = Flux.setup(opt, model)`. + """) +end + +# v0.14 deprecations + +# Enable these when 0.14 is released, and delete const ClipGrad = Optimise.ClipValue etc: +# Base.@deprecate_binding Optimiser OptimiserChain +# Base.@deprecate_binding ClipValue ClipGrad + +# train!(loss::Function, ps::Zygote.Params, data, opt) = throw(ArgumentError( +# """On Flux 0.14, `train!` no longer accepts implicit `Zygote.Params`. +# Instead of `train!(loss_xy, Flux.params(model), data, Adam())` +# it now needs `opt = Flux.setup(Adam(), model); train!(loss_mxy, model, data, opt)` +# where `loss_mxy` accepts the model as its first argument. +# """ +# )) diff --git a/src/optimise/Optimise.jl b/src/optimise/Optimise.jl index e691ce0170..48f660ffdb 100644 --- a/src/optimise/Optimise.jl +++ b/src/optimise/Optimise.jl @@ -1,7 +1,6 @@ module Optimise using LinearAlgebra -import ArrayInterface export train!, update!, Descent, Adam, Momentum, Nesterov, RMSProp, diff --git a/src/optimise/train.jl b/src/optimise/train.jl index a1c3e9a7aa..d0de78e01a 100644 --- a/src/optimise/train.jl +++ b/src/optimise/train.jl @@ -1,20 +1,27 @@ using ProgressLogging: @progress, @withprogress, @logprogress import Zygote: Params, gradient, withgradient +# Add methods to Optimisers.jl's function, so that there is just one Flux.update! +# for both explicit and implicit parameters. +import Optimisers.update! """ update!(opt, p, g) update!(opt, ps::Params, gs) Perform an update step of the parameters `ps` (or the single parameter `p`) -according to optimizer `opt` and the gradients `gs` (the gradient `g`). +according to optimizer `opt::AbstractOptimiser` and the gradients `gs` (the gradient `g`). As a result, the parameters are mutated and the optimizer's internal state may change. The gradient could be mutated as well. + +!!! note + This method for implicit `Params` (and `AbstractOptimiser`) will be removed from Flux 0.14. + The explicit method `update!(opt, model, grad)` from Optimisers.jl will remain. """ -function update!(opt::AbstractOptimiser, x, x̄) - x̄r = ArrayInterface.restructure(x, x̄) # address some cases where Zygote's - # output are not mutable, see #1510 +function update!(opt::AbstractOptimiser, x::AbstractArray, x̄) + x̄r = copyto!(similar(x̄), x̄) # Flux.Optimise assumes it can mutate the gradient. This is not + # safe due to aliasing, nor guaranteed to be possible, e.g. Fill. x .-= apply!(opt, x, x̄r) end @@ -88,6 +95,10 @@ batchmemaybe(x::Tuple) = x Uses a `loss` function and training `data` to improve the model's parameters according to a particular optimisation rule `opt`. +!!! note + This method with implicit `Params` will be removed from Flux 0.14. + It should be replaced with the explicit method `train!(loss, model, data, opt)`. + For each `d in data`, first the gradient of the `loss` is computed like this: ``` gradient(() -> loss(d...), pars) # if d isa Tuple diff --git a/src/train.jl b/src/train.jl new file mode 100644 index 0000000000..919821b710 --- /dev/null +++ b/src/train.jl @@ -0,0 +1,131 @@ +module Train + +using LinearAlgebra +using Optimisers: Optimisers +using Functors: fmap + +import ..Flux.Optimise: train!, update! # during 0.13, we add methods to the old functions + +export setup, train! + +using ProgressLogging: @progress, @withprogress, @logprogress +using Zygote: Zygote, Params + +""" + opt = setup(rule, model) + +This is a version of `Optimisers.setup`, and is the first step before using [`train!`](@ref Flux.train!). +It differs from `Optimisers.setup` in that it: +* has one extra check for mutability (since Flux expects to mutate the model in-place, + while Optimisers.jl is designed to return an updated model) +* has methods which accept Flux's old optimisers, and convert them. + (The old `Flux.Optimise.Adam` and new `Optimisers.Adam` are distinct types.) + +# Example +```jldoctest +julia> model = Dense(2=>1, leakyrelu; init=Flux.ones32); + +julia> opt = Flux.setup(Momentum(0.1), model) # this encodes the optimiser and its state +(weight = Leaf(Momentum{Float64}(0.1, 0.9), Float32[0.0 0.0]), bias = Leaf(Momentum{Float64}(0.1, 0.9), Float32[0.0]), σ = ()) + +julia> x1, y1 = [0.2, -0.3], [0.4]; # use the same data for two steps: + +julia> Flux.train!(model, [(x1, y1), (x1, y1)], opt) do m, x, y + sum(abs.(m(x) .- y)) * 100 + end + +julia> model.bias # was zero, mutated by Flux.train! +1-element Vector{Float32}: + 10.190001 + +julia> opt # mutated by Flux.train! +(weight = Leaf(Momentum{Float64}(0.1, 0.9), Float32[-2.018 3.027]), bias = Leaf(Momentum{Float64}(0.1, 0.9), Float32[-10.09]), σ = ()) +``` +""" +function setup(rule::Optimisers.AbstractRule, model) + state = Optimisers.setup(rule, model) + fmap(model, exclude = Optimisers.isnumeric) do x + Optimisers.maywrite(x) || error("""model must be fully mutable for `train!` to work, got `x::$(typeof(x))`. + If `x .+= dx` is in fact ok, define `Optimisers.maywrite(::$(typeof(x))) = true`""") + end + state +end + +""" + train!(loss, model, data, opt) + +Uses a `loss` function and training `data` to improve the `model`'s parameters +according to a particular optimisation rule `opt`. Iterates through `data` once, +evaluating `loss(model, d...)` for each `d` in data. + +For example, with these definitions... +``` +data = [(x1, y1), (x2, y2), (x3, y3)]; # each element must be a tuple + +loss3(m, x, y) = norm(m(x) .- y) # the model is the first argument + +opt = Flux.setup(Adam(), model) # explicit setup of optimiser momenta +``` +...calling `Flux.train!(loss3, model, data, opt)` runs a loop much like this, +using Zygote's "explicit" mode for the gradient: +``` +for d in data + ∂L∂m = gradient(loss3, model, d...)[1] + update!(opt, model, ∂L∂m) # method for "explicit" gradient +end +``` +You can also write this loop yourself, if you need more flexibility. +For this reason `train!` is not highly extensible. +It adds only a few featurs to the loop above: + +* Stop with a `DomainError` if the loss is infinite or `NaN` at any point. + +* Show a progress bar using [`@withprogress`](https://github.com/JuliaLogging/ProgressLogging.jl). + +!!! note + This method has significant changes from the one in Flux ≤ 0.13: + * It now takes the `model` itself, not the result of [`Flux.params`](@ref). + (This is to move away from Zygote's "implicit" parameter handling, with `Grads`.) + * Instead of `loss` being a function which accepts only the data, + now it must also accept the `model` itself, as the first argument. + * `data` must iterate tuples, otherwise you get an error. + (Previously non-tuple types were not splatted into the loss. + Pass in `((d,) for d in data)` to simulate this.) + * `opt` should be the result of [`Flux.setup`](@ref). Using an optimiser + such as `Adam()` without this step should give you a warning. + * Callback functions are not supported. + But any code can be included in the above `for` loop. +""" +function train!(loss, model, data, opt; cb = nothing) + isnothing(cb) || error("""train! does not support callback functions. + For more control use a loop with `gradient` and `update!`.""") + @withprogress for (i,d) in enumerate(data) + d isa Tuple || error("""train! expects as data an iterator producing tuples, but got $(typeof(d)). + Pass it `((d,) for d in data)`, or use `gradient` and `update!` for more control.""") + l, gs = Zygote.withgradient(m -> loss(m, d...), model) + if !isfinite(l) + throw(DomainError("Loss is $l on data item $i, stopping training")) + end + opt, model = Optimisers.update!(opt, model, gs[1]) + @logprogress Base.haslength(data) ? i/length(data) : nothing + end +end + +# This method let you use Optimisers.Descent() without setup, when there is no state +function train!(loss, model, data, rule::Optimisers.AbstractRule) + train!(loss, model, data, _rule_to_state(model, rule)) +end + +function _rule_to_state(model, rule::Optimisers.AbstractRule) + state = setup(rule, model) + @gensym warn_id + name = typeof(rule).name.name + fmap(state, exclude = x -> x isa Optimisers.Leaf) do leaf + leaf.state isa Nothing || @warn """Optimiser $name has state which will be discarded after `train!` finishes. + Please run `opt = Flux.setup($name(), model)` and pass this `opt` to `train!`.""" leaf maxlog=1 _id=warn_id + leaf + end + state +end + +end # module Train diff --git a/test/runtests.jl b/test/runtests.jl index 9027b114fc..29b2bad311 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -16,8 +16,9 @@ Random.seed!(0) include("utils.jl") end - @testset "Optimise" begin + @testset "Optimise / Train" begin include("optimise.jl") + include("train.jl") end @testset "Data" begin diff --git a/test/train.jl b/test/train.jl new file mode 100644 index 0000000000..49ecf9c751 --- /dev/null +++ b/test/train.jl @@ -0,0 +1,93 @@ +using Flux +# using Flux.Train +import Optimisers + +using Test +using Random + +@testset "Explicit Flux.train! with Zygote" begin + Random.seed!(84) + w = randn(10, 10) + w2 = randn(10, 10) # NB outside the inner @testset, else it will be exactly == w, as the RNG seed is reset. + @testset for rule in [AdamW(), AdaGrad(0.1), AdaMax(), AdaDelta(0.9), AMSGrad(), + NAdam(), RAdam(), Descent(0.1), Adam(), OAdam(), AdaBelief(), + Nesterov(), RMSProp(), Momentum()] + + loss(m, x) = Flux.Losses.mse(w*x, m.weight*x .+ m.bias) + model = (weight=copy(w2), bias=zeros(10), ignore=nothing) + @test loss(model, rand(10, 10)) > 1 + + opt = Flux.setup(rule, model) + Flux.train!(loss, model, ((rand(10),) for _ in 1: 10^5), opt) + @test loss(model, rand(10, 10)) < 0.01 + end + + # Test direct use of Optimisers.jl rule, only really OK for `Descent`: + @testset "without setup, $opt" for opt in [Descent(0.1), Optimisers.Descent(0.1), Optimisers.Adam()] + loss(m, x) = Flux.Losses.mse(w*x, m.weight*x .+ m.bias) + model = (weight=copy(w2), bias=zeros(10), ignore=nothing) + @test loss(model, rand(10, 10)) > 1 + Flux.train!(loss, model, ((rand(10),) for _ in 1: 10^5), opt) + @test loss(model, rand(10, 10)) < 0.01 + end +end + +@testset "Explicit Flux.train! features" begin + @testset "Stop on NaN" begin + m1 = Dense(1 => 1) + m1.weight .= 0 + CNT = 0 + @test_throws DomainError Flux.train!(m1, tuple.(1:100), Descent(0.1)) do m, i + CNT += 1 + (i == 51 ? NaN32 : 1f0) * sum(m([1.0])) + end + @test CNT == 51 # stopped early + @test m1.weight[1] ≈ -5 # did not corrupt weights + end + @testset "data must give tuples" begin + m1 = Dense(1 => 1) + @test_throws ErrorException Flux.train!((args...,) -> 1, m1, [(x=1, y=2) for _ in 1:3], Descent(0.1)) + end + @testset "callbacks give helpful error" begin + m1 = Dense(1 => 1) + cb = () -> println("this should not be printed") + @test_throws ErrorException Flux.train!((args...,) -> 1, m1, [(1,2)], Descent(0.1); cb) + end +end + +@testset "Explicit Flux.update! features" begin + m = Chain(Dense(2=>3, tanh), Dense(3=>1), only) + x = rand(2) + y1 = m(x) # before + + # Implicit gradient + gold = gradient(() -> m(x), Flux.params(m)) + @test gold isa Flux.Zygote.Grads + @test_throws ErrorException Flux.update!(Flux.Adam(), m, gold) # friendly + Flux.update!(Flux.Adam(), Flux.params(m), gold) + y2 = m(x) + @test y2 < y1 + + # Explicit gradient + gs = gradient(marg -> marg(x), m) + @test gs isa Tuple + @test_throws ErrorException Flux.update!(Flux.Adam(), Flux.params(m), gs) # friendly + @test_throws ErrorException Flux.update!(Flux.Adam(), Flux.params(m), gs[1]) # friendly + @test_throws ErrorException Flux.update!(Flux.Adam(), m, gs) # friendly + @test_throws ErrorException Flux.update!(Flux.Adam(), m, gs[1]) # friendly + s = Flux.setup(Adam(), m) + @info "ignore this warning, just testing an upgrade path:" + Flux.update!(s, m, gs) # Chain + Tuple can be unambiguously sorted out + y3 = m(x) + @test y3 < y2 + Flux.update!(s, m, gs[1]) # finally, this is the correct thing + y4 = m(x) + @test y4 < y3 + + # Also check that if you import the new Adam, then Flux.setup does still work! + s2 = Flux.setup(Optimisers.Adam(), m) + Flux.update!(s2, m, gs[1]) + y5 = m(x) + @test y5 < y4 +end +