Skip to content

Commit

Permalink
Merge pull request #171 from FluxML/poster
Browse files Browse the repository at this point in the history
Update mnist example
  • Loading branch information
ablaom authored Jun 29, 2021
2 parents 5e5d698 + 886fca8 commit 3c5fc86
Show file tree
Hide file tree
Showing 7 changed files with 758 additions and 1,664 deletions.
280 changes: 140 additions & 140 deletions examples/mnist/Manifest.toml

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions examples/mnist/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,6 @@ Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306"
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
MLJ = "add582a8-e3ab-11e8-2d5e-e98b27df1bc7"
MLJFlux = "094fc8d1-fd35-5302-93ea-dabda2abf845"
MLJIteration = "614be32b-d00c-4edb-bd02-1eb411ab5e55"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
PyPlot = "d330b81b-6aea-500a-939a-2ce795aea3ee"
Binary file added examples/mnist/loss.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1,748 changes: 361 additions & 1,387 deletions examples/mnist/mnist.ipynb

Large diffs are not rendered by default.

186 changes: 116 additions & 70 deletions examples/mnist/mnist.jl
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
# # Using MLJ to classifiy the MNIST image dataset


using Pkg
const DIR = @__DIR__
Pkg.activate(DIR)
Pkg.instantiate()

# **Julia version** is assumed to be 1.6.*
# **Julia version** is assumed to be ^1.6

using MLJ
using Flux
import MLJFlux
using Random
Random.seed!(123)
import MLJIteration # for `skip`

MLJ.color_off()

Expand All @@ -29,7 +29,7 @@ images, labels = MNIST.traindata();

# In MLJ, integers cannot be used for encoding categorical data, so we
# must force the labels to have the `Multiclass` [scientific
# type](https://alan-turing-institute.github.io/MLJScientificTypes.jl/dev/). For
# type](https://juliaai.github.io/ScientificTypes.jl/dev/). For
# more on this, see [Working with Categorical
# Data](https://alan-turing-institute.github.io/MLJ.jl/dev/working_with_categorical_data/).

Expand Down Expand Up @@ -65,29 +65,23 @@ struct MyConvBuilder
channels3::Int
end

flatten(x::AbstractArray) = reshape(x, :, size(x)[end])
half(x) = div(x, 2)

function MLJFlux.build(b::MyConvBuilder, n_in, n_out, n_channels)
make2d(x::AbstractArray) = reshape(x, :, size(x)[end])

function MLJFlux.build(b::MyConvBuilder, rng, n_in, n_out, n_channels)
k, c1, c2, c3 = b.filter_size, b.channels1, b.channels2, b.channels3

mod(k, 2) == 1 || error("`filter_size` must be odd. ")

p = div(k - 1, 2) # padding to preserve image size on convolution:

h = n_in[1] |> half |> half |> half # final "image" height
w = n_in[2] |> half |> half |> half # final "image" width

return Chain(
Conv((k, k), n_channels => c1, pad=(p, p), relu),
p = div(k - 1, 2) # padding to preserve image size
init = Flux.glorot_uniform(rng)
front = Chain(
Conv((k, k), n_channels => c1, pad=(p, p), relu, init=init),
MaxPool((2, 2)),
Conv((k, k), c1 => c2, pad=(p, p), relu),
Conv((k, k), c1 => c2, pad=(p, p), relu, init=init),
MaxPool((2, 2)),
Conv((k, k), c2 => c3, pad=(p, p), relu),
Conv((k, k), c2 => c3, pad=(p, p), relu, init=init),
MaxPool((2 ,2)),
flatten,
Dense(h*w*c3, n_out))
make2d)
d = Flux.outputsize(front, (n_in..., n_channels, 1)) |> first
return Chain(front, Dense(d, n_out, init=init))
end

# **Note.** There is no final `softmax` here, as this is applied by
Expand All @@ -100,12 +94,13 @@ end

ImageClassifier = @load ImageClassifier
clf = ImageClassifier(builder=MyConvBuilder(3, 16, 32, 32),
acceleration=CPU1(),
batch_size=50,
epochs=10)
epochs=10,
rng=123)

# You can add Flux options `optimiser=...` and `loss=...` here. At
# present, `loss` must be a Flux-compatible loss, not an MLJ measure.
# present, `loss` must be a Flux-compatible loss, not an MLJ
# measure. To run on a GPU, set `acceleration=CUDALib()`.

# Binding the model with data in an MLJ machine:
mach = machine(clf, images, labels);
Expand Down Expand Up @@ -138,10 +133,8 @@ fit!(mach, rows=1:500);
predicted_labels = predict(mach, rows=501:1000);
cross_entropy(predicted_labels, labels[501:1000]) |> mean

# Or, in one line (after resetting the RNG seed to ensure the same
# result):
# Or, in one line:

