Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add constructor for binary classifiers #248

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/src/interface/Classification.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
```@docs
MLJFlux.NeuralNetworkClassifier
```

```@docs
MLJFlux.NeuralNetworkBinaryClassifier
```
1 change: 1 addition & 0 deletions docs/src/interface/Summary.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ Model Type | Prediction type | `scitype(X) <: _` | `scitype(y) <: _`
`NeuralNetworkRegressor` | `Deterministic` | `Table(Continuous)` with `n_in` columns | `AbstractVector{<:Continuous)` (`n_out = 1`)
`MultitargetNeuralNetworkRegressor` | `Deterministic` | `Table(Continuous)` with `n_in` columns | `<: Table(Continuous)` with `n_out` columns
`NeuralNetworkClassifier` | `Probabilistic` | `<:Table(Continuous)` with `n_in` columns | `AbstractVector{<:Finite}` with `n_out` classes
`NeuralNetworkBinaryClassifier` | `Probabilistic` | `<:Table(Continuous)` with `n_in` columns | `AbstractVector{<:Finite{2}}` (`n_out = 2`)
`ImageClassifier` | `Probabilistic` | `AbstractVector(<:Image{W,H})` with `n_in = (W, H)` | `AbstractVector{<:Finite}` with `n_out` classes


Expand Down
2 changes: 1 addition & 1 deletion src/MLJFlux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ include("image.jl")
include("mlj_model_interface.jl")

export NeuralNetworkRegressor, MultitargetNeuralNetworkRegressor
export NeuralNetworkClassifier, ImageClassifier
export NeuralNetworkClassifier, NeuralNetworkBinaryClassifier, ImageClassifier
export CUDALibs, CPU1


Expand Down
26 changes: 24 additions & 2 deletions src/classifier.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@
end

# builds the end-to-end Flux chain needed, given the `model` and `shape`:
MLJFlux.build(model::NeuralNetworkClassifier, rng, shape) =
MLJFlux.build(model::Union{NeuralNetworkClassifier, NeuralNetworkBinaryClassifier}, rng, shape) =
Flux.Chain(build(model.builder, rng, shape...),
model.finaliser)

# returns the model `fitresult` (see "Adding Models for General Use"
# section of the MLJ manual) which must always have the form `(chain,
# metadata)`, where `metadata` is anything extra needed by `predict`:
MLJFlux.fitresult(model::NeuralNetworkClassifier, chain, y) =
MLJFlux.fitresult(model::Union{NeuralNetworkClassifier, NeuralNetworkBinaryClassifier}, chain, y) =
(chain, MLJModelInterface.classes(y[1]))

function MLJModelInterface.predict(model::NeuralNetworkClassifier,
Expand All @@ -37,3 +37,25 @@
input=Union{AbstractMatrix{Continuous},Table(Continuous)},
target=AbstractVector{<:Finite},
path="MLJFlux.NeuralNetworkClassifier")

#### Binary Classifier

function MLJFlux.shape(model::NeuralNetworkBinaryClassifier, X, y)
X = X isa Matrix ? Tables.table(X) : X
n_input = Tables.schema(X).names |> length
return (n_input, 1) # n_output is always 1 for a binary classifier

Check warning on line 46 in src/classifier.jl

View check run for this annotation

Codecov / codecov/patch

src/classifier.jl#L43-L46

Added lines #L43 - L46 were not covered by tests
end

function MLJModelInterface.predict(model::NeuralNetworkBinaryClassifier,

Check warning on line 49 in src/classifier.jl

View check run for this annotation

Codecov / codecov/patch

src/classifier.jl#L49

Added line #L49 was not covered by tests
fitresult,
Xnew)
chain, levels = fitresult
X = reformat(Xnew)
probs = vec(chain(X))
return MLJModelInterface.UnivariateFinite(levels, probs; augment = true)

Check warning on line 55 in src/classifier.jl

View check run for this annotation

Codecov / codecov/patch

src/classifier.jl#L52-L55

Added lines #L52 - L55 were not covered by tests
end

MLJModelInterface.metadata_model(NeuralNetworkBinaryClassifier,
input=Union{AbstractMatrix{Continuous},Table(Continuous)},
target=AbstractVector{<:Finite{2}},
path="MLJFlux.NeuralNetworkBinaryClassifier")
6 changes: 6 additions & 0 deletions src/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -234,3 +234,9 @@
ymatrix = reformat(y)
return [_get(Xmatrix, b) for b in row_batches], [_get(ymatrix, b) for b in row_batches]
end
function collate(model::NeuralNetworkBinaryClassifier, X, y)
row_batches = Base.Iterators.partition(1:nrows(y), model.batch_size)
Xmatrix = reformat(X)
yvec = (y .== classes(y)[2])' # convert to boolean
return [_get(Xmatrix, b) for b in row_batches], [_get(yvec, b) for b in row_batches]

