From c20fb9eae8ef5891c8556c2da189580ff3c0ba40 Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Thu, 13 Oct 2022 09:27:16 -0400
Subject: [PATCH 01/18] explicit train, take 2

---
 NEWS.md                     |   1 +
 Project.toml                |   8 +-
 docs/src/models/overview.md |  12 +-
 src/Flux.jl                 |   4 +
 src/deprecations.jl         |  62 ++++++++++
 src/optimise/train.jl       |  13 +-
 src/train.jl                | 233 ++++++++++++++++++++++++++++++++++++
 test/runtests.jl            |   3 +-
 test/train.jl               | 131 ++++++++++++++++++++
 9 files changed, 457 insertions(+), 10 deletions(-)
 create mode 100644 src/train.jl
 create mode 100644 test/train.jl

diff --git a/NEWS.md b/NEWS.md
index d83e76d62a..032819e1d1 100644
--- a/NEWS.md
+++ b/NEWS.md
@@ -2,6 +2,7 @@
 
 ## v0.13.7
 * Added [`@autosize` macro](https://github.com/FluxML/Flux.jl/pull/2078)
+* New method of `train!` using Zygote's "explicit" mode, allows changing AD back-end.
 
 ## v0.13.4
 * Added [`PairwiseFusion` layer](https://github.com/FluxML/Flux.jl/pull/1983)
diff --git a/Project.toml b/Project.toml
index 1ecb639db4..e33a4e3dba 100644
--- a/Project.toml
+++ b/Project.toml
@@ -34,11 +34,13 @@ MacroTools = "0.5"
 NNlib = "0.8.9"
 NNlibCUDA = "0.2.4"
 OneHotArrays = "0.1, 0.2"
-Optimisers = "0.2.1"
+Optimisers = "0.2.10"
 ProgressLogging = "0.1"
 Reexport = "0.2, 1.0"
 SpecialFunctions = "1.8.2, 2.1.2"
 StatsBase = "0.33"
+Tracker = "0.2.22"
+Yota = "0.8.1"
 Zygote = "0.6.34"
 julia = "1.6"
 
@@ -48,7 +50,9 @@ Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
 FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
 IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e"
 LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
+Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
+Yota = "cd998857-8626-517d-b929-70ad188a48f0"
 Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
 
 [targets]
-test = ["Test", "Documenter", "IterTools", "LinearAlgebra", "FillArrays", "ComponentArrays"]
+test = ["Test", "Documenter", "IterTools", "LinearAlgebra", "FillArrays", "ComponentArrays", "Tracker", "Yota"]
diff --git a/docs/src/models/overview.md b/docs/src/models/overview.md
index c6d52ae083..630da338cb 100644
--- a/docs/src/models/overview.md
+++ b/docs/src/models/overview.md
@@ -77,9 +77,9 @@ julia> predict(x_train)
 In order to make better predictions, you'll need to provide a *loss function* to tell Flux how to objectively *evaluate* the quality of a prediction. Loss functions compute the cumulative distance between actual values and predictions. 
 
 ```jldoctest overview; filter = r"[+-]?([0-9]*[.])?[0-9]+(f[+-]*[0-9])?"
-julia> loss(x, y) = Flux.Losses.mse(predict(x), y);
+julia> loss(model, x, y) = mean(abs2.(model(x) .- y));
 
-julia> loss(x_train, y_train)
+julia> loss(predict, x_train, y_train)
 122.64734f0
 ```
 
@@ -131,7 +131,7 @@ The first parameter is the weight and the second is the bias. Flux will adjust p
 This optimiser implements the classic gradient descent strategy. Now improve the parameters of the model with a call to [`Flux.train!`](@ref) like this:
 
 ```jldoctest overview
-julia> train!(loss, parameters, data, opt)
+julia> train!(loss, predict, data, opt)
 ```
 
 And check the loss:
@@ -156,10 +156,10 @@ In the previous section, we made a single call to `train!` which iterates over t
 
 ```jldoctest overview; filter = r"[+-]?([0-9]*[.])?[0-9]+(f[+-]*[0-9])?"
 julia> for epoch in 1:200
-         train!(loss, parameters, data, opt)
+         train!(loss, predict, data, opt)
        end
 
-julia> loss(x_train, y_train)
+julia> loss(predict, x_train, y_train)
 0.00339581f0
 
 julia> parameters
@@ -188,7 +188,7 @@ First, we gathered real-world data into the variables `x_train`, `y_train`, `x_t
 
 Then, we built a single input, single output predictive model, `predict = Dense(1 => 1)`. The initial predictions weren't accurate, because we had not trained the model yet.
 
-After building the model, we trained it with `train!(loss, parameters, data, opt)`. The loss function is first, followed by the `parameters` holding the weights and biases of the model, the training data, and the `Descent` optimizer provided by Flux. We ran the training step once, and observed that the parameters changed and the loss went down. Then, we ran the `train!` many times to finish the training process.
+After building the model, we trained it with `train!(loss, predict, data, opt)`. The loss function is first, followed by the model itself, the training data, and the `Descent` optimizer provided by Flux. We ran the training step once, and observed that the parameters changed and the loss went down. Then, we ran the `train!` many times to finish the training process.
 
 After we trained the model, we verified it with the test data to verify the results. 
 
diff --git a/src/Flux.jl b/src/Flux.jl
index fcb473ba2c..80eb87df1e 100644
--- a/src/Flux.jl
+++ b/src/Flux.jl
@@ -34,6 +34,10 @@ export Descent, Adam, Momentum, Nesterov, RMSProp,
   AdamW, RAdam, AdaBelief, InvDecay, ExpDecay,
   WeightDecay, ClipValue, ClipNorm
 
+include("train.jl")
+using .Train
+# using .Train: setup, @train_autodiff
+
 using CUDA
 const use_cuda = Ref{Union{Nothing,Bool}}(nothing)
 
diff --git a/src/deprecations.jl b/src/deprecations.jl
index 8c3bc963a4..c878c5192b 100644
--- a/src/deprecations.jl
+++ b/src/deprecations.jl
@@ -82,3 +82,65 @@ Base.@deprecate_binding ADAGrad AdaGrad
 Base.@deprecate_binding ADADelta AdaDelta
 
 @deprecate rng_from_array() default_rng_value()
+
+#=
+  # Valid method in Optimise, old implicit style, is:
+  train!(loss, ps::Params, data, opt::AbstractOptimiser; cb = () -> ())
+  # Valid methods in Train, new explict style, are:
+  train!(loss, model, data, opt)
+  train!(loss, model, data, opt::Optimisers.AbstractRule)
+  # ... and 3-arg:
+  train!(loss, model, opt)
+  train!(loss, model, opt::Optimisers.AbstractRule)
+  # Provide friendly errors for what happens if you mix these up:
+=#
+import .Optimise: train!
+train!(loss, ps::Params, data, opt) = error("can't mix implict Params with explict state")
+train!(loss, ps::Params, opt) = error("can't mix implict Params with explict state")
+
+train!(loss, ps::Params, data, opt::Optimisers.AbstractRule) = error("can't mix implict Params with explict rule")
+train!(loss, ps::Params, opt::Optimisers.AbstractRule) = error("can't mix implict Params with explict rule")
+
+train!(loss, model, data, opt::Optimise.AbstractOptimiser) = train!(loss, model, data, _old_to_new(opt))
+train!(loss, model, opt::Optimise.AbstractOptimiser) = train!(loss, model, _old_to_new(opt))
+
+train!(loss, ps::Params, opt::Optimise.AbstractOptimiser; cb=0) = error("3-arg train does not exist for implicit mode")
+
+# train!(loss::Function, ps::Zygote.Params, data, opt) = throw(ArgumentError(
+#   """On Flux 0.14, `train!` no longer accepts implicit `Zygote.Params`.
+#   Instead of `train!(loss_xy, Flux.params(model), data, Adam())`
+#   it now needs `opt = Flux.setup(Adam(), model); train!(loss_mxy, model, data, opt)`
+#   where `loss_mxy` accepts the model as its first argument.
+#   """
+# ))
+
+# Next, to use the new `setup` with the still-exported old-style Adam etc:
+import .Train: setup
+setup(rule::Optimise.AbstractOptimiser, model) = setup(_old_to_new(rule), model)
+
+for T in [:Descent, :Adam, :Momentum, :Nesterov,
+   	      :AdaGrad, :AdaMax, :AdaDelta, :AMSGrad, :NAdam, :RAdam, :OAdam, :AdaBelief,
+   	      # :InvDecay, :ExpDecay, 
+          ]
+  @eval function _old_to_new(rule::$T)
+    args = map(f -> getfield(rule, f), fieldnames(Optimisers.$T))
+    Optimisers.$T(args...)
+  end
+end
+_old_to_new(rule::Optimiser) = Optimisers.OptimiserChain(map(_old_to_new, rule.os)...)
+const OptimiserChain = Optimise.Optimiser  # lets you use new name with implicit params too.
+_old_to_new(rule::WeightDecay) = Optimisers.WeightDecay(rule.wd)  # called gamma now
+_old_to_new(rule::ClipNorm) = Optimisers.ClipNorm(rule.thesh)  # called omega, and there are more fields 
+_old_to_new(rule::ClipValue) = Optimisers.ClipGrad(rule.thesh)  # called delta now, and struct name differs
+const ClipGrad = Optimise.ClipValue
+_old_to_new(rule::RMSProp) = Optimisers.RMSProp(rule.eta, rule.rho, rule.epsilon)  # RMSProp has no field centred
+
+_old_to_new(rule) = error("Flux.setup does not know how to translate this old-style implicit rule to a new-style Optimisers.jl explicit rule")
+
+Optimisers.setup(rule::Optimise.AbstractOptimiser, model) = error("please use Flux.setup not Optimisers.setup, it may be able to translate this rule")
+
+# v0.14 deprecations
+
+# Enable these when 0.14 is released, and delete const ClipGrad = Optimise.ClipValue etc: 
+# Base.@deprecate_binding Optimiser OptimiserChain
+# Base.@deprecate_binding ClipValue ClipGrad
diff --git a/src/optimise/train.jl b/src/optimise/train.jl
index 158fd8b585..74987f2a4e 100644
--- a/src/optimise/train.jl
+++ b/src/optimise/train.jl
@@ -1,16 +1,23 @@
 using ProgressLogging: @progress, @withprogress, @logprogress
 import Zygote: Params, gradient, withgradient
 
+# Add methods to Optimisers.jl's function, so that there is just one Flux.update!
+# for both explicit and implicit parameters.
+import Optimisers.update!
 
 """
     update!(opt, p, g)
     update!(opt, ps::Params, gs)
 
 Perform an update step of the parameters `ps` (or the single parameter `p`)
-according to optimizer `opt`  and the gradients `gs` (the gradient `g`).
+according to optimizer `opt::AbstractOptimiser`  and the gradients `gs` (the gradient `g`).
 
 As a result, the parameters are mutated and the optimizer's internal state may change.
 The gradient could be mutated as well.
+
+!!! note
+    This method for implicit `Params` (and `AbstractOptimiser`) will be removed from Flux 0.14.
+    The explicit method `update!(opt, model, grad)` from Optimisers.jl will remain.
 """
 function update!(opt::AbstractOptimiser, x, x̄)
   x̄r = copyto!(similar(x̄), x̄)  # Flux.Optimise assumes it can mutate the gradient. This is not
@@ -88,6 +95,10 @@ batchmemaybe(x::Tuple) = x
 Uses a `loss` function and training `data` to improve the 
 model's parameters according to a particular optimisation rule `opt`.
 
+!!! note
+    This method with implicit `Params` will be removed from Flux 0.14.
+    It should be replaced with the explicit method `train!(loss, model, data, opt)`.
+
 For each `d in data`, first the gradient of the `loss` is computed like this:
 ```
     gradient(() -> loss(d...), pars)  # if d isa Tuple
diff --git a/src/train.jl b/src/train.jl
new file mode 100644
index 0000000000..884c3578c0
--- /dev/null
+++ b/src/train.jl
@@ -0,0 +1,233 @@
+module Train
+
+using LinearAlgebra
+using Optimisers: Optimisers
+using Functors: fmap
+
+import ..Flux.Optimise: train!, update!  # during 0.13, we add methods to the old functions
+
+export setup, @train_autodiff
+
+using ProgressLogging: @progress, @withprogress, @logprogress  # TODO add progress logging again
+using Zygote: Zygote, Params
+
+"""
+    opt = setup(rule, model)
+
+This is a version of `Optimisers.setup`, and is the first step before using `train!`.
+It differs from `Optimisers.setup` in that it:
+* has one extra check for mutability
+* has methods which accept Flux's old optimisers, and convert them.
+
+```jldoctest
+julia> model = Dense(2=>1, leakyrelu; init=Flux.ones32);
+
+julia> opt = Flux.setup(Momentum(0.11), model)
+(weight = Leaf(Momentum{Float64}(0.11, 0.9), Float32[0.0 0.0]), bias = Leaf(Momentum{Float64}(0.11, 0.9), Float32[0.0]), σ = ())
+
+julia> Flux.train!(model, opt) do m  # 3-arg train!, for one data point (x = [0.2, -0.3], y = [0.4])
+         sum(m([0.2, -0.3]) .- [0.4]) * 100
+       end
+-40.1
+
+julia> model.bias  # was zero, mutated by Flux.train!
+1-element Vector{Float32}:
+ -0.11
+
+julia> opt  # mutated by Flux.train!
+(weight = Leaf(Momentum{Float64}(0.11, 0.9), Float32[0.022 -0.033]), bias = Leaf(Momentum{Float64}(0.11, 0.9), Float32[0.11]), σ = ())
+```
+"""
+function setup(rule::Optimisers.AbstractRule, model)
+    state = Optimisers.setup(rule, model)
+    fmap(model, exclude = Optimisers.isnumeric) do x
+      Optimisers.maywrite(x) || error("""model must be fully mutable for `train!` to work, got `x::$(typeof(x))`.
+                                         If `x .+= dx` is in fact ok, define `Optimisers.maywrite(::$(typeof(x))) = true`""")
+    end
+    state
+end
+
+# opt = Flux.setup(Adam(), model); train!(model, opt) do m ... 
+setup(model, rule::Optimisers.AbstractRule) = setup(rule, model)
+
+"""
+    train!(loss, model, data, opt)
+
+Uses a `loss` function and training `data` to improve the `model`'s parameters
+according to a particular optimisation rule `opt`.
+
+!!! note
+    This method has significant changes from the one in Flux ≤ 0.13:
+    * It now takes the `model` itself, not the result of [`Flux.params`](@ref).
+      (This is to move away from Zygote's implicit parameter handling.)
+    * Instead of `loss` being a function which typically accepts two arguments
+      (the input `x` and expected output `y` from each element of `data`)
+      now it should typically accept three, the first of which is the `model` itself.
+    * `data` should iterate tuples or NamedTuples
+    * `opt` should be the result of [`Flux.setup`](@ref).
+    * Callback functions are not supported.
+
+For example, with these definitions...
+```
+data = [(x1, y1), (x2, y2), (x3, y3)];  # each element must be a tuple (or NamedTuple)
+
+loss3(m, x, y) = norm(m(x) .- y)  # the model is the first argument
+
+opt = Flux.setup(Adam(), model)  # explicit setup of optimiser momenta
+```
+...calling `train!(loss3, model, data, opt)` runs a loop much like this:
+```
+for d in data
+    ∂L∂m = Zygote.gradient(loss3, model, d...)[1]
+    Optimisers.update!(opt, model, ∂L∂m)
+end
+```
+Stops with a `DomainError` if the loss is infinite or `NaN` at any point.
+
+Returns a vector containing the value of the loss function at each datapoint.
+
+The built-in loss functions accept 3 arguments, allowing for instance `train!(Flux.Losses.mse, model, data, opt)`.
+
+Callback functions are not supported. But see 3-argument `train!(loss, model, opt)` for an
+easy way to construct more complicated training loops.
+
+To change the package used to calculate gradients, use [`Flux.@train_autodiff`](@ref).
+"""
+function train!(loss, model, data, opt)
+  losses = Float32[]
+  @withprogress for (i,d) in enumerate(data)
+    l, (g, _...) = explicit_withgradient(loss, model, data_splat(d)...)
+    isfinite(l) || throw(DomainError("loss function returned $l, stopping training"))
+    opt, model = Optimisers.update!(opt, model, g)
+    push!(losses, l)
+    @logprogress Base.haslength(data) ? i/length(data) : nothing
+  end
+  return losses  # Not entirely sure returning losses is a good idea
+end
+
+data_splat(x::T) where T =  error("""train! expects every d in data be a Tuple or a NamedTuple, got $T
+                                   To allow this type, define `Flux.Train.data_splat(x::$T) = (x,)`""")
+data_splat(x::Tuple) = x
+data_splat(x::NamedTuple) = x
+data_splat(x::AbstractArray{<:Number}) = (x,)
+
+"""
+    train!(loss, model, opt)
+    
+Uses a `loss` function improve the `model`'s parameters.
+
+While the 4-argument method of `train!` iterates over a dataset,
+this 3-argument method is for a single datapoint, and calls `gradient` just once.
+It expects a function `loss` which takes just one argument, the model.
+For example:
+```
+opt = Flux.setup(Adam(), model)   # explicit setup
+train!(model, opt) do m           # the model is passed to the function as `m`
+    Flux.crossentropy(m(x1), y1)  # but the data point `(x1, y1)` is closed over.
+end
+```
+This calls `Zygote.withgradient(m -> Flux.crossentropy(m(x1), y1), model)`.
+(The `do` block is another syntax for this anonymous function.)
+Then it updates the parameters contained within `model` according to `opt`.
+Finally it returns the value of the loss function.
+
+To iterate over a dataset, writing a loop allows more control than
+calling 4-argument `train!`. For example, this adds printing and an early stop:
+```
+data = Flux.DataLoader((Xtrain, Ytrain), batchsize=32)
+opt = Flux.setup(Adam(), model)
+for (i, d) in enumerate(data)
+    x, y = d
+    ell = Flux.train!(m -> Flux.crossentropy(m(x), y), model, opt)
+    i%10==0 && println("on step \$i, the loss was \$ell")  # prints every 10th step
+    ell<0.1 && break                                     # stops training
+end
+```
+
+To change the package used to calculate gradients, use [`Flux.@train_autodiff`](@ref).
+
+!!! note
+    This method has no implicit `Params` analog in Flux ≤ 0.13.
+"""
+function train!(loss, model, opt)
+  l, (g, _...) = explicit_withgradient(loss, model)
+  isfinite(l) || return l
+  _, model = Optimisers.update!(opt, model, g)
+  return l
+end
+
+# These methods let you use Optimisers.Descent() without setup, when there is no state
+function train!(loss, model, data, rule::Optimisers.AbstractRule)
+  train!(loss, model, data, _rule_to_state(model, rule))
+end
+function train!(loss, model, rule::Optimisers.AbstractRule)
+  train!(loss, model, _rule_to_state(model, rule))
+end
+
+function _rule_to_state(model, rule::Optimisers.AbstractRule)
+  state = setup(rule, model)
+  @gensym warn_id
+  name = typeof(rule).name.name
+  fmap(state, exclude = x -> x isa Optimisers.Leaf) do leaf
+    leaf.state isa Nothing ||  @warn """Optimiser $name has state which will be discarded after `train!` finishes.
+                                        Please run `opt = Flux.setup($name(), model)` and pass this `opt` to `train!`.""" leaf maxlog=1 _id=warn_id
+    leaf
+  end
+  state
+end
+
+explicit_withgradient(f, args...) = Zygote.withgradient(f, args...)  # can overload this to use e.g. Yota / Diffractor
+
+"""
+    Flux.@train_autodiff Tracker
+    Flux.@train_autodiff Zygote
+    Flux.@train_autodiff Yota
+    Flux.@train_autodiff Diffractor
+
+This macro allows the use of `train!` with various automatic differentiation (AD) packages,
+instead of the default Zygote.jl.
+
+You should load AD package, and then call this macro with the chosen name.
+The macro overwrites a method withing Flux, thus is a global setting, lasting until you re-start Julia.
+
+Only works with [Yota.jl](https://github.com/dfdx/Yota.jl),
+[Tracker.jl](https://github.com/FluxML/Tracker.jl) (Flux's old AD),
+[Diffractor.jl](https://github.com/JuliaDiff/Diffractor.jl) (which is not yet registered),
+and with the default [Zygote.jl](https://github.com/FluxML/Zygote.jl).
+
+!!! note
+    This is mechanism is experimental! And there are known bugs, in particular Tracker will not automatically switch to training mode for `Dropout` etc.
+"""
+macro train_autodiff(pkg)
+  if pkg == :Diffractor
+    return quote
+      Diffractor.gradient(sin, 0.0)[1] ≈ 1.0  # ensures an error if not loaded
+      function Flux.Train.explicit_withgradient(f, args...)
+        y, back = Diffractor.∂⃖¹(f, args...)
+        dy1 = Flux.Zygote.sensitivity(y)  # Zygote is loaded, and this gives nice errors
+        return (; value = y, gradient = Base.tail(back(dy1)))
+      end
+    end |> esc
+  elseif pkg == :Yota
+    return quote
+      Yota.grad(sin, 0.0) # [2][1] ≈ 1.0
+      function Flux.Train.explicit_withgradient(f, args...)
+        value, (_, gradient...) = Yota.grad(f, args...)
+        return (; value, gradient)
+      end
+    end |> esc
+  elseif pkg == :Tracker
+    return quote
+      Tracker.withgradient(sum, [1.0]).val == 1.0  # ensures an error if too-old version
+      Flux.Train.explicit_withgradient(f, args...) = Tracker.withgradient(f, args...)
+    end |> esc
+  elseif pkg == :Zygote
+    return quote
+      Flux.Train.explicit_withgradient(f, args...) = Flux.Zygote.withgradient(f, args...)
+    end |> esc
+  else
+    throw("@train_autodiff expects one of Tracker, Zygote, Yota, or Diffractor. No other arguments are understood.")
+  end
+end
+
+end # module
diff --git a/test/runtests.jl b/test/runtests.jl
index 9027b114fc..29b2bad311 100644
--- a/test/runtests.jl
+++ b/test/runtests.jl
@@ -16,8 +16,9 @@ Random.seed!(0)
     include("utils.jl")
   end
 
-  @testset "Optimise" begin
+  @testset "Optimise / Train" begin
     include("optimise.jl")
+    include("train.jl")
   end
 
   @testset "Data" begin
diff --git a/test/train.jl b/test/train.jl
new file mode 100644
index 0000000000..81ffa2f3db
--- /dev/null
+++ b/test/train.jl
@@ -0,0 +1,131 @@
+using Flux
+# using Flux.Train
+import Optimisers
+
+using Test
+using Random
+
+@testset "Explicit Flux.train! with Zygote" begin
+  Random.seed!(84)
+  w = randn(10, 10)
+  w2 = randn(10, 10)  # NB outside the inner @testset, else it will be exactly == w, as the RNG seed is reset.
+  @testset for rule in [AdamW(), AdaGrad(0.1), AdaMax(), AdaDelta(0.9), AMSGrad(),
+                        NAdam(), RAdam(), Descent(0.1), Adam(), OAdam(), AdaBelief(),
+                        Nesterov(), RMSProp(), Momentum()]
+
+    loss(m, x) = Flux.Losses.mse(w*x, m.weight*x .+ m.bias)
+    model = (weight=copy(w2), bias=zeros(10), ignore=nothing)
+    @test loss(model, rand(10, 10)) > 1
+
+    opt = Flux.setup(rule, model)
+    Flux.train!(loss, model, ((rand(10),) for _ in 1: 10^5), opt)
+    @test loss(model, rand(10, 10)) < 0.01
+  end
+
+  # Test 3-arg `Flux.train!` method:
+  @testset for rule in [Descent(0.1), Adam(), AdamW()]
+
+    loss(m) = let x = rand(10)
+      Flux.Losses.mse(w*x, m.weight*x .+ m.bias)
+    end
+    model = (weight=copy(w2), bias=zeros(10), ignore=nothing)
+    @test loss(model) > 1
+
+    opt = Flux.setup(rule, model)
+    for i in 1:10^5
+      Flux.train!(loss, model, opt)
+    end
+    @test loss(model) < 0.01
+  end
+
+  # Test direct use of Optimisers.jl rule, only really OK for `Descent`:
+  @testset "without setup, $opt" for opt in [Descent(0.1), Optimisers.Descent(0.1), Optimisers.Adam()]
+    loss(m, x) = Flux.Losses.mse(w*x, m.weight*x .+ m.bias)
+    model = (weight=copy(w2), bias=zeros(10), ignore=nothing)
+    @test loss(model, rand(10, 10)) > 1
+    Flux.train!(loss, model, ((rand(10),) for _ in 1: 10^5), opt)
+    @test loss(model, rand(10, 10)) < 0.01
+  end
+end
+
+@testset "Explicit Flux.train! features" begin
+  # Test that splat accepts NamedTuple
+  # Test NaN / Inf early stop
+  # Test that loss is returned
+end
+
+import Tracker
+Flux.@train_autodiff Tracker
+
+@testset "Explicit Flux.train! with Tracker" begin
+  Random.seed!(84)
+  w = randn(10, 10)
+  w2 = randn(10, 10)  # NB outside the inner @testset, else it will be exactly == w, as the RNG seed is reset.
+  @testset for rule in [Descent(0.1), Adam(), AdamW()]
+
+    loss(m, x) = begin
+      Flux.istraining() && error("This test is not in fact using Tracker!")
+      Flux.Losses.mse(w*x, m.weight*x .+ m.bias)
+    end
+    model = (weight=copy(w2), bias=zeros(10), ignore=nothing)
+    @test loss(model, rand(10, 10)) > 1
+
+    opt = Flux.setup(rule, model)
+    Flux.train!(loss, model, ((rand(10),) for _ in 1: 10^5), opt)
+    @test loss(model, rand(10, 10)) < 0.01
+  end
+
+  # Test 3-arg `Flux.train!` method:
+  @testset for rule in [Descent(0.1), Adam()]
+
+    loss(m) = let x = rand(10)
+      Flux.istraining() && error("This test is not in fact using Tracker!")
+      Flux.Losses.mse(w*x, m.weight*x .+ m.bias)
+    end
+    model = (weight=copy(w2), bias=zeros(10), ignore=nothing)
+    @test loss(model) > 1
+
+    opt = Flux.setup(rule, model)
+    for i in 1:10^5
+      Flux.train!(loss, model, opt)
+    end
+    @test loss(model) < 0.01
+  end
+end
+
+import Yota
+Flux.@train_autodiff Yota
+
+@testset "Explicit Flux.train! with Yota" begin
+  Random.seed!(84)
+  w = randn(10, 10)
+  w2 = randn(10, 10)  # NB outside the inner @testset, else it will be exactly == w, as the RNG seed is reset.
+  @testset for rule in [Descent(0.1), Adam(), AdamW()]
+
+    loss(m, x) = Flux.Losses.mse(w*x, m.weight*x .+ m.bias)
+    model = (weight=copy(w2), bias=zeros(10), ignore=nothing)
+    @test loss(model, rand(10, 10)) > 1
+
+    opt = Flux.setup(rule, model)
+    Flux.train!(loss, model, ((rand(10),) for _ in 1: 10^5), opt)
+    @test loss(model, rand(10, 10)) < 0.01
+  end
+
+  # Test 3-arg `Flux.train!` method:
+  @testset for rule in [Descent(0.1), Adam()]
+
+    loss(m) = let x = rand(10)
+      Flux.Losses.mse(w*x, m.weight*x .+ m.bias)
+    end
+    model = (weight=copy(w2), bias=zeros(10), ignore=nothing)
+    @test loss(model) > 1
+
+    opt = Flux.setup(rule, model)
+    for i in 1:10^5
+      Flux.train!(loss, model, opt)
+    end
+    @test loss(model) < 0.01
+  end
+end
+
+Flux.@train_autodiff Zygote

From 9c22c11a47ddc88eb616ed243b1247b16e23b82c Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Sun, 16 Oct 2022 13:43:24 -0400
Subject: [PATCH 02/18] remove train_autodiff macro

---
 src/train.jl  | 56 -------------------------------------
 test/train.jl | 76 ---------------------------------------------------
 2 files changed, 132 deletions(-)

diff --git a/src/train.jl b/src/train.jl
index 884c3578c0..0a7433ac6d 100644
--- a/src/train.jl
+++ b/src/train.jl
@@ -90,8 +90,6 @@ The built-in loss functions accept 3 arguments, allowing for instance `train!(Fl
 
 Callback functions are not supported. But see 3-argument `train!(loss, model, opt)` for an
 easy way to construct more complicated training loops.
-
-To change the package used to calculate gradients, use [`Flux.@train_autodiff`](@ref).
 """
 function train!(loss, model, data, opt)
   losses = Float32[]
@@ -144,8 +142,6 @@ for (i, d) in enumerate(data)
 end
 ```
 
-To change the package used to calculate gradients, use [`Flux.@train_autodiff`](@ref).
-
 !!! note
     This method has no implicit `Params` analog in Flux ≤ 0.13.
 """
@@ -178,56 +174,4 @@ end
 
 explicit_withgradient(f, args...) = Zygote.withgradient(f, args...)  # can overload this to use e.g. Yota / Diffractor
 
-"""
-    Flux.@train_autodiff Tracker
-    Flux.@train_autodiff Zygote
-    Flux.@train_autodiff Yota
-    Flux.@train_autodiff Diffractor
-
-This macro allows the use of `train!` with various automatic differentiation (AD) packages,
-instead of the default Zygote.jl.
-
-You should load AD package, and then call this macro with the chosen name.
-The macro overwrites a method withing Flux, thus is a global setting, lasting until you re-start Julia.
-
-Only works with [Yota.jl](https://github.com/dfdx/Yota.jl),
-[Tracker.jl](https://github.com/FluxML/Tracker.jl) (Flux's old AD),
-[Diffractor.jl](https://github.com/JuliaDiff/Diffractor.jl) (which is not yet registered),
-and with the default [Zygote.jl](https://github.com/FluxML/Zygote.jl).
-
-!!! note
-    This is mechanism is experimental! And there are known bugs, in particular Tracker will not automatically switch to training mode for `Dropout` etc.
-"""
-macro train_autodiff(pkg)
-  if pkg == :Diffractor
-    return quote
-      Diffractor.gradient(sin, 0.0)[1] ≈ 1.0  # ensures an error if not loaded
-      function Flux.Train.explicit_withgradient(f, args...)
-        y, back = Diffractor.∂⃖¹(f, args...)
-        dy1 = Flux.Zygote.sensitivity(y)  # Zygote is loaded, and this gives nice errors
-        return (; value = y, gradient = Base.tail(back(dy1)))
-      end
-    end |> esc
-  elseif pkg == :Yota
-    return quote
-      Yota.grad(sin, 0.0) # [2][1] ≈ 1.0
-      function Flux.Train.explicit_withgradient(f, args...)
-        value, (_, gradient...) = Yota.grad(f, args...)
-        return (; value, gradient)
-      end
-    end |> esc
-  elseif pkg == :Tracker
-    return quote
-      Tracker.withgradient(sum, [1.0]).val == 1.0  # ensures an error if too-old version
-      Flux.Train.explicit_withgradient(f, args...) = Tracker.withgradient(f, args...)
-    end |> esc
-  elseif pkg == :Zygote
-    return quote
-      Flux.Train.explicit_withgradient(f, args...) = Flux.Zygote.withgradient(f, args...)
-    end |> esc
-  else
-    throw("@train_autodiff expects one of Tracker, Zygote, Yota, or Diffractor. No other arguments are understood.")
-  end
-end
-
 end # module
diff --git a/test/train.jl b/test/train.jl
index 81ffa2f3db..443a39dd75 100644
--- a/test/train.jl
+++ b/test/train.jl
@@ -53,79 +53,3 @@ end
   # Test NaN / Inf early stop
   # Test that loss is returned
 end
-
-import Tracker
-Flux.@train_autodiff Tracker
-
-@testset "Explicit Flux.train! with Tracker" begin
-  Random.seed!(84)
-  w = randn(10, 10)
-  w2 = randn(10, 10)  # NB outside the inner @testset, else it will be exactly == w, as the RNG seed is reset.
-  @testset for rule in [Descent(0.1), Adam(), AdamW()]
-
-    loss(m, x) = begin
-      Flux.istraining() && error("This test is not in fact using Tracker!")
-      Flux.Losses.mse(w*x, m.weight*x .+ m.bias)
-    end
-    model = (weight=copy(w2), bias=zeros(10), ignore=nothing)
-    @test loss(model, rand(10, 10)) > 1
-
-    opt = Flux.setup(rule, model)
-    Flux.train!(loss, model, ((rand(10),) for _ in 1: 10^5), opt)
-    @test loss(model, rand(10, 10)) < 0.01
-  end
-
-  # Test 3-arg `Flux.train!` method:
-  @testset for rule in [Descent(0.1), Adam()]
-
-    loss(m) = let x = rand(10)
-      Flux.istraining() && error("This test is not in fact using Tracker!")
-      Flux.Losses.mse(w*x, m.weight*x .+ m.bias)
-    end
-    model = (weight=copy(w2), bias=zeros(10), ignore=nothing)
-    @test loss(model) > 1
-
-    opt = Flux.setup(rule, model)
-    for i in 1:10^5
-      Flux.train!(loss, model, opt)
-    end
-    @test loss(model) < 0.01
-  end
-end
-
-import Yota
-Flux.@train_autodiff Yota
-
-@testset "Explicit Flux.train! with Yota" begin
-  Random.seed!(84)
-  w = randn(10, 10)
-  w2 = randn(10, 10)  # NB outside the inner @testset, else it will be exactly == w, as the RNG seed is reset.
-  @testset for rule in [Descent(0.1), Adam(), AdamW()]
-
-    loss(m, x) = Flux.Losses.mse(w*x, m.weight*x .+ m.bias)
-    model = (weight=copy(w2), bias=zeros(10), ignore=nothing)
-    @test loss(model, rand(10, 10)) > 1
-
-    opt = Flux.setup(rule, model)
-    Flux.train!(loss, model, ((rand(10),) for _ in 1: 10^5), opt)
-    @test loss(model, rand(10, 10)) < 0.01
-  end
-
-  # Test 3-arg `Flux.train!` method:
-  @testset for rule in [Descent(0.1), Adam()]
-
-    loss(m) = let x = rand(10)
-      Flux.Losses.mse(w*x, m.weight*x .+ m.bias)
-    end
-    model = (weight=copy(w2), bias=zeros(10), ignore=nothing)
-    @test loss(model) > 1
-
-    opt = Flux.setup(rule, model)
-    for i in 1:10^5
-      Flux.train!(loss, model, opt)
-    end
-    @test loss(model) < 0.01
-  end
-end
-
-Flux.@train_autodiff Zygote

From fa022b3a9c7b384c4a93b975006b360c345f4859 Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Sun, 16 Oct 2022 15:46:57 -0400
Subject: [PATCH 03/18] make it stricter, to avoid batchmaybe weirdness

---
 src/train.jl  | 37 ++++++++++++++++++++-----------------
 test/train.jl |  2 +-
 2 files changed, 21 insertions(+), 18 deletions(-)

diff --git a/src/train.jl b/src/train.jl
index 0a7433ac6d..1c35da1b0a 100644
--- a/src/train.jl
+++ b/src/train.jl
@@ -63,17 +63,17 @@ according to a particular optimisation rule `opt`.
     * Instead of `loss` being a function which typically accepts two arguments
       (the input `x` and expected output `y` from each element of `data`)
       now it should typically accept three, the first of which is the `model` itself.
-    * `data` should iterate tuples or NamedTuples
-    * `opt` should be the result of [`Flux.setup`](@ref).
+    * `data` must iterate tuples. Each `d in data` is used as `loss(model, d...)`.
+    * `opt` should be the result of [`Flux.setup`](@ref), it will warn you if not.
     * Callback functions are not supported.
 
 For example, with these definitions...
 ```
-data = [(x1, y1), (x2, y2), (x3, y3)];  # each element must be a tuple (or NamedTuple)
+data = [(x1, y1), (x2, y2), (x3, y3)];  # each element must be a tuple
 
-loss3(m, x, y) = norm(m(x) .- y)  # the model is the first argument
+loss3(m, x, y) = norm(m(x) .- y)        # the model is the first argument
 
-opt = Flux.setup(Adam(), model)  # explicit setup of optimiser momenta
+opt = Flux.setup(Adam(), model)         # explicit setup of optimiser momenta
 ```
 ...calling `train!(loss3, model, data, opt)` runs a loop much like this:
 ```
@@ -82,19 +82,28 @@ for d in data
     Optimisers.update!(opt, model, ∂L∂m)
 end
 ```
-Stops with a `DomainError` if the loss is infinite or `NaN` at any point.
+You can also write this loop yourself, if you need more flexibility.
+Besides the loop, `train!` will:
 
-Returns a vector containing the value of the loss function at each datapoint.
+* Stop with a `DomainError` if the loss is infinite or `NaN` at any point.
 
-The built-in loss functions accept 3 arguments, allowing for instance `train!(Flux.Losses.mse, model, data, opt)`.
+* Return a vector containing the value of the loss function at each datapoint.
 
-Callback functions are not supported. But see 3-argument `train!(loss, model, opt)` for an
-easy way to construct more complicated training loops.
+* Show a progress bar using [`@withprogress`](https://github.com/JuliaLogging/ProgressLogging.jl).
+
+Note that the built-in loss functions accept 3 arguments, allowing for instance
+`train!(Flux.Losses.mse, model, data, opt)` instead of defining `loss3` as above.
+
+Note that callback functions are not supported. But arbitrary code can be inserted into the loop.
 """
 function train!(loss, model, data, opt)
+  Base.issingletontype(typeof(loss)) || error("""train! with explicit parameter expects a pure loss function.
+                                                 It must not close over the model, like loss(x,y) = mse(model(x), y). """)
   losses = Float32[]
   @withprogress for (i,d) in enumerate(data)
-    l, (g, _...) = explicit_withgradient(loss, model, data_splat(d)...)
+    d isa Tuple || error("""train! expects as data an iterator producing tuples, but got $(typeof(d)).
+                            Pass it `((d,) for d in data)`, or use `gradient` and `update!` for more control.""")
+    l, (g, _...) = explicit_withgradient(loss, model, d...)
     isfinite(l) || throw(DomainError("loss function returned $l, stopping training"))
     opt, model = Optimisers.update!(opt, model, g)
     push!(losses, l)
@@ -103,12 +112,6 @@ function train!(loss, model, data, opt)
   return losses  # Not entirely sure returning losses is a good idea
 end
 
-data_splat(x::T) where T =  error("""train! expects every d in data be a Tuple or a NamedTuple, got $T
-                                   To allow this type, define `Flux.Train.data_splat(x::$T) = (x,)`""")
-data_splat(x::Tuple) = x
-data_splat(x::NamedTuple) = x
-data_splat(x::AbstractArray{<:Number}) = (x,)
-
 """
     train!(loss, model, opt)
     
diff --git a/test/train.jl b/test/train.jl
index 443a39dd75..ce5a3c3ee2 100644
--- a/test/train.jl
+++ b/test/train.jl
@@ -49,7 +49,7 @@ using Random
 end
 
 @testset "Explicit Flux.train! features" begin
-  # Test that splat accepts NamedTuple
+  # Test errors from wrong kind of iterator
   # Test NaN / Inf early stop
   # Test that loss is returned
 end

From 4e937dfb315fa76318dbe7dd11d63b1e96929146 Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Sun, 16 Oct 2022 15:48:54 -0400
Subject: [PATCH 04/18] remove 3-argument train! since this requires impure
 loss function, and you can just use update! instead really.

---
 src/deprecations.jl |  8 -------
 src/train.jl        | 51 +--------------------------------------------
 test/train.jl       | 16 --------------
 3 files changed, 1 insertion(+), 74 deletions(-)

diff --git a/src/deprecations.jl b/src/deprecations.jl
index c878c5192b..782efde473 100644
--- a/src/deprecations.jl
+++ b/src/deprecations.jl
@@ -89,22 +89,14 @@ Base.@deprecate_binding ADADelta AdaDelta
   # Valid methods in Train, new explict style, are:
   train!(loss, model, data, opt)
   train!(loss, model, data, opt::Optimisers.AbstractRule)
-  # ... and 3-arg:
-  train!(loss, model, opt)
-  train!(loss, model, opt::Optimisers.AbstractRule)
   # Provide friendly errors for what happens if you mix these up:
 =#
 import .Optimise: train!
 train!(loss, ps::Params, data, opt) = error("can't mix implict Params with explict state")
-train!(loss, ps::Params, opt) = error("can't mix implict Params with explict state")
 
 train!(loss, ps::Params, data, opt::Optimisers.AbstractRule) = error("can't mix implict Params with explict rule")
-train!(loss, ps::Params, opt::Optimisers.AbstractRule) = error("can't mix implict Params with explict rule")
 
 train!(loss, model, data, opt::Optimise.AbstractOptimiser) = train!(loss, model, data, _old_to_new(opt))
-train!(loss, model, opt::Optimise.AbstractOptimiser) = train!(loss, model, _old_to_new(opt))
-
-train!(loss, ps::Params, opt::Optimise.AbstractOptimiser; cb=0) = error("3-arg train does not exist for implicit mode")
 
 # train!(loss::Function, ps::Zygote.Params, data, opt) = throw(ArgumentError(
 #   """On Flux 0.14, `train!` no longer accepts implicit `Zygote.Params`.
diff --git a/src/train.jl b/src/train.jl
index 1c35da1b0a..525912e784 100644
--- a/src/train.jl
+++ b/src/train.jl
@@ -47,9 +47,6 @@ function setup(rule::Optimisers.AbstractRule, model)
     state
 end
 
-# opt = Flux.setup(Adam(), model); train!(model, opt) do m ... 
-setup(model, rule::Optimisers.AbstractRule) = setup(rule, model)
-
 """
     train!(loss, model, data, opt)
 
@@ -112,56 +109,10 @@ function train!(loss, model, data, opt)
   return losses  # Not entirely sure returning losses is a good idea
 end
 
-"""
-    train!(loss, model, opt)
-    
-Uses a `loss` function improve the `model`'s parameters.
-
-While the 4-argument method of `train!` iterates over a dataset,
-this 3-argument method is for a single datapoint, and calls `gradient` just once.
-It expects a function `loss` which takes just one argument, the model.
-For example:
-```
-opt = Flux.setup(Adam(), model)   # explicit setup
-train!(model, opt) do m           # the model is passed to the function as `m`
-    Flux.crossentropy(m(x1), y1)  # but the data point `(x1, y1)` is closed over.
-end
-```
-This calls `Zygote.withgradient(m -> Flux.crossentropy(m(x1), y1), model)`.
-(The `do` block is another syntax for this anonymous function.)
-Then it updates the parameters contained within `model` according to `opt`.
-Finally it returns the value of the loss function.
-
-To iterate over a dataset, writing a loop allows more control than
-calling 4-argument `train!`. For example, this adds printing and an early stop:
-```
-data = Flux.DataLoader((Xtrain, Ytrain), batchsize=32)
-opt = Flux.setup(Adam(), model)
-for (i, d) in enumerate(data)
-    x, y = d
-    ell = Flux.train!(m -> Flux.crossentropy(m(x), y), model, opt)
-    i%10==0 && println("on step \$i, the loss was \$ell")  # prints every 10th step
-    ell<0.1 && break                                     # stops training
-end
-```
-
-!!! note
-    This method has no implicit `Params` analog in Flux ≤ 0.13.
-"""
-function train!(loss, model, opt)
-  l, (g, _...) = explicit_withgradient(loss, model)
-  isfinite(l) || return l
-  _, model = Optimisers.update!(opt, model, g)
-  return l
-end
-
-# These methods let you use Optimisers.Descent() without setup, when there is no state
+# This method let you use Optimisers.Descent() without setup, when there is no state
 function train!(loss, model, data, rule::Optimisers.AbstractRule)
   train!(loss, model, data, _rule_to_state(model, rule))
 end
-function train!(loss, model, rule::Optimisers.AbstractRule)
-  train!(loss, model, _rule_to_state(model, rule))
-end
 
 function _rule_to_state(model, rule::Optimisers.AbstractRule)
   state = setup(rule, model)
diff --git a/test/train.jl b/test/train.jl
index ce5a3c3ee2..607dc1e9a6 100644
--- a/test/train.jl
+++ b/test/train.jl
@@ -22,22 +22,6 @@ using Random
     @test loss(model, rand(10, 10)) < 0.01
   end
 
-  # Test 3-arg `Flux.train!` method:
-  @testset for rule in [Descent(0.1), Adam(), AdamW()]
-
-    loss(m) = let x = rand(10)
-      Flux.Losses.mse(w*x, m.weight*x .+ m.bias)
-    end
-    model = (weight=copy(w2), bias=zeros(10), ignore=nothing)
-    @test loss(model) > 1
-
-    opt = Flux.setup(rule, model)
-    for i in 1:10^5
-      Flux.train!(loss, model, opt)
-    end
-    @test loss(model) < 0.01
-  end
-
   # Test direct use of Optimisers.jl rule, only really OK for `Descent`:
   @testset "without setup, $opt" for opt in [Descent(0.1), Optimisers.Descent(0.1), Optimisers.Adam()]
     loss(m, x) = Flux.Losses.mse(w*x, m.weight*x .+ m.bias)

From 36a2bb5d45ed332987c53958882f5bacd206b5bd Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Sun, 16 Oct 2022 20:01:30 -0400
Subject: [PATCH 05/18] remove issingletontype purity check, too strict

---
 src/train.jl | 4 +---
 1 file changed, 1 insertion(+), 3 deletions(-)

diff --git a/src/train.jl b/src/train.jl
index 525912e784..783536755b 100644
--- a/src/train.jl
+++ b/src/train.jl
@@ -8,7 +8,7 @@ import ..Flux.Optimise: train!, update!  # during 0.13, we add methods to the ol
 
 export setup, @train_autodiff
 
-using ProgressLogging: @progress, @withprogress, @logprogress  # TODO add progress logging again
+using ProgressLogging: @progress, @withprogress, @logprogress
 using Zygote: Zygote, Params
 
 """
@@ -94,8 +94,6 @@ Note that the built-in loss functions accept 3 arguments, allowing for instance
 Note that callback functions are not supported. But arbitrary code can be inserted into the loop.
 """
 function train!(loss, model, data, opt)
-  Base.issingletontype(typeof(loss)) || error("""train! with explicit parameter expects a pure loss function.
-                                                 It must not close over the model, like loss(x,y) = mse(model(x), y). """)
   losses = Float32[]
   @withprogress for (i,d) in enumerate(data)
     d isa Tuple || error("""train! expects as data an iterator producing tuples, but got $(typeof(d)).

From 28a7718ff3ed24081db38942c5370e63ccf6ca75 Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Sun, 16 Oct 2022 21:06:34 -0400
Subject: [PATCH 06/18] tidy up

---
 Project.toml |  6 +----
 src/train.jl | 71 +++++++++++++++++++++++++++++++---------------------
 2 files changed, 43 insertions(+), 34 deletions(-)

diff --git a/Project.toml b/Project.toml
index e33a4e3dba..84e20d8e9c 100644
--- a/Project.toml
+++ b/Project.toml
@@ -39,8 +39,6 @@ ProgressLogging = "0.1"
 Reexport = "0.2, 1.0"
 SpecialFunctions = "1.8.2, 2.1.2"
 StatsBase = "0.33"
-Tracker = "0.2.22"
-Yota = "0.8.1"
 Zygote = "0.6.34"
 julia = "1.6"
 
@@ -50,9 +48,7 @@ Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
 FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
 IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e"
 LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
-Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
-Yota = "cd998857-8626-517d-b929-70ad188a48f0"
 Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
 
 [targets]
-test = ["Test", "Documenter", "IterTools", "LinearAlgebra", "FillArrays", "ComponentArrays", "Tracker", "Yota"]
+test = ["Test", "Documenter", "IterTools", "LinearAlgebra", "FillArrays", "ComponentArrays"]
diff --git a/src/train.jl b/src/train.jl
index 783536755b..046126328a 100644
--- a/src/train.jl
+++ b/src/train.jl
@@ -6,7 +6,7 @@ using Functors: fmap
 
 import ..Flux.Optimise: train!, update!  # during 0.13, we add methods to the old functions
 
-export setup, @train_autodiff
+export setup, train!
 
 using ProgressLogging: @progress, @withprogress, @logprogress
 using Zygote: Zygote, Params
@@ -14,28 +14,33 @@ using Zygote: Zygote, Params
 """
     opt = setup(rule, model)
 
-This is a version of `Optimisers.setup`, and is the first step before using `train!`.
+This is a version of `Optimisers.setup`, and is the first step before using [`train!`](@ref Flux.train!).
 It differs from `Optimisers.setup` in that it:
 * has one extra check for mutability
 * has methods which accept Flux's old optimisers, and convert them.
 
+# Example
 ```jldoctest
 julia> model = Dense(2=>1, leakyrelu; init=Flux.ones32);
 
-julia> opt = Flux.setup(Momentum(0.11), model)
-(weight = Leaf(Momentum{Float64}(0.11, 0.9), Float32[0.0 0.0]), bias = Leaf(Momentum{Float64}(0.11, 0.9), Float32[0.0]), σ = ())
+julia> opt = Flux.setup(Momentum(0.1), model)  # this encodes the optimiser and its state
+(weight = Leaf(Momentum{Float64}(0.1, 0.9), Float32[0.0 0.0]), bias = Leaf(Momentum{Float64}(0.1, 0.9), Float32[0.0]), σ = ())
 
-julia> Flux.train!(model, opt) do m  # 3-arg train!, for one data point (x = [0.2, -0.3], y = [0.4])
-         sum(m([0.2, -0.3]) .- [0.4]) * 100
+julia> x1, y1 = [0.2, -0.3], [0.4];  # use the same data for two steps:
+
+julia> Flux.train!(model, [(x1, y1), (x1, y1)], opt) do m, x, y
+         sum(abs.(m(x) .- y)) * 100
        end
--40.1
+2-element Vector{Float32}:
+ 40.1
+ 38.7
 
 julia> model.bias  # was zero, mutated by Flux.train!
 1-element Vector{Float32}:
- -0.11
+ 10.190001
 
 julia> opt  # mutated by Flux.train!
-(weight = Leaf(Momentum{Float64}(0.11, 0.9), Float32[0.022 -0.033]), bias = Leaf(Momentum{Float64}(0.11, 0.9), Float32[0.11]), σ = ())
+(weight = Leaf(Momentum{Float64}(0.1, 0.9), Float32[-2.018 3.027]), bias = Leaf(Momentum{Float64}(0.1, 0.9), Float32[-10.09]), σ = ())
 ```
 """
 function setup(rule::Optimisers.AbstractRule, model)
@@ -51,18 +56,8 @@ end
     train!(loss, model, data, opt)
 
 Uses a `loss` function and training `data` to improve the `model`'s parameters
-according to a particular optimisation rule `opt`.
-
-!!! note
-    This method has significant changes from the one in Flux ≤ 0.13:
-    * It now takes the `model` itself, not the result of [`Flux.params`](@ref).
-      (This is to move away from Zygote's implicit parameter handling.)
-    * Instead of `loss` being a function which typically accepts two arguments
-      (the input `x` and expected output `y` from each element of `data`)
-      now it should typically accept three, the first of which is the `model` itself.
-    * `data` must iterate tuples. Each `d in data` is used as `loss(model, d...)`.
-    * `opt` should be the result of [`Flux.setup`](@ref), it will warn you if not.
-    * Callback functions are not supported.
+according to a particular optimisation rule `opt`. Iterates through `data` once,
+evaluating `loss(model, d...)` for each `d` in data.
 
 For example, with these definitions...
 ```
@@ -72,15 +67,17 @@ loss3(m, x, y) = norm(m(x) .- y)        # the model is the first argument
 
 opt = Flux.setup(Adam(), model)         # explicit setup of optimiser momenta
 ```
-...calling `train!(loss3, model, data, opt)` runs a loop much like this:
+...calling `Flux.train!(loss3, model, data, opt)` runs a loop much like this,
+using Zygote's "explicit" mode for the gradient:
 ```
 for d in data
-    ∂L∂m = Zygote.gradient(loss3, model, d...)[1]
-    Optimisers.update!(opt, model, ∂L∂m)
+    ∂L∂m = gradient(loss3, model, d...)[1]
+    update!(opt, model, ∂L∂m)           # method for "explicit" gradient
 end
 ```
 You can also write this loop yourself, if you need more flexibility.
-Besides the loop, `train!` will:
+For this reason `train!` is not highly extensible.
+It adds only a few featurs to the loop above:
 
 * Stop with a `DomainError` if the loss is infinite or `NaN` at any point.
 
@@ -91,20 +88,36 @@ Besides the loop, `train!` will:
 Note that the built-in loss functions accept 3 arguments, allowing for instance
 `train!(Flux.Losses.mse, model, data, opt)` instead of defining `loss3` as above.
 
-Note that callback functions are not supported. But arbitrary code can be inserted into the loop.
+!!! note
+    This method has significant changes from the one in Flux ≤ 0.13:
+    * It now takes the `model` itself, not the result of [`Flux.params`](@ref).
+      (This is to move away from Zygote's "implicit" parameter handling, with `Grads`.)
+    * Instead of `loss` being a function which typically accepts two arguments
+      (the input `x` and expected output `y` from each element of `data`)
+      now it should typically accept three, the first of which is the `model` itself.
+    * `data` must iterate tuples, otherwise you get an error.
+      (Previously non-tuple types were not splatted into the loss. 
+      Pass in `((d,) for d in data)` to simulate this.)
+    * `opt` should be the result of [`Flux.setup`](@ref). Using an optimiser
+      such as `Adam()` without this step should give you a warning.
+    * Callback functions are not supported.
+      But any code can be included in the above `for` loop.
 """
-function train!(loss, model, data, opt)
+function train!(loss, model, data, opt; cb = nothing)
+  isnothing(cb) || error("""train! does not support callback functions.
+                            For more control use a loop with `gradient` and `update!`.""")
   losses = Float32[]
   @withprogress for (i,d) in enumerate(data)
     d isa Tuple || error("""train! expects as data an iterator producing tuples, but got $(typeof(d)).
                             Pass it `((d,) for d in data)`, or use `gradient` and `update!` for more control.""")
-    l, (g, _...) = explicit_withgradient(loss, model, d...)
+    # l, (g, _...) = explicit_withgradient(loss, model, d...)  # BTW this un-thunks gradient w.r.t. data. Could avoid that
+    l, (g, _...) = explicit_withgradient(m -> loss(m, d...), model)
     isfinite(l) || throw(DomainError("loss function returned $l, stopping training"))
     opt, model = Optimisers.update!(opt, model, g)
     push!(losses, l)
     @logprogress Base.haslength(data) ? i/length(data) : nothing
   end
-  return losses  # Not entirely sure returning losses is a good idea
+  return losses  # Not entirely sure returning losses is a good idea, as it may conflict with later returning immutable models alla Optimisers.jl
 end
 
 # This method let you use Optimisers.Descent() without setup, when there is no state

From 5d74b04c1f2e1d14fe1b29b0e4b76e0853cdeda2 Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Sun, 16 Oct 2022 22:26:28 -0400
Subject: [PATCH 07/18] fix tests

---
 NEWS.md                     | 2 +-
 docs/make.jl                | 4 ++--
 docs/src/models/overview.md | 6 +++---
 3 files changed, 6 insertions(+), 6 deletions(-)

diff --git a/NEWS.md b/NEWS.md
index 032819e1d1..6d48d6380c 100644
--- a/NEWS.md
+++ b/NEWS.md
@@ -2,7 +2,7 @@
 
 ## v0.13.7
 * Added [`@autosize` macro](https://github.com/FluxML/Flux.jl/pull/2078)
-* New method of `train!` using Zygote's "explicit" mode, allows changing AD back-end.
+* New method of `train!` using Zygote's "explicit" mode. Part of a move away from "implicit" `Params`.
 
 ## v0.13.4
 * Added [`PairwiseFusion` layer](https://github.com/FluxML/Flux.jl/pull/1983)
diff --git a/docs/make.jl b/docs/make.jl
index 40d6033637..4094d11607 100644
--- a/docs/make.jl
+++ b/docs/make.jl
@@ -1,10 +1,10 @@
-using Documenter, Flux, NNlib, Functors, MLUtils, BSON, Optimisers, OneHotArrays, Zygote, ChainRulesCore
+using Documenter, Flux, NNlib, Functors, MLUtils, BSON, Optimisers, OneHotArrays, Zygote, ChainRulesCore, Statistics
 
 
 DocMeta.setdocmeta!(Flux, :DocTestSetup, :(using Flux); recursive = true)
 
 makedocs(
-    modules = [Flux, NNlib, Functors, MLUtils, BSON, Optimisers, OneHotArrays, Zygote, ChainRulesCore, Base],
+    modules = [Flux, NNlib, Functors, MLUtils, BSON, Optimisers, OneHotArrays, Zygote, ChainRulesCore, Base, Statistics],
     doctest = false,
     sitename = "Flux",
     # strict = [:cross_references,],
diff --git a/docs/src/models/overview.md b/docs/src/models/overview.md
index 630da338cb..b187a2de26 100644
--- a/docs/src/models/overview.md
+++ b/docs/src/models/overview.md
@@ -17,7 +17,7 @@ This example will predict the output of the function `4x + 2`. Making such predi
 
 First, import `Flux` and define the function we want to simulate:
 
-```jldoctest overview
+```jldoctest overview setup = :(using Statistics)
 julia> using Flux
 
 julia> actual(x) = 4x + 2
@@ -77,13 +77,13 @@ julia> predict(x_train)
 In order to make better predictions, you'll need to provide a *loss function* to tell Flux how to objectively *evaluate* the quality of a prediction. Loss functions compute the cumulative distance between actual values and predictions. 
 
 ```jldoctest overview; filter = r"[+-]?([0-9]*[.])?[0-9]+(f[+-]*[0-9])?"
-julia> loss(model, x, y) = mean(abs2.(model(x) .- y));
+julia> loss(model, x, y) = Statistics.mean(abs2.(model(x) .- y));
 
 julia> loss(predict, x_train, y_train)
 122.64734f0
 ```
 
-More accurate predictions will yield a lower loss. You can write your own loss functions or rely on those already provided by Flux. This loss function is called [mean squared error](https://www.statisticshowto.com/probability-and-statistics/statistics-definitions/mean-squared-error/). Flux works by iteratively reducing the loss through *training*.
+More accurate predictions will yield a lower loss. You can write your own loss functions or rely on those already provided by Flux. This loss function is called [mean squared error](https://www.statisticshowto.com/probability-and-statistics/statistics-definitions/mean-squared-error/) (and built-in as [`mse`](@ref Flux.Losses.mse)). Flux works by iteratively reducing the loss through *training*.
 
 ## 3. Improve the Prediction
 

From 7eaf3ea10dc95817de9aee23f47a44fcf299a1c4 Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Thu, 10 Nov 2022 08:22:39 -0500
Subject: [PATCH 08/18] use _old_to_new in Optimisers.setup too

---
 src/deprecations.jl | 41 ++++++++++++++++++++++++++---------------
 1 file changed, 26 insertions(+), 15 deletions(-)

diff --git a/src/deprecations.jl b/src/deprecations.jl
index 782efde473..b686e68f99 100644
--- a/src/deprecations.jl
+++ b/src/deprecations.jl
@@ -86,29 +86,34 @@ Base.@deprecate_binding ADADelta AdaDelta
 #=
   # Valid method in Optimise, old implicit style, is:
   train!(loss, ps::Params, data, opt::AbstractOptimiser; cb = () -> ())
+
   # Valid methods in Train, new explict style, are:
-  train!(loss, model, data, opt)
-  train!(loss, model, data, opt::Optimisers.AbstractRule)
+  train!(loss, model, data, opt)  # preferred
+  train!(loss, model, data, opt::Optimisers.AbstractRule)  # if you forget setup
+
   # Provide friendly errors for what happens if you mix these up:
 =#
 import .Optimise: train!
-train!(loss, ps::Params, data, opt) = error("can't mix implict Params with explict state")
 
-train!(loss, ps::Params, data, opt::Optimisers.AbstractRule) = error("can't mix implict Params with explict rule")
+train!(loss, ps::Params, data, opt) = error(
+  """can't mix implict Params with explict state!
+  To use `Flux.params(m)` in `train!`, the 4th argument must be from the old `Flux.Optimise` sub-module.
+  But better to use the new explicit style, in which `m` itself is the 2nd argument.
+  """)
 
-train!(loss, model, data, opt::Optimise.AbstractOptimiser) = train!(loss, model, data, _old_to_new(opt))
+train!(loss, ps::Params, data, opt::Optimisers.AbstractRule) = error(
+  """can't mix implict Params with explict rule from Optimisers.jl
+  To use `Flux.params(m)` in `train!`, the 4th argument must be from the old `Flux.Optimise` sub-mo$
+  But better to use the new explicit style, in which `m` itself is the 2nd argument.
+  """)
 
-# train!(loss::Function, ps::Zygote.Params, data, opt) = throw(ArgumentError(
-#   """On Flux 0.14, `train!` no longer accepts implicit `Zygote.Params`.
-#   Instead of `train!(loss_xy, Flux.params(model), data, Adam())`
-#   it now needs `opt = Flux.setup(Adam(), model); train!(loss_mxy, model, data, opt)`
-#   where `loss_mxy` accepts the model as its first argument.
-#   """
-# ))
+train!(loss, model, data, opt::Optimise.AbstractOptimiser) = train!(loss, model, data, _old_to_new(opt))
 
-# Next, to use the new `setup` with the still-exported old-style Adam etc:
+# Next, to use the new `setup` with the still-exported old-style `Adam` etc:
 import .Train: setup
 setup(rule::Optimise.AbstractOptimiser, model) = setup(_old_to_new(rule), model)
+# ... and allow accidental use of `Optimisers.setup` to do the same:
+Optimisers.setup(rule::Optimise.AbstractOptimiser, model) = setup(_old_to_new(rule), model)
 
 for T in [:Descent, :Adam, :Momentum, :Nesterov,
    	      :AdaGrad, :AdaMax, :AdaDelta, :AMSGrad, :NAdam, :RAdam, :OAdam, :AdaBelief,
@@ -129,10 +134,16 @@ _old_to_new(rule::RMSProp) = Optimisers.RMSProp(rule.eta, rule.rho, rule.epsilon
 
 _old_to_new(rule) = error("Flux.setup does not know how to translate this old-style implicit rule to a new-style Optimisers.jl explicit rule")
 
-Optimisers.setup(rule::Optimise.AbstractOptimiser, model) = error("please use Flux.setup not Optimisers.setup, it may be able to translate this rule")
-
 # v0.14 deprecations
 
 # Enable these when 0.14 is released, and delete const ClipGrad = Optimise.ClipValue etc: 
 # Base.@deprecate_binding Optimiser OptimiserChain
 # Base.@deprecate_binding ClipValue ClipGrad
+
+# train!(loss::Function, ps::Zygote.Params, data, opt) = throw(ArgumentError(
+#   """On Flux 0.14, `train!` no longer accepts implicit `Zygote.Params`.
+#   Instead of `train!(loss_xy, Flux.params(model), data, Adam())`
+#   it now needs `opt = Flux.setup(Adam(), model); train!(loss_mxy, model, data, opt)`
+#   where `loss_mxy` accepts the model as its first argument.
+#   """
+# ))

From f3e1559c0a43221b0e145974a6f2028037203b78 Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Thu, 10 Nov 2022 11:52:04 -0500
Subject: [PATCH 09/18] oops

---
 src/deprecations.jl | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/src/deprecations.jl b/src/deprecations.jl
index b686e68f99..6c5846c8c1 100644
--- a/src/deprecations.jl
+++ b/src/deprecations.jl
@@ -103,7 +103,7 @@ train!(loss, ps::Params, data, opt) = error(
 
 train!(loss, ps::Params, data, opt::Optimisers.AbstractRule) = error(
   """can't mix implict Params with explict rule from Optimisers.jl
-  To use `Flux.params(m)` in `train!`, the 4th argument must be from the old `Flux.Optimise` sub-mo$
+  To use `Flux.params(m)` in `train!`, the 4th argument must be from the old `Flux.Optimise` sub-module.
   But better to use the new explicit style, in which `m` itself is the 2nd argument.
   """)
 

From 63ad5430958d0c796b90e7b657c162a647b9e3b1 Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Thu, 17 Nov 2022 23:41:37 -0500
Subject: [PATCH 10/18] return nothing

---
 src/train.jl  | 6 ------
 test/train.jl | 1 -
 2 files changed, 7 deletions(-)

diff --git a/src/train.jl b/src/train.jl
index 046126328a..41169c4044 100644
--- a/src/train.jl
+++ b/src/train.jl
@@ -31,9 +31,6 @@ julia> x1, y1 = [0.2, -0.3], [0.4];  # use the same data for two steps:
 julia> Flux.train!(model, [(x1, y1), (x1, y1)], opt) do m, x, y
          sum(abs.(m(x) .- y)) * 100
        end
-2-element Vector{Float32}:
- 40.1
- 38.7
 
 julia> model.bias  # was zero, mutated by Flux.train!
 1-element Vector{Float32}:
@@ -81,8 +78,6 @@ It adds only a few featurs to the loop above:
 
 * Stop with a `DomainError` if the loss is infinite or `NaN` at any point.
 
-* Return a vector containing the value of the loss function at each datapoint.
-
 * Show a progress bar using [`@withprogress`](https://github.com/JuliaLogging/ProgressLogging.jl).
 
 Note that the built-in loss functions accept 3 arguments, allowing for instance
@@ -117,7 +112,6 @@ function train!(loss, model, data, opt; cb = nothing)
     push!(losses, l)
     @logprogress Base.haslength(data) ? i/length(data) : nothing
   end
-  return losses  # Not entirely sure returning losses is a good idea, as it may conflict with later returning immutable models alla Optimisers.jl
 end
 
 # This method let you use Optimisers.Descent() without setup, when there is no state
diff --git a/test/train.jl b/test/train.jl
index 607dc1e9a6..2b4510565e 100644
--- a/test/train.jl
+++ b/test/train.jl
@@ -35,5 +35,4 @@ end
 @testset "Explicit Flux.train! features" begin
   # Test errors from wrong kind of iterator
   # Test NaN / Inf early stop
-  # Test that loss is returned
 end

From 2bd0dad964d8cd836f44e736b3830c3a707ddd78 Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Fri, 18 Nov 2022 00:01:25 -0500
Subject: [PATCH 11/18] test NaN + error, tidy up

---
 src/train.jl  | 13 +++++--------
 test/train.jl | 17 +++++++++++++++--
 2 files changed, 20 insertions(+), 10 deletions(-)

diff --git a/src/train.jl b/src/train.jl
index 41169c4044..b43b84c4fe 100644
--- a/src/train.jl
+++ b/src/train.jl
@@ -101,15 +101,14 @@ Note that the built-in loss functions accept 3 arguments, allowing for instance
 function train!(loss, model, data, opt; cb = nothing)
   isnothing(cb) || error("""train! does not support callback functions.
                             For more control use a loop with `gradient` and `update!`.""")
-  losses = Float32[]
   @withprogress for (i,d) in enumerate(data)
     d isa Tuple || error("""train! expects as data an iterator producing tuples, but got $(typeof(d)).
                             Pass it `((d,) for d in data)`, or use `gradient` and `update!` for more control.""")
-    # l, (g, _...) = explicit_withgradient(loss, model, d...)  # BTW this un-thunks gradient w.r.t. data. Could avoid that
-    l, (g, _...) = explicit_withgradient(m -> loss(m, d...), model)
-    isfinite(l) || throw(DomainError("loss function returned $l, stopping training"))
-    opt, model = Optimisers.update!(opt, model, g)
-    push!(losses, l)
+    l, gs = Zygote.withgradient(m -> loss(m, d...), model)
+    if !isfinite(l)
+      throw(DomainError("Loss is $l on data item $i, stopping training"))
+    end
+    opt, model = Optimisers.update!(opt, model, gs[1])
     @logprogress Base.haslength(data) ? i/length(data) : nothing
   end
 end
@@ -131,6 +130,4 @@ function _rule_to_state(model, rule::Optimisers.AbstractRule)
   state
 end
 
-explicit_withgradient(f, args...) = Zygote.withgradient(f, args...)  # can overload this to use e.g. Yota / Diffractor
-
 end # module
diff --git a/test/train.jl b/test/train.jl
index 2b4510565e..e0e99e15d4 100644
--- a/test/train.jl
+++ b/test/train.jl
@@ -33,6 +33,19 @@ using Random
 end
 
 @testset "Explicit Flux.train! features" begin
-  # Test errors from wrong kind of iterator
-  # Test NaN / Inf early stop
+  @testset "Stop on NaN" begin
+    m = Dense(1 => 1)
+    m.weight .= 0
+    CNT = 0
+    @test_throws DomainError Flux.train!(m, tuple.(1:100), Descent(0.1)) do (i,)
+      CNT += 1
+      (i == 51 ? NaN32 : 1f0) * sum(m([1.0]))
+    end
+    @test CNT == 51  # stopped early
+    @test m.weight[1] ≈ -5  # did not corrupt weights
+  end
+  @testset "data must give tuples" begin
+    m = Dense(1 => 1)
+    @test_throws ErrorException Flux.train!((args...,) -> 1, m, [(x=1, y=2) for _ in 1:3], Descent(0.1))
+  end
 end

From 20326ea09b979a7690a2f6acdaf09272d0f3edb9 Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Sun, 20 Nov 2022 11:46:28 -0500
Subject: [PATCH 12/18] fix test

---
 test/train.jl | 17 +++++++++++------
 1 file changed, 11 insertions(+), 6 deletions(-)

diff --git a/test/train.jl b/test/train.jl
index e0e99e15d4..92f2b125ba 100644
--- a/test/train.jl
+++ b/test/train.jl
@@ -34,18 +34,23 @@ end
 
 @testset "Explicit Flux.train! features" begin
   @testset "Stop on NaN" begin
-    m = Dense(1 => 1)
-    m.weight .= 0
+    m1 = Dense(1 => 1)
+    m1.weight .= 0
     CNT = 0
-    @test_throws DomainError Flux.train!(m, tuple.(1:100), Descent(0.1)) do (i,)
+    @test_throws DomainError Flux.train!(m1, tuple.(1:100), Descent(0.1)) do m, i
       CNT += 1
       (i == 51 ? NaN32 : 1f0) * sum(m([1.0]))
     end
     @test CNT == 51  # stopped early
-    @test m.weight[1] ≈ -5  # did not corrupt weights
+    @test m1.weight[1] ≈ -5  # did not corrupt weights
   end
   @testset "data must give tuples" begin
-    m = Dense(1 => 1)
-    @test_throws ErrorException Flux.train!((args...,) -> 1, m, [(x=1, y=2) for _ in 1:3], Descent(0.1))
+    m1 = Dense(1 => 1)
+    @test_throws ErrorException Flux.train!((args...,) -> 1, m1, [(x=1, y=2) for _ in 1:3], Descent(0.1))
+  end
+  @testset "callbacks give helpful error" begin
+    m1 = Dense(1 => 1)
+    cb = () -> println("this should not be printed")
+    @test_throws ErrorException Flux.train!((args...,) -> 1, m1, [(1,2)], Descent(0.1); cb)
   end
 end

From db2a9b90584c8fe5b4b511e2478ea350f4b9b2bf Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Sun, 20 Nov 2022 11:52:42 -0500
Subject: [PATCH 13/18] remove 2 vs 3 argument comment from docstring

---
 src/train.jl | 14 ++++++--------
 1 file changed, 6 insertions(+), 8 deletions(-)

diff --git a/src/train.jl b/src/train.jl
index b43b84c4fe..919821b710 100644
--- a/src/train.jl
+++ b/src/train.jl
@@ -16,8 +16,10 @@ using Zygote: Zygote, Params
 
 This is a version of `Optimisers.setup`, and is the first step before using [`train!`](@ref Flux.train!).
 It differs from `Optimisers.setup` in that it:
-* has one extra check for mutability
+* has one extra check for mutability (since Flux expects to mutate the model in-place,
+  while Optimisers.jl is designed to return an updated model)
 * has methods which accept Flux's old optimisers, and convert them.
+  (The old `Flux.Optimise.Adam` and new `Optimisers.Adam` are distinct types.)
 
 # Example
 ```jldoctest
@@ -80,16 +82,12 @@ It adds only a few featurs to the loop above:
 
 * Show a progress bar using [`@withprogress`](https://github.com/JuliaLogging/ProgressLogging.jl).
 
-Note that the built-in loss functions accept 3 arguments, allowing for instance
-`train!(Flux.Losses.mse, model, data, opt)` instead of defining `loss3` as above.
-
 !!! note
     This method has significant changes from the one in Flux ≤ 0.13:
     * It now takes the `model` itself, not the result of [`Flux.params`](@ref).
       (This is to move away from Zygote's "implicit" parameter handling, with `Grads`.)
-    * Instead of `loss` being a function which typically accepts two arguments
-      (the input `x` and expected output `y` from each element of `data`)
-      now it should typically accept three, the first of which is the `model` itself.
+    * Instead of `loss` being a function which accepts only the data,
+      now it must also accept the `model` itself, as the first argument.
     * `data` must iterate tuples, otherwise you get an error.
       (Previously non-tuple types were not splatted into the loss. 
       Pass in `((d,) for d in data)` to simulate this.)
@@ -130,4 +128,4 @@ function _rule_to_state(model, rule::Optimisers.AbstractRule)
   state
 end
 
-end # module
+end # module Train

From c617807e8496fe471d6b3cbff025a0a44b21aa87 Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Sun, 20 Nov 2022 12:49:52 -0500
Subject: [PATCH 14/18] nice errors for update! with mixed-up input

---
 src/deprecations.jl   | 37 +++++++++++++++++++++++++++++++++++++
 src/optimise/train.jl |  2 +-
 test/train.jl         | 37 +++++++++++++++++++++++++++++++++++++
 3 files changed, 75 insertions(+), 1 deletion(-)

diff --git a/src/deprecations.jl b/src/deprecations.jl
index 6c5846c8c1..1cac2a4b86 100644
--- a/src/deprecations.jl
+++ b/src/deprecations.jl
@@ -134,6 +134,43 @@ _old_to_new(rule::RMSProp) = Optimisers.RMSProp(rule.eta, rule.rho, rule.epsilon
 
 _old_to_new(rule) = error("Flux.setup does not know how to translate this old-style implicit rule to a new-style Optimisers.jl explicit rule")
 
+# Since `update!` should be called in a loop, it makes less sense to call `setup` for you if you forgot.
+# But let's make sure that such uses give a helpful error:
+import .Optimise: update!
+
+function update!(opt::Optimise.AbstractOptimiser, model, grad)
+  # This error method requires narrowing the main worker method of Flux.Optimise
+  # to accept only arrays. Remove if this causes problems!
+  # update!(opt::Flux.Optimise.AbstractOptimiser, x::AbstractArray, x̄)
+  error("""Invalid input to `update!`.
+    * For the implicit style, this needs `update(::AbstractOptimiser, ::Params, ::Grads)`
+    * For the explicit style, `update(state, model, grad)` needs `state = Flux.setup(opt, model)`.
+    """)
+end
+
+# An easy error to make is to pass result of explicit gradient(...), not gradient(...)[1]
+# Can't catch every case, but can catch many simple Flux models:
+
+function update!(opt, model::Chain, grads::Tuple)
+  # Zygote will make a NamedTuple{(:layers,)} for the gradient of Chain, Diffractor a Tangent
+  @warn """explicit `update!(opt, model, grad)` wants the gradient for the model alone,
+    not the whole tuple from `gradient(m -> loss(m, x, y), model)`. You probably want `grads[1]`."""
+  update!(opt, model, grads[1])
+end
+
+function update!(opt::Optimise.AbstractOptimiser, model::Chain, grads::Tuple)  # ambiguity
+  update!(opt, model, grads[1])  # calls error case "Invalid input" just above
+end
+
+# One more easy error to catch is using explicit gradient with `params(m)`:
+
+function update!(opt::Optimise.AbstractOptimiser, ::Params, grads::Union{Tuple, NamedTuple})
+  error("""can't mix implicit Params with explicit gradients!
+    * For the implicit style, this needs `update(::AbstractOptimiser, ::Params, ::Grads)` with implicit gradient.
+    * For the explicit style, `update(state, model, grad)` needs the model itself, and `state = Flux.setup(opt, model)`.
+    """)
+end
+
 # v0.14 deprecations
 
 # Enable these when 0.14 is released, and delete const ClipGrad = Optimise.ClipValue etc: 
diff --git a/src/optimise/train.jl b/src/optimise/train.jl
index 74987f2a4e..d0de78e01a 100644
--- a/src/optimise/train.jl
+++ b/src/optimise/train.jl
@@ -19,7 +19,7 @@ The gradient could be mutated as well.
     This method for implicit `Params` (and `AbstractOptimiser`) will be removed from Flux 0.14.
     The explicit method `update!(opt, model, grad)` from Optimisers.jl will remain.
 """
-function update!(opt::AbstractOptimiser, x, x̄)
+function update!(opt::AbstractOptimiser, x::AbstractArray, x̄)
   x̄r = copyto!(similar(x̄), x̄)  # Flux.Optimise assumes it can mutate the gradient. This is not
                                # safe due to aliasing, nor guaranteed to be possible, e.g. Fill.
   x .-= apply!(opt, x, x̄r)
diff --git a/test/train.jl b/test/train.jl
index 92f2b125ba..49ecf9c751 100644
--- a/test/train.jl
+++ b/test/train.jl
@@ -54,3 +54,40 @@ end
     @test_throws ErrorException Flux.train!((args...,) -> 1, m1, [(1,2)], Descent(0.1); cb)
   end
 end
+
+@testset "Explicit Flux.update! features" begin
+  m = Chain(Dense(2=>3, tanh), Dense(3=>1), only)
+  x = rand(2)
+  y1 = m(x)  # before
+
+  # Implicit gradient
+  gold = gradient(() -> m(x), Flux.params(m))
+  @test gold isa Flux.Zygote.Grads
+  @test_throws ErrorException Flux.update!(Flux.Adam(), m, gold)  # friendly
+  Flux.update!(Flux.Adam(), Flux.params(m), gold)
+  y2 = m(x)
+  @test y2 < y1
+
+  # Explicit gradient
+  gs = gradient(marg -> marg(x), m)
+  @test gs isa Tuple
+  @test_throws ErrorException Flux.update!(Flux.Adam(), Flux.params(m), gs) # friendly
+  @test_throws ErrorException Flux.update!(Flux.Adam(), Flux.params(m), gs[1]) # friendly
+  @test_throws ErrorException Flux.update!(Flux.Adam(), m, gs)  # friendly
+  @test_throws ErrorException Flux.update!(Flux.Adam(), m, gs[1])  # friendly
+  s = Flux.setup(Adam(), m)
+  @info "ignore this warning, just testing an upgrade path:"
+  Flux.update!(s, m, gs)  # Chain + Tuple can be unambiguously sorted out
+  y3 = m(x)
+  @test y3 < y2
+  Flux.update!(s, m, gs[1])  # finally, this is the correct thing
+  y4 = m(x)
+  @test y4 < y3
+
+  # Also check that if you import the new Adam, then Flux.setup does still work!
+  s2 = Flux.setup(Optimisers.Adam(), m)
+  Flux.update!(s2, m, gs[1])
+  y5 = m(x)
+  @test y5 < y4
+end
+

From 0389de34e36e81a1c4b00a5da9d965b26f5bc7ed Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Sun, 20 Nov 2022 13:09:33 -0500
Subject: [PATCH 15/18] fix doctest by making "using Statistics" explicit

---
 docs/src/models/overview.md | 6 ++++--
 1 file changed, 4 insertions(+), 2 deletions(-)

diff --git a/docs/src/models/overview.md b/docs/src/models/overview.md
index b187a2de26..c0e9291240 100644
--- a/docs/src/models/overview.md
+++ b/docs/src/models/overview.md
@@ -17,7 +17,7 @@ This example will predict the output of the function `4x + 2`. Making such predi
 
 First, import `Flux` and define the function we want to simulate:
 
-```jldoctest overview setup = :(using Statistics)
+```jldoctest overview
 julia> using Flux
 
 julia> actual(x) = 4x + 2
@@ -77,7 +77,9 @@ julia> predict(x_train)
 In order to make better predictions, you'll need to provide a *loss function* to tell Flux how to objectively *evaluate* the quality of a prediction. Loss functions compute the cumulative distance between actual values and predictions. 
 
 ```jldoctest overview; filter = r"[+-]?([0-9]*[.])?[0-9]+(f[+-]*[0-9])?"
-julia> loss(model, x, y) = Statistics.mean(abs2.(model(x) .- y));
+julia> using Statistics
+
+julia> loss(model, x, y) = mean(abs2.(model(x) .- y));
 
 julia> loss(predict, x_train, y_train)
 122.64734f0

From 732fa131036c25ccff6a6157868af2f2b818fd08 Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Sun, 20 Nov 2022 13:10:00 -0500
Subject: [PATCH 16/18] also delete Flux.params() from the example completely

---
 docs/src/models/overview.md | 22 +++++-----------------
 1 file changed, 5 insertions(+), 17 deletions(-)

diff --git a/docs/src/models/overview.md b/docs/src/models/overview.md
index c0e9291240..707fa79b1a 100644
--- a/docs/src/models/overview.md
+++ b/docs/src/models/overview.md
@@ -114,21 +114,9 @@ julia> predict.bias
  0.0
 ```
 
-The dimensions of these model parameters depend on the number of inputs and outputs. Since models can have hundreds of inputs and several layers, it helps to have a function to collect the parameters into the data structure Flux expects:
+The dimensions of these model parameters depend on the number of inputs and outputs.
 
-```jldoctest overview; filter = r"[+-]?([0-9]*[.])?[0-9]+(f[+-]*[0-9])?"
-julia> parameters = Flux.params(predict)
-Params([Float32[0.9066542], Float32[0.0]])
-```
-
-These are the parameters Flux will change, one step at a time, to improve predictions. At each step, the contents of this `Params` object changes too, since it is just a collection of references to the mutable arrays inside the model: 
-
-```jldoctest overview
-julia> predict.weight in parameters, predict.bias in parameters
-(true, true)
-```
-
-The first parameter is the weight and the second is the bias. Flux will adjust predictions by iteratively changing these parameters according to the optimizer.
+Flux will adjust predictions by iteratively changing these parameters according to the optimizer.
 
 This optimiser implements the classic gradient descent strategy. Now improve the parameters of the model with a call to [`Flux.train!`](@ref) like this:
 
@@ -146,8 +134,8 @@ julia> loss(x_train, y_train)
 It went down. Why? 
 
 ```jldoctest overview; filter = r"[+-]?([0-9]*[.])?[0-9]+(f[+-]*[0-9])?"
-julia> parameters
-Params([Float32[7.5777884], Float32[1.9466728]])
+julia> predict.weight, predict.bias
+(Float32[7.5777884], Float32[1.9466728])
 ```
 
 The parameters have changed. This single step is the essence of machine learning.
@@ -165,7 +153,7 @@ julia> loss(predict, x_train, y_train)
 0.00339581f0
 
 julia> parameters
-Params([Float32[4.0178537], Float32[2.0050256]])
+(Float32[4.0178537], Float32[2.0050256])
 ```
 
 After 200 training steps, the loss went down, and the parameters are getting close to those in the function the model is built to predict.

From d9699c016b593fe2f874d77a4b7df7c1ef74e5ed Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Sun, 20 Nov 2022 13:15:37 -0500
Subject: [PATCH 17/18] fix

---
 docs/src/models/overview.md | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/docs/src/models/overview.md b/docs/src/models/overview.md
index 707fa79b1a..c2428d6d24 100644
--- a/docs/src/models/overview.md
+++ b/docs/src/models/overview.md
@@ -127,7 +127,7 @@ julia> train!(loss, predict, data, opt)
 And check the loss:
 
 ```jldoctest overview; filter = r"[+-]?([0-9]*[.])?[0-9]+(f[+-]*[0-9])?"
-julia> loss(x_train, y_train)
+julia> loss(predict, x_train, y_train)
 116.38745f0
 ```
 

From db7ad43dfdf34c1fe01dc61e71261b67c9ddda21 Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Sun, 20 Nov 2022 13:44:21 -0500
Subject: [PATCH 18/18] fix

---
 docs/src/models/overview.md | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/docs/src/models/overview.md b/docs/src/models/overview.md
index c2428d6d24..5700e02b25 100644
--- a/docs/src/models/overview.md
+++ b/docs/src/models/overview.md
@@ -152,7 +152,7 @@ julia> for epoch in 1:200
 julia> loss(predict, x_train, y_train)
 0.00339581f0
 
-julia> parameters
+julia> predict.weight, predict.bias
 (Float32[4.0178537], Float32[2.0050256])
 ```