Random.seed!(123)
evaluate!(mach,
resampling=Holdout(fraction_train=0.5),
measure=cross_entropy,
Expand All @@ -151,76 +144,129 @@ evaluate!(mach,

# ## Wrapping the MLJFlux model with iteration controls

# Any iterative MLJ model implementing the warm restart functionality
# illustrated above for `ImageClassifier` can be wrapped in *iteration
# controls*, as we demonstrate next. For more on MLJ's
# `IteratedModel` wrapper, see the [MLJ
# Any iterative MLJFlux model can be wrapped in *iteration controls*,
# as we demonstrate next. For more on MLJ's `IteratedModel` wrapper,
# see the [MLJ
# documentation](https://alan-turing-institute.github.io/MLJ.jl/dev/controlling_iterative_models/).

# The "self-iterating" model, called `imodel` below, is for iterating the
# image classifier defined above until one of the following stopping
# criterion apply:
# The "self-iterating" classifier, called `iterated_clf` below, is for
# iterating the image classifier defined above until one of the
# following stopping criterion apply:

# - `Patience(3)` (3 consecutive increases in the loss)
# - `Patience(3)`: 3 consecutive increases in the loss
# - `InvalidValue()`: an out-of-sample loss, or a training loss, is `NaN`, `Inf`, or `-Inf`
# - `TimeLimit(t=5/60)`: training time has exceeded 5 minutes

# - `InvalidValue()` (an out-of-sample loss, or a training loss,
# is `NaN`, `Inf`, or `-Inf`)
# These checks (and other controls) will be applied every two epochs
# (because of the `Step(2)` control). Additionally, training a
# machine bound to `iterated_clf` will:

# - `TimeLimit(t=1/60)` (training time has exceeded one minute)
# - save a snapshot of the machine every three control cycles (every six epochs)
# - record traces of the out-of-sample loss and training losses for plotting
# - record mean value traces of each Flux parameter for plotting

# Additionally, training a machine bound to `imodel` will:
# For a complete list of controls, see [this
# table](https://alan-turing-institute.github.io/MLJ.jl/dev/controlling_iterative_models/#Controls-provided).

# - save a snapshot of the machine every three epochs
# ### Wrapping the classifier

# - record the out-of-sample loss and training losses for plotting
# Some helpers

# For a complete list of controls, see [this
# table](https://alan-turing-institute.github.io/MLJ.jl/dev/controlling_iterative_models/#Controls-provided).
make2d(x::AbstractArray) = reshape(x, :, size(x)[end])
make1d(x::AbstractArray) = reshape(x, length(x));

# To extract Flux params from an MLJFlux machine

parameters(mach) = make1d.(Flux.params(fitted_params(mach)));

# To store the traces:

losses = []
training_losses = [];

add_loss(loss) = push!(losses, loss)
add_training_loss(losses) = push!(training_losses, losses[end])

imodel = IteratedModel(model=clf,
controls=[Step(1), # train one epoch at a time
Patience(2),
InvalidValue(),
TimeLimit(0.5),
Save(joinpath(DIR, "mnist_machine.jlso")),
WithLossDo(), # for logging to `Info`
WithLossDo(add_loss),
WithTrainingLossesDo(add_training_loss)],
training_losses = []
parameter_means = Float32[];

# To update the traces:

update_loss(loss) = push!(losses, loss)
update_training_loss(losses) = push!(training_losses, losses[end])
update_means(mach) = append!(parameter_means, mean.(parameters(mach)));

# The controls to apply:

save_control =
MLJIteration.skip(Save(joinpath(DIR, "mnist.jlso")), predicate=3)

controls=[Step(2),
Patience(3),
InvalidValue(),
TimeLimit(5/60),
save_control,
WithLossDo(),
WithLossDo(update_loss),
WithTrainingLossesDo(update_training_loss),
Callback(update_means)
];

# The "self-iterating" classifier:

iterated_clf = IteratedModel(model=clf,
controls=controls,
resampling=Holdout(fraction_train=0.7),
measure=log_loss,
retrain=false)
measure=log_loss)

# ### Binding the wrapped model to data:

# Binding our self-iterating model to data:
mach = machine(iterated_clf, images, labels);

mach = machine(imodel, images, labels)

# And training on the first 500 images:
# ### Training

fit!(mach, rows=1:500)
fit!(mach, rows=1:500);

# A comparison of the training and out-of-sample losses:
# ### Comparison of the training and out-of-sample losses:

plot(losses,
title="Cross Entropy",
xlab = "epoch",
ylab = "root squared error",
label="out-of-sample")
plot!(training_losses, label="training")

# Retrieving a snapshot for a prediction:
savefig(joinpath(DIR, "loss.png"))

# ### Evolution of weights

n_epochs = length(losses)
n_parameters = div(length(parameter_means), n_epochs)
parameter_means2 = reshape(copy(parameter_means), n_parameters, n_epochs)'
plot(parameter_means2,
title="Flux parameter mean weights",
xlab = "epoch")

# **Note.** The the higher the number, the deeper the chain parameter.

mach2 = machine(joinpath(DIR, "mnist_machine5.jlso"))
savefig(joinpath(DIR, "weights.png"))


# ### Retrieving a snapshot for a prediction:

mach2 = machine(joinpath(DIR, "mnist3.jlso"))
predict_mode(mach2, images[501:503])

#-

# ### Restarting training

# Mutating `iterated_clf.controls` or `clf.epochs` (which is otherwise
# ignored) will allow you to restart training from where it left off.

iterated_clf.controls[2] = Patience(4)
fit!(mach, rows=1:500)

plot(losses,
xlab = "epoch",
ylab = "root squared error",
label="out-of-sample")
plot!(training_losses, label="training")

using Literate #src
Literate.markdown(@__FILE__, @__DIR__, execute=false) #src
Literate.notebook(@__FILE__, @__DIR__, execute=true) #src

Literate.notebook(@__FILE__, @__DIR__, execute=false) #src
Loading

0 comments on commit 3c5fc86

Please sign in to comment.