This tutorial is available as a Jupyter notebook or julia script here.
Julia version is assumed to be 1.10.*
using MLJ
+using Flux
+import MLJFlux
+import MLUtils
+import MLJIteration # for `skip`
If running on a GPU, you will also need to import CUDA
and import cuDNN
.
using Plots
+gr(size=(600, 300*(sqrt(5)-1)));
Downloading the MNIST image dataset:
import MLDatasets: MNIST
+
+ENV["DATADEPS_ALWAYS_ACCEPT"] = true
+images, labels = MNIST(split=:train)[:];
In MLJ, integers cannot be used for encoding categorical data, so we must force the labels to have the Multiclass
scientific type. For more on this, see Working with Categorical Data.
labels = coerce(labels, Multiclass);
+images = coerce(images, GrayImage);
Checking scientific types:
@assert scitype(images) <: AbstractVector{<:Image}
+@assert scitype(labels) <: AbstractVector{<:Finite}
Looks good.
For general instructions on coercing image data, see Type coercion for image data
images[1]
We start by defining a suitable Builder
object. This is a recipe for building the neural network. Our builder will work for images of any (constant) size, whether they be color or black and white (ie, single or multi-channel). The architecture always consists of six alternating convolution and max-pool layers, and a final dense layer; the filter size and the number of channels after each convolution layer is customisable.
import MLJFlux
+struct MyConvBuilder
+ filter_size::Int
+ channels1::Int
+ channels2::Int
+ channels3::Int
+end
+
+function MLJFlux.build(b::MyConvBuilder, rng, n_in, n_out, n_channels)
+ k, c1, c2, c3 = b.filter_size, b.channels1, b.channels2, b.channels3
+ mod(k, 2) == 1 || error("`filter_size` must be odd. ")
+ p = div(k - 1, 2) # padding to preserve image size
+ init = Flux.glorot_uniform(rng)
+ front = Chain(
+ Conv((k, k), n_channels => c1, pad=(p, p), relu, init=init),
+ MaxPool((2, 2)),
+ Conv((k, k), c1 => c2, pad=(p, p), relu, init=init),
+ MaxPool((2, 2)),
+ Conv((k, k), c2 => c3, pad=(p, p), relu, init=init),
+ MaxPool((2 ,2)),
+ MLUtils.flatten)
+ d = Flux.outputsize(front, (n_in..., n_channels, 1)) |> first
+ return Chain(front, Dense(d, n_out, init=init))
+end
Notes.
There is no final softmax
here, as this is applied by default in all MLJFLux classifiers. Customisation of this behaviour is controlled using using the finaliser
hyperparameter of the classifier.
Instead of calculating the padding p
, Flux can infer the required padding in each dimension, which you enable by replacing pad = (p, p)
with pad = SamePad()
.
We now define the MLJ model.
ImageClassifier = @load ImageClassifier
+clf = ImageClassifier(
+ builder=MyConvBuilder(3, 16, 32, 32),
+ batch_size=50,
+ epochs=10,
+ rng=123,
+)
ImageClassifier(
+ builder = Main.MyConvBuilder(3, 16, 32, 32),
+ finaliser = NNlib.softmax,
+ optimiser = Adam(0.001, (0.9, 0.999), 1.0e-8),
+ loss = Flux.Losses.crossentropy,
+ epochs = 10,
+ batch_size = 50,
+ lambda = 0.0,
+ alpha = 0.0,
+ rng = 123,
+ optimiser_changes_trigger_retraining = false,
+ acceleration = CPU1{Nothing}(nothing))
You can add Flux options optimiser=...
and loss=...
in the above constructor call. At present, loss
must be a Flux-compatible loss, not an MLJ measure. To run on a GPU, add to the constructor acceleration=CUDALib()
and omit rng
.
For illustration purposes, we won't use all the data here:
train = 1:500
+test = 501:1000
501:1000
Binding the model with data in an MLJ machine:
mach = machine(clf, images, labels);
Training for 10 epochs on the first 500 images:
fit!(mach, rows=train, verbosity=2);
[ Info: Training machine(ImageClassifier(builder = Main.MyConvBuilder(3, 16, 32, 32), …), …).
+[ Info: Loss is 2.28
+[ Info: Loss is 2.171
+[ Info: Loss is 1.942
+[ Info: Loss is 1.505
+[ Info: Loss is 0.9922
+[ Info: Loss is 0.6912
+[ Info: Loss is 0.5584
+[ Info: Loss is 0.4542
+[ Info: Loss is 0.3809
+[ Info: Loss is 0.3272
Inspecting:
report(mach)
(training_losses = Float32[2.3174262, 2.280439, 2.1711705, 1.9420795, 1.5045885, 0.99224484, 0.69117606, 0.5583703, 0.45424515, 0.38085267, 0.3271538],)
chain = fitted_params(mach)
(chain = Chain(Chain(Chain(Conv((3, 3), 1 => 16, relu, pad=1), MaxPool((2, 2)), Conv((3, 3), 16 => 32, relu, pad=1), MaxPool((2, 2)), Conv((3, 3), 32 => 32, relu, pad=1), MaxPool((2, 2)), flatten), Dense(288 => 10)), softmax),)
Flux.params(chain)[2]
16-element Vector{Float32}:
+ 0.003225543
+ 0.019304937
+ 0.062040687
+ 0.024518687
+ 0.05317823
+ 0.069572166
+ 0.044410173
+ 0.024950704
+ 0.015806748
+ 0.015081032
+ 0.017513964
+ 0.02133927
+ 0.040562775
+ 0.0018777152
+ 0.055122323
+ 0.057923194
Adding 20 more epochs:
clf.epochs = clf.epochs + 20
+fit!(mach, rows=train);
[ Info: Updating machine(ImageClassifier(builder = Main.MyConvBuilder(3, 16, 32, 32), …), …).
+
Optimising neural net: 10%[==> ] ETA: 0:00:07
Optimising neural net: 14%[===> ] ETA: 0:00:08
Optimising neural net: 19%[====> ] ETA: 0:00:08
Optimising neural net: 24%[=====> ] ETA: 0:00:08
Optimising neural net: 29%[=======> ] ETA: 0:00:07
Optimising neural net: 33%[========> ] ETA: 0:00:07
Optimising neural net: 38%[=========> ] ETA: 0:00:07
Optimising neural net: 43%[==========> ] ETA: 0:00:06
Optimising neural net: 48%[===========> ] ETA: 0:00:05
Optimising neural net: 52%[=============> ] ETA: 0:00:05
Optimising neural net: 57%[==============> ] ETA: 0:00:04
Optimising neural net: 62%[===============> ] ETA: 0:00:04
Optimising neural net: 67%[================> ] ETA: 0:00:04
Optimising neural net: 71%[=================> ] ETA: 0:00:03
Optimising neural net: 76%[===================> ] ETA: 0:00:03
Optimising neural net: 81%[====================> ] ETA: 0:00:02
Optimising neural net: 86%[=====================> ] ETA: 0:00:02
Optimising neural net: 90%[======================> ] ETA: 0:00:01
Optimising neural net: 95%[=======================> ] ETA: 0:00:01
Optimising neural net: 100%[=========================] Time: 0:00:10
Computing an out-of-sample estimate of the loss:
predicted_labels = predict(mach, rows=test);
+cross_entropy(predicted_labels, labels[test])
0.4883231265583621
Or to fit and predict, in one line:
evaluate!(mach,
+ resampling=Holdout(fraction_train=0.5),
+ measure=cross_entropy,
+ rows=1:1000,
+ verbosity=0)
PerformanceEvaluation object with these fields:
+ model, measure, operation,
+ measurement, per_fold, per_observation,
+ fitted_params_per_fold, report_per_fold,
+ train_test_rows, resampling, repeats
+Extract:
+┌──────────────────────┬───────────┬─────────────┐
+│ measure │ operation │ measurement │
+├──────────────────────┼───────────┼─────────────┤
+│ LogLoss( │ predict │ 0.488 │
+│ tol = 2.22045e-16) │ │ │
+└──────────────────────┴───────────┴─────────────┘
+
Any iterative MLJFlux model can be wrapped in iteration controls, as we demonstrate next. For more on MLJ's IteratedModel
wrapper, see the MLJ documentation.
The "self-iterating" classifier, called iterated_clf
below, is for iterating the image classifier defined above until one of the following stopping criterion apply:
Patience(3)
: 3 consecutive increases in the lossInvalidValue()
: an out-of-sample loss, or a training loss, is NaN
, Inf
, or -Inf
TimeLimit(t=5/60)
: training time has exceeded 5 minutes
These checks (and other controls) will be applied every two epochs (because of the Step(2)
control). Additionally, training a machine bound to iterated_clf
will:
- save a snapshot of the machine every three control cycles (every six epochs)
- record traces of the out-of-sample loss and training losses for plotting
- record mean value traces of each Flux parameter for plotting
For a complete list of controls, see this table.
Some helpers
To extract Flux params from an MLJFlux machine
parameters(mach) = vec.(Flux.params(fitted_params(mach)));
To store the traces:
losses = []
+training_losses = []
+parameter_means = Float32[];
+epochs = []
Any[]
To update the traces:
update_loss(loss) = push!(losses, loss)
+update_training_loss(losses) = push!(training_losses, losses[end])
+update_means(mach) = append!(parameter_means, mean.(parameters(mach)));
+update_epochs(epoch) = push!(epochs, epoch)
update_epochs (generic function with 1 method)
The controls to apply:
save_control =
+ MLJIteration.skip(Save(joinpath(tempdir(), "mnist.jls")), predicate=3)
+
+controls=[
+ Step(2),
+ Patience(3),
+ InvalidValue(),
+ TimeLimit(5/60),
+ save_control,
+ WithLossDo(),
+ WithLossDo(update_loss),
+ WithTrainingLossesDo(update_training_loss),
+ Callback(update_means),
+ WithIterationsDo(update_epochs),
+];
The "self-iterating" classifier:
iterated_clf = IteratedModel(
+ clf,
+ controls=controls,
+ resampling=Holdout(fraction_train=0.7),
+ measure=log_loss,
+)
ProbabilisticIteratedModel(
+ model = ImageClassifier(
+ builder = Main.MyConvBuilder(3, 16, 32, 32),
+ finaliser = NNlib.softmax,
+ optimiser = Adam(0.001, (0.9, 0.999), 1.0e-8),
+ loss = Flux.Losses.crossentropy,
+ epochs = 30,
+ batch_size = 50,
+ lambda = 0.0,
+ alpha = 0.0,
+ rng = 123,
+ optimiser_changes_trigger_retraining = false,
+ acceleration = CPU1{Nothing}(nothing)),
+ controls = Any[IterationControl.Step(2), EarlyStopping.Patience(3), EarlyStopping.InvalidValue(), EarlyStopping.TimeLimit(Dates.Millisecond(300000)), IterationControl.Skip{MLJIteration.Save{typeof(Serialization.serialize)}, IterationControl.var"#8#9"{Int64}}(MLJIteration.Save{typeof(Serialization.serialize)}("/tmp/mnist.jls", Serialization.serialize), IterationControl.var"#8#9"{Int64}(3)), IterationControl.WithLossDo{IterationControl.var"#20#22"}(IterationControl.var"#20#22"(), false, nothing), IterationControl.WithLossDo{typeof(Main.update_loss)}(Main.update_loss, false, nothing), IterationControl.WithTrainingLossesDo{typeof(Main.update_training_loss)}(Main.update_training_loss, false, nothing), IterationControl.Callback{typeof(Main.update_means)}(Main.update_means, false, nothing, false), MLJIteration.WithIterationsDo{typeof(Main.update_epochs)}(Main.update_epochs, false, nothing)],
+ resampling = Holdout(
+ fraction_train = 0.7,
+ shuffle = false,
+ rng = Random._GLOBAL_RNG()),
+ measure = LogLoss(tol = 2.22045e-16),
+ weights = nothing,
+ class_weights = nothing,
+ operation = MLJModelInterface.predict,
+ retrain = false,
+ check_measure = true,
+ iteration_parameter = nothing,
+ cache = true)
mach = machine(iterated_clf, images, labels);
fit!(mach, rows=train);
[ Info: Training machine(ProbabilisticIteratedModel(model = ImageClassifier(builder = Main.MyConvBuilder(3, 16, 32, 32), …), …), …).
+[ Info: No iteration parameter specified. Using `iteration_parameter=:(epochs)`.
+[ Info: loss: 2.195050130190149
+[ Info: loss: 1.8450074691283658
+[ Info: Saving "/tmp/mnist1.jls".
+[ Info: loss: 1.1388123685158849
+[ Info: loss: 0.702997545486733
+[ Info: loss: 0.5778269559910739
+[ Info: Saving "/tmp/mnist2.jls".
+[ Info: loss: 0.5222495075757826
+[ Info: loss: 0.49847208228951995
+[ Info: loss: 0.4897800580510804
+[ Info: Saving "/tmp/mnist3.jls".
+[ Info: loss: 0.4893840844808948
+[ Info: loss: 0.49094569068535143
+[ Info: loss: 0.49593260647952264
+[ Info: Saving "/tmp/mnist4.jls".
+[ Info: loss: 0.5062357308150314
+[ Info: final loss: 0.5062357308150314
+[ Info: final training loss: 0.059303638
+[ Info: Stop triggered by EarlyStopping.Patience(3) stopping criterion.
+[ Info: Total of 24 iterations.
plot(
+ epochs,
+ losses,
+ xlab = "epoch",
+ ylab = "cross entropy",
+ label="out-of-sample",
+)
+plot!(epochs, training_losses, label="training")
+
+savefig(joinpath(tempdir(), "loss.png"))
"/tmp/loss.png"
n_epochs = length(losses)
+n_parameters = div(length(parameter_means), n_epochs)
+parameter_means2 = reshape(copy(parameter_means), n_parameters, n_epochs)'
+plot(
+ epochs,
+ parameter_means2,
+ title="Flux parameter mean weights",
+ xlab = "epoch",
+)
Note. The higher the number in the plot legend, the deeper the layer we are **weight-averaging.
savefig(joinpath(tempdir(), "weights.png"))
"/tmp/weights.png"
mach2 = machine(joinpath(tempdir(), "mnist3.jls"))
+predict_mode(mach2, images[501:503])
3-element CategoricalArrays.CategoricalArray{Int64,1,UInt32}:
+ 7
+ 9
+ 5
Mutating iterated_clf.controls
or clf.epochs
(which is otherwise ignored) will allow you to restart training from where it left off.
iterated_clf.controls[2] = Patience(4)
+fit!(mach, rows=train)
+
+plot(
+ epochs,
+ losses,
+ xlab = "epoch",
+ ylab = "cross entropy",
+ label="out-of-sample",
+)
+plot!(epochs, training_losses, label="training")
This page was generated using Literate.jl.