From a5e55465859dcf6489673f3d4b0c34c157b77ce3 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Thu, 24 Nov 2022 15:52:01 +0100 Subject: [PATCH] allow non-tuple data in the new train! (#2119) * allow non-tuple data * cl/batchme * add tests * test multiple callback * cleanup notes * cleanup * cleanup * remove callbacks * cleanup * Update src/train.jl Co-authored-by: Kyle Daruwalla Co-authored-by: Kyle Daruwalla --- Project.toml | 1 - src/train.jl | 19 ++++++++----------- test/train.jl | 11 ++++++++--- 3 files changed, 16 insertions(+), 15 deletions(-) diff --git a/Project.toml b/Project.toml index 84e20d8e9c..a01cab0f9f 100644 --- a/Project.toml +++ b/Project.toml @@ -21,7 +21,6 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" -Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] diff --git a/src/train.jl b/src/train.jl index 919821b710..d548e0ac02 100644 --- a/src/train.jl +++ b/src/train.jl @@ -56,11 +56,12 @@ end Uses a `loss` function and training `data` to improve the `model`'s parameters according to a particular optimisation rule `opt`. Iterates through `data` once, -evaluating `loss(model, d...)` for each `d` in data. +evaluating for each `d in data` either `loss(model, d...)` if `d isa Tuple`, +or else `loss(model, d)` for other `d`. For example, with these definitions... ``` -data = [(x1, y1), (x2, y2), (x3, y3)]; # each element must be a tuple +data = [(x1, y1), (x2, y2), (x3, y3)] loss3(m, x, y) = norm(m(x) .- y) # the model is the first argument @@ -76,7 +77,7 @@ end ``` You can also write this loop yourself, if you need more flexibility. For this reason `train!` is not highly extensible. -It adds only a few featurs to the loop above: +It adds only a few features to the loop above: * Stop with a `DomainError` if the loss is infinite or `NaN` at any point. @@ -88,9 +89,6 @@ It adds only a few featurs to the loop above: (This is to move away from Zygote's "implicit" parameter handling, with `Grads`.) * Instead of `loss` being a function which accepts only the data, now it must also accept the `model` itself, as the first argument. - * `data` must iterate tuples, otherwise you get an error. - (Previously non-tuple types were not splatted into the loss. - Pass in `((d,) for d in data)` to simulate this.) * `opt` should be the result of [`Flux.setup`](@ref). Using an optimiser such as `Adam()` without this step should give you a warning. * Callback functions are not supported. @@ -100,9 +98,8 @@ function train!(loss, model, data, opt; cb = nothing) isnothing(cb) || error("""train! does not support callback functions. For more control use a loop with `gradient` and `update!`.""") @withprogress for (i,d) in enumerate(data) - d isa Tuple || error("""train! expects as data an iterator producing tuples, but got $(typeof(d)). - Pass it `((d,) for d in data)`, or use `gradient` and `update!` for more control.""") - l, gs = Zygote.withgradient(m -> loss(m, d...), model) + d_splat = d isa Tuple ? d : (d,) + l, gs = Zygote.withgradient(m -> loss(m, d_splat...), model) if !isfinite(l) throw(DomainError("Loss is $l on data item $i, stopping training")) end @@ -112,8 +109,8 @@ function train!(loss, model, data, opt; cb = nothing) end # This method let you use Optimisers.Descent() without setup, when there is no state -function train!(loss, model, data, rule::Optimisers.AbstractRule) - train!(loss, model, data, _rule_to_state(model, rule)) +function train!(loss, model, data, rule::Optimisers.AbstractRule; cb = nothing) + train!(loss, model, data, _rule_to_state(model, rule); cb) end function _rule_to_state(model, rule::Optimisers.AbstractRule) diff --git a/test/train.jl b/test/train.jl index 49ecf9c751..cfadde7d9b 100644 --- a/test/train.jl +++ b/test/train.jl @@ -44,10 +44,15 @@ end @test CNT == 51 # stopped early @test m1.weight[1] ≈ -5 # did not corrupt weights end - @testset "data must give tuples" begin - m1 = Dense(1 => 1) - @test_throws ErrorException Flux.train!((args...,) -> 1, m1, [(x=1, y=2) for _ in 1:3], Descent(0.1)) + + @testset "non-tuple data" begin + loss(m, x) = Flux.Losses.mse(w*x, m.weight*x .+ m.bias) + model = (weight=copy(w2), bias=zeros(10)) + opt = Flux.setup(AdamW(), model) + Flux.train!(loss, model, (rand(10) for _ in 1: 10^5), opt) + @test loss(model, rand(10, 10)) < 0.01 end + @testset "callbacks give helpful error" begin m1 = Dense(1 => 1) cb = () -> println("this should not be printed")