Skip to content

Commit

Permalink
Merge branch 'binaryclassifier' of https://github.com/tiemvanderdeure…
Browse files Browse the repository at this point in the history
…/MLJFlux.jl into tiemvanderdeure-binaryclassifier
  • Loading branch information
ablaom committed Jun 10, 2024
2 parents f38e0cf + 5842ed9 commit dd00443
Show file tree
Hide file tree
Showing 6 changed files with 274 additions and 47 deletions.
4 changes: 4 additions & 0 deletions docs/src/interface/Classification.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
```@docs
MLJFlux.NeuralNetworkClassifier
```

```@docs
MLJFlux.NeuralNetworkBinaryClassifier
```
1 change: 1 addition & 0 deletions docs/src/interface/Summary.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ Model Type | Prediction type | `scitype(X) <: _` | `scitype(y) <: _`
`NeuralNetworkRegressor` | `Deterministic` | `Table(Continuous)` with `n_in` columns | `AbstractVector{<:Continuous)` (`n_out = 1`)
`MultitargetNeuralNetworkRegressor` | `Deterministic` | `Table(Continuous)` with `n_in` columns | `<: Table(Continuous)` with `n_out` columns
`NeuralNetworkClassifier` | `Probabilistic` | `<:Table(Continuous)` with `n_in` columns | `AbstractVector{<:Finite}` with `n_out` classes
`NeuralNetworkBinaryClassifier` | `Probabilistic` | `<:Table(Continuous)` with `n_in` columns | `AbstractVector{<:Finite{2}}` (`n_out = 2`)
`ImageClassifier` | `Probabilistic` | `AbstractVector(<:Image{W,H})` with `n_in = (W, H)` | `AbstractVector{<:Finite}` with `n_out` classes


Expand Down
2 changes: 1 addition & 1 deletion src/MLJFlux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ include("image.jl")
include("mlj_model_interface.jl")

export NeuralNetworkRegressor, MultitargetNeuralNetworkRegressor
export NeuralNetworkClassifier, ImageClassifier
export NeuralNetworkClassifier, NeuralNetworkBinaryClassifier, ImageClassifier
export CUDALibs, CPU1

include("deprecated.jl")
Expand Down
26 changes: 24 additions & 2 deletions src/classifier.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@ function MLJFlux.shape(model::NeuralNetworkClassifier, X, y)
end

# builds the end-to-end Flux chain needed, given the `model` and `shape`:
MLJFlux.build(model::NeuralNetworkClassifier, rng, shape) =
MLJFlux.build(model::Union{NeuralNetworkClassifier, NeuralNetworkBinaryClassifier}, rng, shape) =
Flux.Chain(build(model.builder, rng, shape...),
model.finaliser)

# returns the model `fitresult` (see "Adding Models for General Use"
# section of the MLJ manual) which must always have the form `(chain,
# metadata)`, where `metadata` is anything extra needed by `predict`:
MLJFlux.fitresult(model::NeuralNetworkClassifier, chain, y) =
MLJFlux.fitresult(model::Union{NeuralNetworkClassifier, NeuralNetworkBinaryClassifier}, chain, y) =
(chain, MLJModelInterface.classes(y[1]))

function MLJModelInterface.predict(model::NeuralNetworkClassifier,
Expand All @@ -37,3 +37,25 @@ MLJModelInterface.metadata_model(NeuralNetworkClassifier,
input=Union{AbstractMatrix{Continuous},Table(Continuous)},
target=AbstractVector{<:Finite},
path="MLJFlux.NeuralNetworkClassifier")

#### Binary Classifier

function MLJFlux.shape(model::NeuralNetworkBinaryClassifier, X, y)
X = X isa Matrix ? Tables.table(X) : X
n_input = Tables.schema(X).names |> length
return (n_input, 1) # n_output is always 1 for a binary classifier
end

function MLJModelInterface.predict(model::NeuralNetworkBinaryClassifier,
fitresult,
Xnew)
chain, levels = fitresult
X = reformat(Xnew)
probs = vec(chain(X))
return MLJModelInterface.UnivariateFinite(levels, probs; augment = true)
end

MLJModelInterface.metadata_model(NeuralNetworkBinaryClassifier,
input=Union{AbstractMatrix{Continuous},Table(Continuous)},
target=AbstractVector{<:Finite{2}},
path="MLJFlux.NeuralNetworkBinaryClassifier")
6 changes: 6 additions & 0 deletions src/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -274,3 +274,9 @@ function collate(model, X, y)
ymatrix = reformat(y)
return [_get(Xmatrix, b) for b in row_batches], [_get(ymatrix, b) for b in row_batches]
end
function collate(model::NeuralNetworkBinaryClassifier, X, y)
row_batches = Base.Iterators.partition(1:nrows(y), model.batch_size)
Xmatrix = reformat(X)
yvec = (y .== classes(y)[2])' # convert to boolean
return [_get(Xmatrix, b) for b in row_batches], [_get(yvec, b) for b in row_batches]
end
Loading

0 comments on commit dd00443

Please sign in to comment.