Skip to content

Commit

Permalink
make it stricter, to avoid batchmaybe weirdness
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Nov 10, 2022
1 parent 0cc649d commit 14df718
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 18 deletions.
37 changes: 20 additions & 17 deletions src/train.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,17 +63,17 @@ according to a particular optimisation rule `opt`.
* Instead of `loss` being a function which typically accepts two arguments
(the input `x` and expected output `y` from each element of `data`)
now it should typically accept three, the first of which is the `model` itself.
* `data` should iterate tuples or NamedTuples
* `opt` should be the result of [`Flux.setup`](@ref).
* `data` must iterate tuples. Each `d in data` is used as `loss(model, d...)`.
* `opt` should be the result of [`Flux.setup`](@ref), it will warn you if not.
* Callback functions are not supported.
For example, with these definitions...
```
data = [(x1, y1), (x2, y2), (x3, y3)]; # each element must be a tuple (or NamedTuple)
data = [(x1, y1), (x2, y2), (x3, y3)]; # each element must be a tuple
loss3(m, x, y) = norm(m(x) .- y) # the model is the first argument
loss3(m, x, y) = norm(m(x) .- y) # the model is the first argument
opt = Flux.setup(Adam(), model) # explicit setup of optimiser momenta
opt = Flux.setup(Adam(), model) # explicit setup of optimiser momenta
```
...calling `train!(loss3, model, data, opt)` runs a loop much like this:
```
Expand All @@ -82,19 +82,28 @@ for d in data
Optimisers.update!(opt, model, ∂L∂m)
end
```
Stops with a `DomainError` if the loss is infinite or `NaN` at any point.
You can also write this loop yourself, if you need more flexibility.
Besides the loop, `train!` will:
Returns a vector containing the value of the loss function at each datapoint.
* Stop with a `DomainError` if the loss is infinite or `NaN` at any point.
The built-in loss functions accept 3 arguments, allowing for instance `train!(Flux.Losses.mse, model, data, opt)`.
* Return a vector containing the value of the loss function at each datapoint.
Callback functions are not supported. But see 3-argument `train!(loss, model, opt)` for an
easy way to construct more complicated training loops.
* Show a progress bar using [`@withprogress`](https://github.com/JuliaLogging/ProgressLogging.jl).
Note that the built-in loss functions accept 3 arguments, allowing for instance
`train!(Flux.Losses.mse, model, data, opt)` instead of defining `loss3` as above.
Note that callback functions are not supported. But arbitrary code can be inserted into the loop.
"""
function train!(loss, model, data, opt)
Base.issingletontype(typeof(loss)) || error("""train! with explicit parameter expects a pure loss function.
It must not close over the model, like loss(x,y) = mse(model(x), y). """)
losses = Float32[]
@withprogress for (i,d) in enumerate(data)
l, (g, _...) = explicit_withgradient(loss, model, data_splat(d)...)
d isa Tuple || error("""train! expects as data an iterator producing tuples, but got $(typeof(d)).
Pass it `((d,) for d in data)`, or use `gradient` and `update!` for more control.""")
l, (g, _...) = explicit_withgradient(loss, model, d...)
isfinite(l) || throw(DomainError("loss function returned $l, stopping training"))
opt, model = Optimisers.update!(opt, model, g)
push!(losses, l)
Expand All @@ -103,12 +112,6 @@ function train!(loss, model, data, opt)
return losses # Not entirely sure returning losses is a good idea
end

data_splat(x::T) where T = error("""train! expects every d in data be a Tuple or a NamedTuple, got $T
To allow this type, define `Flux.Train.data_splat(x::$T) = (x,)`""")
data_splat(x::Tuple) = x
data_splat(x::NamedTuple) = x
data_splat(x::AbstractArray{<:Number}) = (x,)

"""
train!(loss, model, opt)
Expand Down
2 changes: 1 addition & 1 deletion test/train.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ using Random
end

@testset "Explicit Flux.train! features" begin
# Test that splat accepts NamedTuple
# Test errors from wrong kind of iterator
# Test NaN / Inf early stop
# Test that loss is returned
end

0 comments on commit 14df718

Please sign in to comment.