Check warning on line 241 in src/core.jl

View check run for this annotation

Codecov / codecov/patch

src/core.jl#L237-L241

Added lines #L237 - L241 were not covered by tests
end
202 changes: 199 additions & 3 deletions src/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,15 @@ abstract type MLJFluxDeterministic <: MLJModelInterface.Deterministic end

const MLJFluxModel = Union{MLJFluxProbabilistic,MLJFluxDeterministic}

for Model in [:NeuralNetworkClassifier, :ImageClassifier]
for Model in [:NeuralNetworkClassifier, :NeuralNetworkBinaryClassifier, :ImageClassifier]

# default settings that are not equal across models
default_builder_ex =
Model == :ImageClassifier ? :(image_builder(VGGHack)) : Short()
default_finaliser =
Model == :NeuralNetworkBinaryClassifier ? Flux.σ : Flux.softmax
default_loss =
Model == :NeuralNetworkBinaryClassifier ? Flux.binarycrossentropy : Flux.crossentropy

ex = quote
mutable struct $Model{B,F,O,L} <: MLJFluxProbabilistic
Expand All @@ -23,7 +28,7 @@ for Model in [:NeuralNetworkClassifier, :ImageClassifier]
acceleration::AbstractResource # eg, `CPU1()` or `CUDALibs()`
end

