Skip to content

Commit

Permalink
A cleaner more powerful train! function (#3)
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott authored Nov 27, 2022
1 parent 660ef33 commit 8650d54
Show file tree
Hide file tree
Showing 5 changed files with 111 additions and 1 deletion.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@ version = "0.1.0"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
Flux = "0.13.7"
NNlib = "0.8.10"
Optimisers = "0.2.10"
ProgressMeter = "1.7.2"
Zygote = "0.6.49"
julia = "1.6"

Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,4 @@ As will any features which migrate to Flux itself.
## Current Features

* Layers `Split` and `Join`

* A more advanced `train!`
3 changes: 3 additions & 0 deletions src/Fluxperimental.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,7 @@ using Flux
include("split_join.jl")
export Split, Join

include("train.jl")
export shinkansen!

end # module Fluxperimental
78 changes: 78 additions & 0 deletions src/train.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
using Flux: withgradient, DataLoader
using Optimisers: Optimisers
using ProgressMeter: ProgressMeter, Progress, next!

#=
This grew out of explicit-mode upgrade here:
https://github.com/FluxML/Flux.jl/pull/2082
=#

"""
shinkansen!(loss, model, data...; state, epochs=1, [batchsize, keywords...])
This is a re-design of `train!`:
* The loss function must accept the remaining arguments: `loss(model, data...)`
* The optimiser state from `setup` must be passed to the keyword `state`.
By default it calls `gradient(loss, model, data...)` just like that.
Same order as the arguments. If you specify `epochs = 100`, then it will do this 100 times.
But if you specify `batchsize = 32`, then it first makes `DataLoader(data...; batchsize)`,
and uses that to generate smaller arrays to feed to `gradient`.
All other keywords are passed to `DataLoader`, e.g. to shuffle batches.
Returns the loss from every call.
# Example
```
X = repeat(hcat(digits.(0:3, base=2, pad=2)...), 1, 32)
Y = Flux.onehotbatch(xor.(eachrow(X)...), 0:1)
model = Chain(Dense(2 => 3, sigmoid), BatchNorm(3), Dense(3 => 2))
state = Flux.setup(Adam(0.1, (0.7, 0.95)), model)
# state = Optimisers.setup(Optimisers.Adam(0.1, (0.7, 0.95)), model) # for now
shinkansen!(model, X, Y; state, epochs=100, batchsize=16, shuffle=true) do m, x, y
Flux.logitcrossentropy(m(x), y)
end
all((softmax(model(X)) .> 0.5) .== Y)
```
"""
function shinkansen!(loss::Function, model, data...; state, epochs=1, batchsize=nothing, kw...)
if batchsize != nothing
loader = DataLoader(data; batchsize, kw...)
losses = Vector{Float32}[]
prog = Progress(length(loader) * epochs)

for e in 1:epochs
eplosses = Float32[]
for (i,d) in enumerate(loader)
l, (g, _...) = withgradient(loss, model, d...)
isfinite(l) || error("loss is $l, on batch $i, epoch $epoch")
Optimisers.update!(state, model, g)
push!(eplosses, l)
next!(prog; showvalues=[(:epoch, e), (:loss, l)])
end
push!(losses, eplosses)
end

return allequal(size.(losses)) ? reduce(hcat, losses) : losses
else
losses = Float32[]
prog = Progress(epochs)

for e in 1:epochs
l, (g, _...) = withgradient(loss, model, data...)
isfinite(l) || error("loss is $l, on epoch $epoch")
Optimisers.update!(state, model, g)
push!(losses, l)
next!(prog; showvalues=[(:epoch, epoch), (:loss, l)])
end

return losses
end
end
27 changes: 27 additions & 0 deletions test/train.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import Flux, Fluxperimental, Optimisers

@testset "shinkansen!" begin

X = repeat(hcat(digits.(0:3, base=2, pad=2)...), 1, 32)
Y = Flux.onehotbatch(xor.(eachrow(X)...), 0:1)

model = Flux.Chain(Flux.Dense(2 => 3, Flux.sigmoid), Flux.BatchNorm(3), Flux.Dense(3 => 2))
state = Optimisers.setup(Optimisers.Adam(0.1, (0.7, 0.95)), model)

Fluxperimental.shinkansen!(model, X, Y; state, epochs=100) do m, x, y
Flux.logitcrossentropy(m(x), y)
end

@test all((Flux.softmax(model(X)) .> 0.5) .== Y)

model = Flux.Chain(Flux.Dense(2 => 3, Flux.sigmoid), Flux.BatchNorm(3), Flux.Dense(3 => 2))
state = Optimisers.setup(Optimisers.Adam(0.1, (0.7, 0.95)), model)

Fluxperimental.shinkansen!(model, X, Y; state, epochs=100, batchsize=16, shuffle=true) do m, x, y
Flux.logitcrossentropy(m(x), y)
end

@test all((Flux.softmax(model(X)) .> 0.5) .== Y)

end

0 comments on commit 8650d54

Please sign in to comment.