Skip to content

Commit

Permalink
add NeuralNetworkBinaryClassifier
Browse files Browse the repository at this point in the history
  • Loading branch information
tiemvanderdeure committed Apr 23, 2024
1 parent c127b67 commit a8b336e
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 4 deletions.
26 changes: 24 additions & 2 deletions src/classifier.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@ function MLJFlux.shape(model::NeuralNetworkClassifier, X, y)
end

# builds the end-to-end Flux chain needed, given the `model` and `shape`:
MLJFlux.build(model::NeuralNetworkClassifier, rng, shape) =
MLJFlux.build(model::Union{NeuralNetworkClassifier, NeuralNetworkBinaryClassifier}, rng, shape) =
Flux.Chain(build(model.builder, rng, shape...),
model.finaliser)

# returns the model `fitresult` (see "Adding Models for General Use"
# section of the MLJ manual) which must always have the form `(chain,
# metadata)`, where `metadata` is anything extra needed by `predict`:
MLJFlux.fitresult(model::NeuralNetworkClassifier, chain, y) =
MLJFlux.fitresult(model::Union{NeuralNetworkClassifier, NeuralNetworkBinaryClassifier}, chain, y) =
(chain, MLJModelInterface.classes(y[1]))

function MLJModelInterface.predict(model::NeuralNetworkClassifier,
Expand All @@ -37,3 +37,25 @@ MLJModelInterface.metadata_model(NeuralNetworkClassifier,
input=Union{AbstractMatrix{Continuous},Table(Continuous)},
target=AbstractVector{<:Finite},
path="MLJFlux.NeuralNetworkClassifier")

#### Binary Classifier

function MLJFlux.shape(model::NeuralNetworkBinaryClassifier, X, y)
X = X isa Matrix ? Tables.table(X) : X
n_input = Tables.schema(X).names |> length
return (n_input, 1) # n_output is always 1 for a binary classifier
end

function MLJModelInterface.predict(model::NeuralNetworkBinaryClassifier,
fitresult,
Xnew)
chain, levels = fitresult
X = reformat(Xnew)
probs = vec(chain(X))
return MLJModelInterface.UnivariateFinite(levels, probs; augment = true)
end

MLJModelInterface.metadata_model(NeuralNetworkBinaryClassifier,
input=Union{AbstractMatrix{Continuous},Table(Continuous)},
target=AbstractVector{<:Finite{2}},
path="MLJFlux.NeuralNetworkBinaryClassifier")
6 changes: 6 additions & 0 deletions src/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -234,3 +234,9 @@ function collate(model, X, y)
ymatrix = reformat(y)
return [_get(Xmatrix, b) for b in row_batches], [_get(ymatrix, b) for b in row_batches]
end
function collate(model::NeuralNetworkBinaryClassifier, X, y)
row_batches = Base.Iterators.partition(1:nrows(y), model.batch_size)
Xmatrix = reformat(X)
yvec = (y .== classes(y)[2])' # convert to boolean
return [_get(Xmatrix, b) for b in row_batches], [_get(yvec, b) for b in row_batches]
end
9 changes: 7 additions & 2 deletions src/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,15 @@ abstract type MLJFluxDeterministic <: MLJModelInterface.Deterministic end

const MLJFluxModel = Union{MLJFluxProbabilistic,MLJFluxDeterministic}

for Model in [:NeuralNetworkClassifier, :ImageClassifier]
for Model in [:NeuralNetworkClassifier, :NeuralNetworkBinaryClassifier, :ImageClassifier]

# default settings that are not equal across models
default_builder_ex =
Model == :ImageClassifier ? :(image_builder(VGGHack)) : Short()
default_finaliser =
Model == :NeuralNetworkBinaryClassifier ? Flux.σ : Flux.softmax
default_loss =
Model == :NeuralNetworkBinaryClassifier ? Flux.binarycrossentropy : Flux.crossentropy

ex = quote
mutable struct $Model{B,F,O,L} <: MLJFluxProbabilistic
Expand All @@ -23,7 +28,7 @@ for Model in [:NeuralNetworkClassifier, :ImageClassifier]
acceleration::AbstractResource # eg, `CPU1()` or `CUDALibs()`
end

function $Model(; builder::B=$default_builder_ex, finaliser::F=Flux.softmax, optimiser::O=Flux.Optimise.Adam(), loss::L=Flux.crossentropy, epochs=10, batch_size=1, lambda=0, alpha=0, rng=Random.GLOBAL_RNG, optimiser_changes_trigger_retraining=false, acceleration=CPU1()
function $Model(; builder::B=$default_builder_ex, finaliser::F=$default_finaliser, optimiser::O=Flux.Optimise.Adam(), loss::L=$default_loss, epochs=10, batch_size=1, lambda=0, alpha=0, rng=Random.GLOBAL_RNG, optimiser_changes_trigger_retraining=false, acceleration=CPU1()
) where {B,F,O,L}

model = $Model{B,F,O,L}(builder, finaliser, optimiser, loss, epochs, batch_size, lambda, alpha, rng, optimiser_changes_trigger_retraining, acceleration
Expand Down

0 comments on commit a8b336e

Please sign in to comment.