From a8b336e74c629a3b77392139208b388f8c49bb0d Mon Sep 17 00:00:00 2001 From: tiemvanderdeure Date: Tue, 23 Apr 2024 13:17:55 +0200 Subject: [PATCH] add NeuralNetworkBinaryClassifier --- src/classifier.jl | 26 ++++++++++++++++++++++++-- src/core.jl | 6 ++++++ src/types.jl | 9 +++++++-- 3 files changed, 37 insertions(+), 4 deletions(-) diff --git a/src/classifier.jl b/src/classifier.jl index ed9d4cf9..d7523fc5 100644 --- a/src/classifier.jl +++ b/src/classifier.jl @@ -14,14 +14,14 @@ function MLJFlux.shape(model::NeuralNetworkClassifier, X, y) end # builds the end-to-end Flux chain needed, given the `model` and `shape`: -MLJFlux.build(model::NeuralNetworkClassifier, rng, shape) = +MLJFlux.build(model::Union{NeuralNetworkClassifier, NeuralNetworkBinaryClassifier}, rng, shape) = Flux.Chain(build(model.builder, rng, shape...), model.finaliser) # returns the model `fitresult` (see "Adding Models for General Use" # section of the MLJ manual) which must always have the form `(chain, # metadata)`, where `metadata` is anything extra needed by `predict`: -MLJFlux.fitresult(model::NeuralNetworkClassifier, chain, y) = +MLJFlux.fitresult(model::Union{NeuralNetworkClassifier, NeuralNetworkBinaryClassifier}, chain, y) = (chain, MLJModelInterface.classes(y[1])) function MLJModelInterface.predict(model::NeuralNetworkClassifier, @@ -37,3 +37,25 @@ MLJModelInterface.metadata_model(NeuralNetworkClassifier, input=Union{AbstractMatrix{Continuous},Table(Continuous)}, target=AbstractVector{<:Finite}, path="MLJFlux.NeuralNetworkClassifier") + +#### Binary Classifier + +function MLJFlux.shape(model::NeuralNetworkBinaryClassifier, X, y) + X = X isa Matrix ? Tables.table(X) : X + n_input = Tables.schema(X).names |> length + return (n_input, 1) # n_output is always 1 for a binary classifier +end + +function MLJModelInterface.predict(model::NeuralNetworkBinaryClassifier, + fitresult, + Xnew) + chain, levels = fitresult + X = reformat(Xnew) + probs = vec(chain(X)) + return MLJModelInterface.UnivariateFinite(levels, probs; augment = true) +end + +MLJModelInterface.metadata_model(NeuralNetworkBinaryClassifier, + input=Union{AbstractMatrix{Continuous},Table(Continuous)}, + target=AbstractVector{<:Finite{2}}, + path="MLJFlux.NeuralNetworkBinaryClassifier") diff --git a/src/core.jl b/src/core.jl index cca5a145..b67d8aae 100644 --- a/src/core.jl +++ b/src/core.jl @@ -234,3 +234,9 @@ function collate(model, X, y) ymatrix = reformat(y) return [_get(Xmatrix, b) for b in row_batches], [_get(ymatrix, b) for b in row_batches] end +function collate(model::NeuralNetworkBinaryClassifier, X, y) + row_batches = Base.Iterators.partition(1:nrows(y), model.batch_size) + Xmatrix = reformat(X) + yvec = (y .== classes(y)[2])' # convert to boolean + return [_get(Xmatrix, b) for b in row_batches], [_get(yvec, b) for b in row_batches] +end diff --git a/src/types.jl b/src/types.jl index c608abbf..dc4fc939 100644 --- a/src/types.jl +++ b/src/types.jl @@ -3,10 +3,15 @@ abstract type MLJFluxDeterministic <: MLJModelInterface.Deterministic end const MLJFluxModel = Union{MLJFluxProbabilistic,MLJFluxDeterministic} -for Model in [:NeuralNetworkClassifier, :ImageClassifier] +for Model in [:NeuralNetworkClassifier, :NeuralNetworkBinaryClassifier, :ImageClassifier] + # default settings that are not equal across models default_builder_ex = Model == :ImageClassifier ? :(image_builder(VGGHack)) : Short() + default_finaliser = + Model == :NeuralNetworkBinaryClassifier ? Flux.σ : Flux.softmax + default_loss = + Model == :NeuralNetworkBinaryClassifier ? Flux.binarycrossentropy : Flux.crossentropy ex = quote mutable struct $Model{B,F,O,L} <: MLJFluxProbabilistic @@ -23,7 +28,7 @@ for Model in [:NeuralNetworkClassifier, :ImageClassifier] acceleration::AbstractResource # eg, `CPU1()` or `CUDALibs()` end - function $Model(; builder::B=$default_builder_ex, finaliser::F=Flux.softmax, optimiser::O=Flux.Optimise.Adam(), loss::L=Flux.crossentropy, epochs=10, batch_size=1, lambda=0, alpha=0, rng=Random.GLOBAL_RNG, optimiser_changes_trigger_retraining=false, acceleration=CPU1() + function $Model(; builder::B=$default_builder_ex, finaliser::F=$default_finaliser, optimiser::O=Flux.Optimise.Adam(), loss::L=$default_loss, epochs=10, batch_size=1, lambda=0, alpha=0, rng=Random.GLOBAL_RNG, optimiser_changes_trigger_retraining=false, acceleration=CPU1() ) where {B,F,O,L} model = $Model{B,F,O,L}(builder, finaliser, optimiser, loss, epochs, batch_size, lambda, alpha, rng, optimiser_changes_trigger_retraining, acceleration