-
-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
A cleaner more powerful
train!
function (#3)
- Loading branch information
Showing
5 changed files
with
111 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
using Flux: withgradient, DataLoader | ||
using Optimisers: Optimisers | ||
using ProgressMeter: ProgressMeter, Progress, next! | ||
|
||
#= | ||
This grew out of explicit-mode upgrade here: | ||
https://github.com/FluxML/Flux.jl/pull/2082 | ||
=# | ||
|
||
""" | ||
shinkansen!(loss, model, data...; state, epochs=1, [batchsize, keywords...]) | ||
This is a re-design of `train!`: | ||
* The loss function must accept the remaining arguments: `loss(model, data...)` | ||
* The optimiser state from `setup` must be passed to the keyword `state`. | ||
By default it calls `gradient(loss, model, data...)` just like that. | ||
Same order as the arguments. If you specify `epochs = 100`, then it will do this 100 times. | ||
But if you specify `batchsize = 32`, then it first makes `DataLoader(data...; batchsize)`, | ||
and uses that to generate smaller arrays to feed to `gradient`. | ||
All other keywords are passed to `DataLoader`, e.g. to shuffle batches. | ||
Returns the loss from every call. | ||
# Example | ||
``` | ||
X = repeat(hcat(digits.(0:3, base=2, pad=2)...), 1, 32) | ||
Y = Flux.onehotbatch(xor.(eachrow(X)...), 0:1) | ||
model = Chain(Dense(2 => 3, sigmoid), BatchNorm(3), Dense(3 => 2)) | ||
state = Flux.setup(Adam(0.1, (0.7, 0.95)), model) | ||
# state = Optimisers.setup(Optimisers.Adam(0.1, (0.7, 0.95)), model) # for now | ||
shinkansen!(model, X, Y; state, epochs=100, batchsize=16, shuffle=true) do m, x, y | ||
Flux.logitcrossentropy(m(x), y) | ||
end | ||
all((softmax(model(X)) .> 0.5) .== Y) | ||
``` | ||
""" | ||
function shinkansen!(loss::Function, model, data...; state, epochs=1, batchsize=nothing, kw...) | ||
if batchsize != nothing | ||
loader = DataLoader(data; batchsize, kw...) | ||
losses = Vector{Float32}[] | ||
prog = Progress(length(loader) * epochs) | ||
|
||
for e in 1:epochs | ||
eplosses = Float32[] | ||
for (i,d) in enumerate(loader) | ||
l, (g, _...) = withgradient(loss, model, d...) | ||
isfinite(l) || error("loss is $l, on batch $i, epoch $epoch") | ||
Optimisers.update!(state, model, g) | ||
push!(eplosses, l) | ||
next!(prog; showvalues=[(:epoch, e), (:loss, l)]) | ||
end | ||
push!(losses, eplosses) | ||
end | ||
|
||
return allequal(size.(losses)) ? reduce(hcat, losses) : losses | ||
else | ||
losses = Float32[] | ||
prog = Progress(epochs) | ||
|
||
for e in 1:epochs | ||
l, (g, _...) = withgradient(loss, model, data...) | ||
isfinite(l) || error("loss is $l, on epoch $epoch") | ||
Optimisers.update!(state, model, g) | ||
push!(losses, l) | ||
next!(prog; showvalues=[(:epoch, epoch), (:loss, l)]) | ||
end | ||
|
||
return losses | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
import Flux, Fluxperimental, Optimisers | ||
|
||
@testset "shinkansen!" begin | ||
|
||
X = repeat(hcat(digits.(0:3, base=2, pad=2)...), 1, 32) | ||
Y = Flux.onehotbatch(xor.(eachrow(X)...), 0:1) | ||
|
||
model = Flux.Chain(Flux.Dense(2 => 3, Flux.sigmoid), Flux.BatchNorm(3), Flux.Dense(3 => 2)) | ||
state = Optimisers.setup(Optimisers.Adam(0.1, (0.7, 0.95)), model) | ||
|
||
Fluxperimental.shinkansen!(model, X, Y; state, epochs=100) do m, x, y | ||
Flux.logitcrossentropy(m(x), y) | ||
end | ||
|
||
@test all((Flux.softmax(model(X)) .> 0.5) .== Y) | ||
|
||
model = Flux.Chain(Flux.Dense(2 => 3, Flux.sigmoid), Flux.BatchNorm(3), Flux.Dense(3 => 2)) | ||
state = Optimisers.setup(Optimisers.Adam(0.1, (0.7, 0.95)), model) | ||
|
||
Fluxperimental.shinkansen!(model, X, Y; state, epochs=100, batchsize=16, shuffle=true) do m, x, y | ||
Flux.logitcrossentropy(m(x), y) | ||
end | ||
|
||
@test all((Flux.softmax(model(X)) .> 0.5) .== Y) | ||
|
||
end | ||
|