Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

A cleaner more powerful train! function #3

Merged
merged 2 commits into from
Nov 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@ version = "0.1.0"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
Flux = "0.13.7"
NNlib = "0.8.10"
Optimisers = "0.2.10"
ProgressMeter = "1.7.2"
Zygote = "0.6.49"
julia = "1.6"

Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,4 @@ As will any features which migrate to Flux itself.
## Current Features

* Layers `Split` and `Join`

* A more advanced `train!`
3 changes: 3 additions & 0 deletions src/Fluxperimental.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,7 @@ using Flux
include("split_join.jl")
export Split, Join

include("train.jl")
export shinkansen!

end # module Fluxperimental
78 changes: 78 additions & 0 deletions src/train.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
using Flux: withgradient, DataLoader
using Optimisers: Optimisers
using ProgressMeter: ProgressMeter, Progress, next!

#=

This grew out of explicit-mode upgrade here:
https://github.com/FluxML/Flux.jl/pull/2082

=#

"""
shinkansen!(loss, model, data...; state, epochs=1, [batchsize, keywords...])

This is a re-design of `train!`:

* The loss function must accept the remaining arguments: `loss(model, data...)`
* The optimiser state from `setup` must be passed to the keyword `state`.

By default it calls `gradient(loss, model, data...)` just like that.
Same order as the arguments. If you specify `epochs = 100`, then it will do this 100 times.

But if you specify `batchsize = 32`, then it first makes `DataLoader(data...; batchsize)`,
and uses that to generate smaller arrays to feed to `gradient`.
All other keywords are passed to `DataLoader`, e.g. to shuffle batches.

Returns the loss from every call.

# Example
```
X = repeat(hcat(digits.(0:3, base=2, pad=2)...), 1, 32)
Y = Flux.onehotbatch(xor.(eachrow(X)...), 0:1)

model = Chain(Dense(2 => 3, sigmoid), BatchNorm(3), Dense(3 => 2))
state = Flux.setup(Adam(0.1, (0.7, 0.95)), model)
# state = Optimisers.setup(Optimisers.Adam(0.1, (0.7, 0.95)), model) # for now

shinkansen!(model, X, Y; state, epochs=100, batchsize=16, shuffle=true) do m, x, y
Flux.logitcrossentropy(m(x), y)
end

all((softmax(model(X)) .> 0.5) .== Y)
```
"""
function shinkansen!(loss::Function, model, data...; state, epochs=1, batchsize=nothing, kw...)
if batchsize != nothing
loader = DataLoader(data; batchsize, kw...)
losses = Vector{Float32}[]
prog = Progress(length(loader) * epochs)

for e in 1:epochs
eplosses = Float32[]
for (i,d) in enumerate(loader)
l, (g, _...) = withgradient(loss, model, d...)
isfinite(l) || error("loss is $l, on batch $i, epoch $epoch")
Optimisers.update!(state, model, g)
push!(eplosses, l)
next!(prog; showvalues=[(:epoch, e), (:loss, l)])
end
push!(losses, eplosses)
end

return allequal(size.(losses)) ? reduce(hcat, losses) : losses
else
losses = Float32[]
prog = Progress(epochs)

for e in 1:epochs
l, (g, _...) = withgradient(loss, model, data...)
isfinite(l) || error("loss is $l, on epoch $epoch")
Optimisers.update!(state, model, g)
push!(losses, l)
next!(prog; showvalues=[(:epoch, epoch), (:loss, l)])
end

return losses
end
end
27 changes: 27 additions & 0 deletions test/train.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import Flux, Fluxperimental, Optimisers

@testset "shinkansen!" begin

X = repeat(hcat(digits.(0:3, base=2, pad=2)...), 1, 32)
Y = Flux.onehotbatch(xor.(eachrow(X)...), 0:1)

model = Flux.Chain(Flux.Dense(2 => 3, Flux.sigmoid), Flux.BatchNorm(3), Flux.Dense(3 => 2))
state = Optimisers.setup(Optimisers.Adam(0.1, (0.7, 0.95)), model)

Fluxperimental.shinkansen!(model, X, Y; state, epochs=100) do m, x, y
Flux.logitcrossentropy(m(x), y)
end

@test all((Flux.softmax(model(X)) .> 0.5) .== Y)

model = Flux.Chain(Flux.Dense(2 => 3, Flux.sigmoid), Flux.BatchNorm(3), Flux.Dense(3 => 2))
state = Optimisers.setup(Optimisers.Adam(0.1, (0.7, 0.95)), model)

Fluxperimental.shinkansen!(model, X, Y; state, epochs=100, batchsize=16, shuffle=true) do m, x, y
Flux.logitcrossentropy(m(x), y)
end

@test all((Flux.softmax(model(X)) .> 0.5) .== Y)

end