From 14df7187973c457ede772651dc615e65a31aea69 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sun, 16 Oct 2022 15:46:57 -0400 Subject: [PATCH] make it stricter, to avoid batchmaybe weirdness --- src/train.jl | 37 ++++++++++++++++++++----------------- test/train.jl | 2 +- 2 files changed, 21 insertions(+), 18 deletions(-) diff --git a/src/train.jl b/src/train.jl index 0a7433ac6d..1c35da1b0a 100644 --- a/src/train.jl +++ b/src/train.jl @@ -63,17 +63,17 @@ according to a particular optimisation rule `opt`. * Instead of `loss` being a function which typically accepts two arguments (the input `x` and expected output `y` from each element of `data`) now it should typically accept three, the first of which is the `model` itself. - * `data` should iterate tuples or NamedTuples - * `opt` should be the result of [`Flux.setup`](@ref). + * `data` must iterate tuples. Each `d in data` is used as `loss(model, d...)`. + * `opt` should be the result of [`Flux.setup`](@ref), it will warn you if not. * Callback functions are not supported. For example, with these definitions... ``` -data = [(x1, y1), (x2, y2), (x3, y3)]; # each element must be a tuple (or NamedTuple) +data = [(x1, y1), (x2, y2), (x3, y3)]; # each element must be a tuple -loss3(m, x, y) = norm(m(x) .- y) # the model is the first argument +loss3(m, x, y) = norm(m(x) .- y) # the model is the first argument -opt = Flux.setup(Adam(), model) # explicit setup of optimiser momenta +opt = Flux.setup(Adam(), model) # explicit setup of optimiser momenta ``` ...calling `train!(loss3, model, data, opt)` runs a loop much like this: ``` @@ -82,19 +82,28 @@ for d in data Optimisers.update!(opt, model, ∂L∂m) end ``` -Stops with a `DomainError` if the loss is infinite or `NaN` at any point. +You can also write this loop yourself, if you need more flexibility. +Besides the loop, `train!` will: -Returns a vector containing the value of the loss function at each datapoint. +* Stop with a `DomainError` if the loss is infinite or `NaN` at any point. -The built-in loss functions accept 3 arguments, allowing for instance `train!(Flux.Losses.mse, model, data, opt)`. +* Return a vector containing the value of the loss function at each datapoint. -Callback functions are not supported. But see 3-argument `train!(loss, model, opt)` for an -easy way to construct more complicated training loops. +* Show a progress bar using [`@withprogress`](https://github.com/JuliaLogging/ProgressLogging.jl). + +Note that the built-in loss functions accept 3 arguments, allowing for instance +`train!(Flux.Losses.mse, model, data, opt)` instead of defining `loss3` as above. + +Note that callback functions are not supported. But arbitrary code can be inserted into the loop. """ function train!(loss, model, data, opt) + Base.issingletontype(typeof(loss)) || error("""train! with explicit parameter expects a pure loss function. + It must not close over the model, like loss(x,y) = mse(model(x), y). """) losses = Float32[] @withprogress for (i,d) in enumerate(data) - l, (g, _...) = explicit_withgradient(loss, model, data_splat(d)...) + d isa Tuple || error("""train! expects as data an iterator producing tuples, but got $(typeof(d)). + Pass it `((d,) for d in data)`, or use `gradient` and `update!` for more control.""") + l, (g, _...) = explicit_withgradient(loss, model, d...) isfinite(l) || throw(DomainError("loss function returned $l, stopping training")) opt, model = Optimisers.update!(opt, model, g) push!(losses, l) @@ -103,12 +112,6 @@ function train!(loss, model, data, opt) return losses # Not entirely sure returning losses is a good idea end -data_splat(x::T) where T = error("""train! expects every d in data be a Tuple or a NamedTuple, got $T - To allow this type, define `Flux.Train.data_splat(x::$T) = (x,)`""") -data_splat(x::Tuple) = x -data_splat(x::NamedTuple) = x -data_splat(x::AbstractArray{<:Number}) = (x,) - """ train!(loss, model, opt) diff --git a/test/train.jl b/test/train.jl index 443a39dd75..ce5a3c3ee2 100644 --- a/test/train.jl +++ b/test/train.jl @@ -49,7 +49,7 @@ using Random end @testset "Explicit Flux.train! features" begin - # Test that splat accepts NamedTuple + # Test errors from wrong kind of iterator # Test NaN / Inf early stop # Test that loss is returned end