diff --git a/Project.toml b/Project.toml index 59d91fe..a44f650 100644 --- a/Project.toml +++ b/Project.toml @@ -6,12 +6,14 @@ version = "0.1.0" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" +ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] Flux = "0.13.7" NNlib = "0.8.10" Optimisers = "0.2.10" +ProgressMeter = "1.7.2" Zygote = "0.6.49" julia = "1.6" diff --git a/README.md b/README.md index 6374cc5..3b7d68f 100644 --- a/README.md +++ b/README.md @@ -34,4 +34,4 @@ As will any features which migrate to Flux itself. ## Current Features * Layers `Split` and `Join` - +* A more advanced `train!` diff --git a/src/Fluxperimental.jl b/src/Fluxperimental.jl index 047310a..26026d3 100644 --- a/src/Fluxperimental.jl +++ b/src/Fluxperimental.jl @@ -5,4 +5,7 @@ using Flux include("split_join.jl") export Split, Join +include("train.jl") +export shinkansen! + end # module Fluxperimental diff --git a/src/train.jl b/src/train.jl new file mode 100644 index 0000000..095f923 --- /dev/null +++ b/src/train.jl @@ -0,0 +1,78 @@ +using Flux: withgradient, DataLoader +using Optimisers: Optimisers +using ProgressMeter: ProgressMeter, Progress, next! + +#= + +This grew out of explicit-mode upgrade here: +https://github.com/FluxML/Flux.jl/pull/2082 + +=# + +""" + shinkansen!(loss, model, data...; state, epochs=1, [batchsize, keywords...]) + +This is a re-design of `train!`: + +* The loss function must accept the remaining arguments: `loss(model, data...)` +* The optimiser state from `setup` must be passed to the keyword `state`. + +By default it calls `gradient(loss, model, data...)` just like that. +Same order as the arguments. If you specify `epochs = 100`, then it will do this 100 times. + +But if you specify `batchsize = 32`, then it first makes `DataLoader(data...; batchsize)`, +and uses that to generate smaller arrays to feed to `gradient`. +All other keywords are passed to `DataLoader`, e.g. to shuffle batches. + +Returns the loss from every call. + +# Example +``` +X = repeat(hcat(digits.(0:3, base=2, pad=2)...), 1, 32) +Y = Flux.onehotbatch(xor.(eachrow(X)...), 0:1) + +model = Chain(Dense(2 => 3, sigmoid), BatchNorm(3), Dense(3 => 2)) +state = Flux.setup(Adam(0.1, (0.7, 0.95)), model) +# state = Optimisers.setup(Optimisers.Adam(0.1, (0.7, 0.95)), model) # for now + +shinkansen!(model, X, Y; state, epochs=100, batchsize=16, shuffle=true) do m, x, y + Flux.logitcrossentropy(m(x), y) +end + +all((softmax(model(X)) .> 0.5) .== Y) +``` +""" +function shinkansen!(loss::Function, model, data...; state, epochs=1, batchsize=nothing, kw...) + if batchsize != nothing + loader = DataLoader(data; batchsize, kw...) + losses = Vector{Float32}[] + prog = Progress(length(loader) * epochs) + + for e in 1:epochs + eplosses = Float32[] + for (i,d) in enumerate(loader) + l, (g, _...) = withgradient(loss, model, d...) + isfinite(l) || error("loss is $l, on batch $i, epoch $epoch") + Optimisers.update!(state, model, g) + push!(eplosses, l) + next!(prog; showvalues=[(:epoch, e), (:loss, l)]) + end + push!(losses, eplosses) + end + + return allequal(size.(losses)) ? reduce(hcat, losses) : losses + else + losses = Float32[] + prog = Progress(epochs) + + for e in 1:epochs + l, (g, _...) = withgradient(loss, model, data...) + isfinite(l) || error("loss is $l, on epoch $epoch") + Optimisers.update!(state, model, g) + push!(losses, l) + next!(prog; showvalues=[(:epoch, epoch), (:loss, l)]) + end + + return losses + end +end diff --git a/test/train.jl b/test/train.jl new file mode 100644 index 0000000..8c0a710 --- /dev/null +++ b/test/train.jl @@ -0,0 +1,27 @@ +import Flux, Fluxperimental, Optimisers + +@testset "shinkansen!" begin + + X = repeat(hcat(digits.(0:3, base=2, pad=2)...), 1, 32) + Y = Flux.onehotbatch(xor.(eachrow(X)...), 0:1) + + model = Flux.Chain(Flux.Dense(2 => 3, Flux.sigmoid), Flux.BatchNorm(3), Flux.Dense(3 => 2)) + state = Optimisers.setup(Optimisers.Adam(0.1, (0.7, 0.95)), model) + + Fluxperimental.shinkansen!(model, X, Y; state, epochs=100) do m, x, y + Flux.logitcrossentropy(m(x), y) + end + + @test all((Flux.softmax(model(X)) .> 0.5) .== Y) + + model = Flux.Chain(Flux.Dense(2 => 3, Flux.sigmoid), Flux.BatchNorm(3), Flux.Dense(3 => 2)) + state = Optimisers.setup(Optimisers.Adam(0.1, (0.7, 0.95)), model) + + Fluxperimental.shinkansen!(model, X, Y; state, epochs=100, batchsize=16, shuffle=true) do m, x, y + Flux.logitcrossentropy(m(x), y) + end + + @test all((Flux.softmax(model(X)) .> 0.5) .== Y) + +end +