function $Model(; builder::B=$default_builder_ex, finaliser::F=Flux.softmax, optimiser::O=Flux.Optimise.Adam(), loss::L=Flux.crossentropy, epochs=10, batch_size=1, lambda=0, alpha=0, rng=Random.GLOBAL_RNG, optimiser_changes_trigger_retraining=false, acceleration=CPU1()
function $Model(; builder::B=$default_builder_ex, finaliser::F=$default_finaliser, optimiser::O=Flux.Optimise.Adam(), loss::L=$default_loss, epochs=10, batch_size=1, lambda=0, alpha=0, rng=Random.GLOBAL_RNG, optimiser_changes_trigger_retraining=false, acceleration=CPU1()
) where {B,F,O,L}

model = $Model{B,F,O,L}(builder, finaliser, optimiser, loss, epochs, batch_size, lambda, alpha, rng, optimiser_changes_trigger_retraining, acceleration
Expand Down Expand Up @@ -277,11 +282,202 @@ plot(curve.parameter_values,

```

See also [`ImageClassifier`](@ref).
See also [`ImageClassifier`](@ref), [`NeuralNetworkBinaryClassifier`](@ref).

"""
NeuralNetworkClassifier

"""
$(MMI.doc_header(NeuralNetworkBinaryClassifier))

`NeuralNetworkBinaryClassifier` is for training a data-dependent Flux.jl neural network
for making probabilistic predictions of a binary (`Multiclass{2}` or `OrderedFactor{2}`) target,
given a table of `Continuous` features. Users provide a recipe for constructing
the network, based on properties of the data that is encountered, by specifying
an appropriate `builder`. See MLJFlux documentation for more on builders.

# Training data

In MLJ or MLJBase, bind an instance `model` to data with

mach = machine(model, X, y)

Here:

- `X` is either a `Matrix` or any table of input features (eg, a `DataFrame`) whose columns are of scitype
`Continuous`; check column scitypes with `schema(X)`. If `X` is a `Matrix`,
it is assumed to have columns corresponding to features and rows corresponding to observations.

- `y` is the target, which can be any `AbstractVector` whose element scitype is `Multiclass{2}`
or `OrderedFactor{2}`; check the scitype with `scitype(y)`

Train the machine with `fit!(mach, rows=...)`.


# Hyper-parameters

- `builder=MLJFlux.Short()`: An MLJFlux builder that constructs a neural network. Possible
`builders` include: `MLJFlux.Linear`, `MLJFlux.Short`, and `MLJFlux.MLP`. See
MLJFlux.jl documentation for examples of user-defined builders. See also `finaliser`
below.

- `optimiser::Flux.Adam()`: A `Flux.Optimise` optimiser. The optimiser performs the
updating of the weights of the network. For further reference, see [the Flux optimiser
documentation](https://fluxml.ai/Flux.jl/stable/training/optimisers/). To choose a
learning rate (the update rate of the optimizer), a good rule of thumb is to start out
at `10e-3`, and tune using powers of 10 between `1` and `1e-7`.

- `loss=Flux.binarycrossentropy`: The loss function which the network will optimize. Should be a
function which can be called in the form `loss(yhat, y)`. Possible loss functions are
listed in [the Flux loss function
documentation](https://fluxml.ai/Flux.jl/stable/models/losses/). For a classification
task, the most natural loss functions are:

- `Flux.binarycrossentropy`: Standard binary classification loss, also known as the log
loss.

- `Flux.logitbinarycrossentropy`: Mathematically equal to crossentropy, but numerically more
stable than finalising the outputs with `σ` and then calculating
crossentropy. You will need to specify `finaliser=identity` to remove MLJFlux's
default sigmoid finaliser, and understand that the output of `predict` is then
unnormalized (no longer probabilistic).

- `Flux.tversky_loss`: Used with imbalanced data to give more weight to false negatives.

- `Flux.binary_focal_loss`: Used with highly imbalanced data. Weights harder examples more than
easier examples.

Currently MLJ measures are not supported values of `loss`.

- `epochs::Int=10`: The duration of training, in epochs. Typically, one epoch represents
one pass through the complete the training dataset.

- `batch_size::int=1`: the batch size to be used for training, representing the number of
samples per update of the network weights. Typically, batch size is between 8 and
512. Increassing batch size may accelerate training if `acceleration=CUDALibs()` and a
GPU is available.

- `lambda::Float64=0`: The strength of the weight regularization penalty. Can be any value
in the range `[0, ∞)`.

- `alpha::Float64=0`: The L2/L1 mix of regularization, in the range `[0, 1]`. A value of 0
represents L2 regularization, and a value of 1 represents L1 regularization.

- `rng::Union{AbstractRNG, Int64}`: The random number generator or seed used during
training.

- `optimizer_changes_trigger_retraining::Bool=false`: Defines what happens when re-fitting
a machine if the associated optimiser has changed. If `true`, the associated machine
will retrain from scratch on `fit!` call, otherwise it will not.

- `acceleration::AbstractResource=CPU1()`: Defines on what hardware training is done. For
Training on GPU, use `CUDALibs()`.

- `finaliser=Flux.σ`: The final activation function of the neural network (applied
after the network defined by `builder`). Defaults to `Flux.σ`.


# Operations

- `predict(mach, Xnew)`: return predictions of the target given new features `Xnew`, which
should have the same scitype as `X` above. Predictions are probabilistic but uncalibrated.

- `predict_mode(mach, Xnew)`: Return the modes of the probabilistic predictions returned
above.


# Fitted parameters

The fields of `fitted_params(mach)` are:

- `chain`: The trained "chain" (Flux.jl model), namely the series of layers,
functions, and activations which make up the neural network. This includes
the final layer specified by `finaliser` (eg, `softmax`).


# Report

The fields of `report(mach)` are:

- `training_losses`: A vector of training losses (penalised if `lambda != 0`) in
historical order, of length `epochs + 1`. The first element is the pre-training loss.

# Examples

In this example we build a classification model using the Iris dataset. This is a very
basic example, using a default builder and no standardization. For a more advanced
illustration, see [`NeuralNetworkRegressor`](@ref) or [`ImageClassifier`](@ref), and
examples in the MLJFlux.jl documentation.

```julia
using MLJ, Flux
import RDatasets
```

First, we can load the data:

```julia
mtcars = RDatasets.dataset("datasets", "mtcars");
y, X = unpack(mtcars, ==(:VS), in([:MPG, :Cyl, :Disp, :HP, :WT, :QSec])); # a vector and a table
y = categorical(y) # classifier takes catogorical input
X_f32 = Float32.(X) # To match floating point type of the neural network layers
NeuralNetworkBinaryClassifier = @load NeuralNetworkBinaryClassifier pkg=MLJFlux
bclf = NeuralNetworkBinaryClassifier()
```

Next, we can train the model:

```julia
mach = machine(bclf, X_f32, y)
fit!(mach)
```

We can train the model in an incremental fashion, altering the learning rate as we go,
provided `optimizer_changes_trigger_retraining` is `false` (the default). Here, we also
change the number of (total) iterations:

```julia
bclf.optimiser.eta = bclf.optimiser.eta * 2
bclf.epochs = bclf.epochs + 5

fit!(mach, verbosity=2) # trains 5 more epochs
```

We can inspect the mean training loss using the `cross_entropy` function:

```julia
training_loss = cross_entropy(predict(mach, X_f32), y) |> mean
```

And we can access the Flux chain (model) using `fitted_params`:

```julia
chain = fitted_params(mach).chain
```

Finally, we can see how the out-of-sample performance changes over time, using MLJ's
`learning_curve` function:

```julia
r = range(bclf, :epochs, lower=1, upper=200, scale=:log10)
curve = learning_curve(bclf, X_f32, y,
range=r,
resampling=Holdout(fraction_train=0.7),
measure=cross_entropy)
using Plots
plot(curve.parameter_values,
curve.measurements,
xlab=curve.parameter_name,
xscale=curve.parameter_scale,
ylab = "Cross Entropy")

```

See also [`ImageClassifier`](@ref).

"""
NeuralNetworkBinaryClassifier

"""
$(MMI.doc_header(ImageClassifier))

Expand Down
Loading