From a8b336e74c629a3b77392139208b388f8c49bb0d Mon Sep 17 00:00:00 2001 From: tiemvanderdeure Date: Tue, 23 Apr 2024 13:17:55 +0200 Subject: [PATCH 1/2] add NeuralNetworkBinaryClassifier --- src/classifier.jl | 26 ++++++++++++++++++++++++-- src/core.jl | 6 ++++++ src/types.jl | 9 +++++++-- 3 files changed, 37 insertions(+), 4 deletions(-) diff --git a/src/classifier.jl b/src/classifier.jl index ed9d4cf9..d7523fc5 100644 --- a/src/classifier.jl +++ b/src/classifier.jl @@ -14,14 +14,14 @@ function MLJFlux.shape(model::NeuralNetworkClassifier, X, y) end # builds the end-to-end Flux chain needed, given the `model` and `shape`: -MLJFlux.build(model::NeuralNetworkClassifier, rng, shape) = +MLJFlux.build(model::Union{NeuralNetworkClassifier, NeuralNetworkBinaryClassifier}, rng, shape) = Flux.Chain(build(model.builder, rng, shape...), model.finaliser) # returns the model `fitresult` (see "Adding Models for General Use" # section of the MLJ manual) which must always have the form `(chain, # metadata)`, where `metadata` is anything extra needed by `predict`: -MLJFlux.fitresult(model::NeuralNetworkClassifier, chain, y) = +MLJFlux.fitresult(model::Union{NeuralNetworkClassifier, NeuralNetworkBinaryClassifier}, chain, y) = (chain, MLJModelInterface.classes(y[1])) function MLJModelInterface.predict(model::NeuralNetworkClassifier, @@ -37,3 +37,25 @@ MLJModelInterface.metadata_model(NeuralNetworkClassifier, input=Union{AbstractMatrix{Continuous},Table(Continuous)}, target=AbstractVector{<:Finite}, path="MLJFlux.NeuralNetworkClassifier") + +#### Binary Classifier + +function MLJFlux.shape(model::NeuralNetworkBinaryClassifier, X, y) + X = X isa Matrix ? Tables.table(X) : X + n_input = Tables.schema(X).names |> length + return (n_input, 1) # n_output is always 1 for a binary classifier +end + +function MLJModelInterface.predict(model::NeuralNetworkBinaryClassifier, + fitresult, + Xnew) + chain, levels = fitresult + X = reformat(Xnew) + probs = vec(chain(X)) + return MLJModelInterface.UnivariateFinite(levels, probs; augment = true) +end + +MLJModelInterface.metadata_model(NeuralNetworkBinaryClassifier, + input=Union{AbstractMatrix{Continuous},Table(Continuous)}, + target=AbstractVector{<:Finite{2}}, + path="MLJFlux.NeuralNetworkBinaryClassifier") diff --git a/src/core.jl b/src/core.jl index cca5a145..b67d8aae 100644 --- a/src/core.jl +++ b/src/core.jl @@ -234,3 +234,9 @@ function collate(model, X, y) ymatrix = reformat(y) return [_get(Xmatrix, b) for b in row_batches], [_get(ymatrix, b) for b in row_batches] end +function collate(model::NeuralNetworkBinaryClassifier, X, y) + row_batches = Base.Iterators.partition(1:nrows(y), model.batch_size) + Xmatrix = reformat(X) + yvec = (y .== classes(y)[2])' # convert to boolean + return [_get(Xmatrix, b) for b in row_batches], [_get(yvec, b) for b in row_batches] +end diff --git a/src/types.jl b/src/types.jl index c608abbf..dc4fc939 100644 --- a/src/types.jl +++ b/src/types.jl @@ -3,10 +3,15 @@ abstract type MLJFluxDeterministic <: MLJModelInterface.Deterministic end const MLJFluxModel = Union{MLJFluxProbabilistic,MLJFluxDeterministic} -for Model in [:NeuralNetworkClassifier, :ImageClassifier] +for Model in [:NeuralNetworkClassifier, :NeuralNetworkBinaryClassifier, :ImageClassifier] + # default settings that are not equal across models default_builder_ex = Model == :ImageClassifier ? :(image_builder(VGGHack)) : Short() + default_finaliser = + Model == :NeuralNetworkBinaryClassifier ? Flux.σ : Flux.softmax + default_loss = + Model == :NeuralNetworkBinaryClassifier ? Flux.binarycrossentropy : Flux.crossentropy ex = quote mutable struct $Model{B,F,O,L} <: MLJFluxProbabilistic @@ -23,7 +28,7 @@ for Model in [:NeuralNetworkClassifier, :ImageClassifier] acceleration::AbstractResource # eg, `CPU1()` or `CUDALibs()` end - function $Model(; builder::B=$default_builder_ex, finaliser::F=Flux.softmax, optimiser::O=Flux.Optimise.Adam(), loss::L=Flux.crossentropy, epochs=10, batch_size=1, lambda=0, alpha=0, rng=Random.GLOBAL_RNG, optimiser_changes_trigger_retraining=false, acceleration=CPU1() + function $Model(; builder::B=$default_builder_ex, finaliser::F=$default_finaliser, optimiser::O=Flux.Optimise.Adam(), loss::L=$default_loss, epochs=10, batch_size=1, lambda=0, alpha=0, rng=Random.GLOBAL_RNG, optimiser_changes_trigger_retraining=false, acceleration=CPU1() ) where {B,F,O,L} model = $Model{B,F,O,L}(builder, finaliser, optimiser, loss, epochs, batch_size, lambda, alpha, rng, optimiser_changes_trigger_retraining, acceleration From 5842ed9d03161a0663abcbdb7a123a5d960870c3 Mon Sep 17 00:00:00 2001 From: tiemvanderdeure Date: Thu, 30 May 2024 16:56:02 +0200 Subject: [PATCH 2/2] add some docs for the binary classifier --- docs/src/interface/Classification.md | 4 + docs/src/interface/Summary.md | 1 + src/MLJFlux.jl | 2 +- src/types.jl | 193 ++++++++++++++++++++++++++- 4 files changed, 198 insertions(+), 2 deletions(-) diff --git a/docs/src/interface/Classification.md b/docs/src/interface/Classification.md index 0491e8fc..d45d7a2b 100644 --- a/docs/src/interface/Classification.md +++ b/docs/src/interface/Classification.md @@ -1,3 +1,7 @@ ```@docs MLJFlux.NeuralNetworkClassifier +``` + +```@docs +MLJFlux.NeuralNetworkBinaryClassifier ``` \ No newline at end of file diff --git a/docs/src/interface/Summary.md b/docs/src/interface/Summary.md index ecff99d5..a8f7b383 100644 --- a/docs/src/interface/Summary.md +++ b/docs/src/interface/Summary.md @@ -12,6 +12,7 @@ Model Type | Prediction type | `scitype(X) <: _` | `scitype(y) <: _` `NeuralNetworkRegressor` | `Deterministic` | `Table(Continuous)` with `n_in` columns | `AbstractVector{<:Continuous)` (`n_out = 1`) `MultitargetNeuralNetworkRegressor` | `Deterministic` | `Table(Continuous)` with `n_in` columns | `<: Table(Continuous)` with `n_out` columns `NeuralNetworkClassifier` | `Probabilistic` | `<:Table(Continuous)` with `n_in` columns | `AbstractVector{<:Finite}` with `n_out` classes +`NeuralNetworkBinaryClassifier` | `Probabilistic` | `<:Table(Continuous)` with `n_in` columns | `AbstractVector{<:Finite{2}}` (`n_out = 2`) `ImageClassifier` | `Probabilistic` | `AbstractVector(<:Image{W,H})` with `n_in = (W, H)` | `AbstractVector{<:Finite}` with `n_out` classes diff --git a/src/MLJFlux.jl b/src/MLJFlux.jl index 5091d798..bd6011eb 100644 --- a/src/MLJFlux.jl +++ b/src/MLJFlux.jl @@ -29,7 +29,7 @@ include("image.jl") include("mlj_model_interface.jl") export NeuralNetworkRegressor, MultitargetNeuralNetworkRegressor -export NeuralNetworkClassifier, ImageClassifier +export NeuralNetworkClassifier, NeuralNetworkBinaryClassifier, ImageClassifier export CUDALibs, CPU1 diff --git a/src/types.jl b/src/types.jl index dc4fc939..3bc7ce05 100644 --- a/src/types.jl +++ b/src/types.jl @@ -282,11 +282,202 @@ plot(curve.parameter_values, ``` -See also [`ImageClassifier`](@ref). +See also [`ImageClassifier`](@ref), [`NeuralNetworkBinaryClassifier`](@ref). """ NeuralNetworkClassifier +""" +$(MMI.doc_header(NeuralNetworkBinaryClassifier)) + +`NeuralNetworkBinaryClassifier` is for training a data-dependent Flux.jl neural network +for making probabilistic predictions of a binary (`Multiclass{2}` or `OrderedFactor{2}`) target, +given a table of `Continuous` features. Users provide a recipe for constructing + the network, based on properties of the data that is encountered, by specifying + an appropriate `builder`. See MLJFlux documentation for more on builders. + +# Training data + +In MLJ or MLJBase, bind an instance `model` to data with + + mach = machine(model, X, y) + +Here: + +- `X` is either a `Matrix` or any table of input features (eg, a `DataFrame`) whose columns are of scitype + `Continuous`; check column scitypes with `schema(X)`. If `X` is a `Matrix`, + it is assumed to have columns corresponding to features and rows corresponding to observations. + +- `y` is the target, which can be any `AbstractVector` whose element scitype is `Multiclass{2}` + or `OrderedFactor{2}`; check the scitype with `scitype(y)` + +Train the machine with `fit!(mach, rows=...)`. + + +# Hyper-parameters + +- `builder=MLJFlux.Short()`: An MLJFlux builder that constructs a neural network. Possible + `builders` include: `MLJFlux.Linear`, `MLJFlux.Short`, and `MLJFlux.MLP`. See + MLJFlux.jl documentation for examples of user-defined builders. See also `finaliser` + below. + +- `optimiser::Flux.Adam()`: A `Flux.Optimise` optimiser. The optimiser performs the + updating of the weights of the network. For further reference, see [the Flux optimiser + documentation](https://fluxml.ai/Flux.jl/stable/training/optimisers/). To choose a + learning rate (the update rate of the optimizer), a good rule of thumb is to start out + at `10e-3`, and tune using powers of 10 between `1` and `1e-7`. + +- `loss=Flux.binarycrossentropy`: The loss function which the network will optimize. Should be a + function which can be called in the form `loss(yhat, y)`. Possible loss functions are + listed in [the Flux loss function + documentation](https://fluxml.ai/Flux.jl/stable/models/losses/). For a classification + task, the most natural loss functions are: + + - `Flux.binarycrossentropy`: Standard binary classification loss, also known as the log + loss. + + - `Flux.logitbinarycrossentropy`: Mathematically equal to crossentropy, but numerically more + stable than finalising the outputs with `σ` and then calculating + crossentropy. You will need to specify `finaliser=identity` to remove MLJFlux's + default sigmoid finaliser, and understand that the output of `predict` is then + unnormalized (no longer probabilistic). + + - `Flux.tversky_loss`: Used with imbalanced data to give more weight to false negatives. + + - `Flux.binary_focal_loss`: Used with highly imbalanced data. Weights harder examples more than + easier examples. + + Currently MLJ measures are not supported values of `loss`. + +- `epochs::Int=10`: The duration of training, in epochs. Typically, one epoch represents + one pass through the complete the training dataset. + +- `batch_size::int=1`: the batch size to be used for training, representing the number of + samples per update of the network weights. Typically, batch size is between 8 and + 512. Increassing batch size may accelerate training if `acceleration=CUDALibs()` and a + GPU is available. + +- `lambda::Float64=0`: The strength of the weight regularization penalty. Can be any value + in the range `[0, ∞)`. + +- `alpha::Float64=0`: The L2/L1 mix of regularization, in the range `[0, 1]`. A value of 0 + represents L2 regularization, and a value of 1 represents L1 regularization. + +- `rng::Union{AbstractRNG, Int64}`: The random number generator or seed used during + training. + +- `optimizer_changes_trigger_retraining::Bool=false`: Defines what happens when re-fitting + a machine if the associated optimiser has changed. If `true`, the associated machine + will retrain from scratch on `fit!` call, otherwise it will not. + +- `acceleration::AbstractResource=CPU1()`: Defines on what hardware training is done. For + Training on GPU, use `CUDALibs()`. + +- `finaliser=Flux.σ`: The final activation function of the neural network (applied + after the network defined by `builder`). Defaults to `Flux.σ`. + + +# Operations + +- `predict(mach, Xnew)`: return predictions of the target given new features `Xnew`, which + should have the same scitype as `X` above. Predictions are probabilistic but uncalibrated. + +- `predict_mode(mach, Xnew)`: Return the modes of the probabilistic predictions returned + above. + + +# Fitted parameters + +The fields of `fitted_params(mach)` are: + +- `chain`: The trained "chain" (Flux.jl model), namely the series of layers, + functions, and activations which make up the neural network. This includes + the final layer specified by `finaliser` (eg, `softmax`). + + +# Report + +The fields of `report(mach)` are: + +- `training_losses`: A vector of training losses (penalised if `lambda != 0`) in + historical order, of length `epochs + 1`. The first element is the pre-training loss. + +# Examples + +In this example we build a classification model using the Iris dataset. This is a very +basic example, using a default builder and no standardization. For a more advanced +illustration, see [`NeuralNetworkRegressor`](@ref) or [`ImageClassifier`](@ref), and +examples in the MLJFlux.jl documentation. + +```julia +using MLJ, Flux +import RDatasets +``` + +First, we can load the data: + +```julia +mtcars = RDatasets.dataset("datasets", "mtcars"); +y, X = unpack(mtcars, ==(:VS), in([:MPG, :Cyl, :Disp, :HP, :WT, :QSec])); # a vector and a table +y = categorical(y) # classifier takes catogorical input +X_f32 = Float32.(X) # To match floating point type of the neural network layers +NeuralNetworkBinaryClassifier = @load NeuralNetworkBinaryClassifier pkg=MLJFlux +bclf = NeuralNetworkBinaryClassifier() +``` + +Next, we can train the model: + +```julia +mach = machine(bclf, X_f32, y) +fit!(mach) +``` + +We can train the model in an incremental fashion, altering the learning rate as we go, +provided `optimizer_changes_trigger_retraining` is `false` (the default). Here, we also +change the number of (total) iterations: + +```julia +bclf.optimiser.eta = bclf.optimiser.eta * 2 +bclf.epochs = bclf.epochs + 5 + +fit!(mach, verbosity=2) # trains 5 more epochs +``` + +We can inspect the mean training loss using the `cross_entropy` function: + +```julia +training_loss = cross_entropy(predict(mach, X_f32), y) |> mean +``` + +And we can access the Flux chain (model) using `fitted_params`: + +```julia +chain = fitted_params(mach).chain +``` + +Finally, we can see how the out-of-sample performance changes over time, using MLJ's +`learning_curve` function: + +```julia +r = range(bclf, :epochs, lower=1, upper=200, scale=:log10) +curve = learning_curve(bclf, X_f32, y, + range=r, + resampling=Holdout(fraction_train=0.7), + measure=cross_entropy) +using Plots +plot(curve.parameter_values, + curve.measurements, + xlab=curve.parameter_name, + xscale=curve.parameter_scale, + ylab = "Cross Entropy") + +``` + +See also [`ImageClassifier`](@ref). + +""" +NeuralNetworkBinaryClassifier + """ $(MMI.doc_header(ImageClassifier))