Skip to content

Commit

Permalink
fix defaults for NNBinaryClassifier constructor
Browse files Browse the repository at this point in the history
  • Loading branch information
ablaom committed Jun 10, 2024
1 parent dd00443 commit b7291f6
Show file tree
Hide file tree
Showing 3 changed files with 125 additions and 23 deletions.
51 changes: 33 additions & 18 deletions src/classifier.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
"""
shape(model::NeuralNetworkClassifier, X, y)
A private method that returns the shape of the input and output of the model for given data `X` and `y`.
A private method that returns the shape of the input and output of the model for given
data `X` and `y`.
"""
function MLJFlux.shape(model::NeuralNetworkClassifier, X, y)
X = X isa Matrix ? Tables.table(X) : X
Expand All @@ -14,29 +16,38 @@ function MLJFlux.shape(model::NeuralNetworkClassifier, X, y)
end

# builds the end-to-end Flux chain needed, given the `model` and `shape`:
MLJFlux.build(model::Union{NeuralNetworkClassifier, NeuralNetworkBinaryClassifier}, rng, shape) =
Flux.Chain(build(model.builder, rng, shape...),
model.finaliser)
MLJFlux.build(
model::Union{NeuralNetworkClassifier, NeuralNetworkBinaryClassifier},
rng,
shape,
) = Flux.Chain(build(model.builder, rng, shape...), model.finaliser)

# returns the model `fitresult` (see "Adding Models for General Use"
# section of the MLJ manual) which must always have the form `(chain,
# metadata)`, where `metadata` is anything extra needed by `predict`:
MLJFlux.fitresult(model::Union{NeuralNetworkClassifier, NeuralNetworkBinaryClassifier}, chain, y) =
(chain, MLJModelInterface.classes(y[1]))
MLJFlux.fitresult(
model::Union{NeuralNetworkClassifier, NeuralNetworkBinaryClassifier},
chain,
y,
) = (chain, MLJModelInterface.classes(y[1]))

function MLJModelInterface.predict(model::NeuralNetworkClassifier,
function MLJModelInterface.predict(
model::NeuralNetworkClassifier,
fitresult,
Xnew)
Xnew,
)
chain, levels = fitresult
X = reformat(Xnew)
probs = vcat([chain(tomat(X[:, i]))' for i in 1:size(X, 2)]...)
return MLJModelInterface.UnivariateFinite(levels, probs)
end

MLJModelInterface.metadata_model(NeuralNetworkClassifier,
input=Union{AbstractMatrix{Continuous},Table(Continuous)},
target=AbstractVector{<:Finite},
path="MLJFlux.NeuralNetworkClassifier")
MLJModelInterface.metadata_model(
NeuralNetworkClassifier,
input_scitype=Union{AbstractMatrix{Continuous},Table(Continuous)},
target_scitype=AbstractVector{<:Finite},
load_path="MLJFlux.NeuralNetworkClassifier",
)

#### Binary Classifier

Expand All @@ -46,16 +57,20 @@ function MLJFlux.shape(model::NeuralNetworkBinaryClassifier, X, y)
return (n_input, 1) # n_output is always 1 for a binary classifier
end

function MLJModelInterface.predict(model::NeuralNetworkBinaryClassifier,
function MLJModelInterface.predict(
model::NeuralNetworkBinaryClassifier,
fitresult,
Xnew)
Xnew,
)
chain, levels = fitresult
X = reformat(Xnew)
probs = vec(chain(X))
return MLJModelInterface.UnivariateFinite(levels, probs; augment = true)
end

MLJModelInterface.metadata_model(NeuralNetworkBinaryClassifier,
input=Union{AbstractMatrix{Continuous},Table(Continuous)},
target=AbstractVector{<:Finite{2}},
path="MLJFlux.NeuralNetworkBinaryClassifier")
MLJModelInterface.metadata_model(
NeuralNetworkBinaryClassifier,
input_scitype=Union{AbstractMatrix{Continuous},Table(Continuous)},
target_scitype=AbstractVector{<:Finite{2}},
load_path="MLJFlux.NeuralNetworkBinaryClassifier",
)
4 changes: 2 additions & 2 deletions src/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ for Model in [:NeuralNetworkClassifier, :NeuralNetworkBinaryClassifier, :ImageCl

function $Model(
;builder::B=$default_builder_ex,
finaliser::F=Flux.softmax,
finaliser::F=$default_finaliser,
optimiser::O=Optimisers.Adam(),
loss::L=Flux.crossentropy,
loss::L=$default_loss,
epochs=10,
batch_size=1,
lambda=0,
Expand Down
93 changes: 90 additions & 3 deletions test/classifier.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
## NEURAL NETWORK CLASSIFIER
# # NEURAL NETWORK CLASSIFIER

seed!(1234)
N = 300
Expand Down Expand Up @@ -59,7 +59,7 @@ losses = []
end
dist = MLJBase.UnivariateFinite(prob_given_class)
loss_baseline =
StatisticalMeasures.cross_entropy(fill(dist, length(test)), y[test]) |> mean
StatisticalMeasures.cross_entropy(fill(dist, length(test)), y[test])

# check flux model is an improvement on predicting constant
# distribution
Expand All @@ -76,7 +76,7 @@ losses = []
first_last_training_loss = MLJBase.report(mach)[1][[1, end]]
push!(losses, first_last_training_loss[2])
yhat = MLJBase.predict(mach, rows=test);
@test mean(StatisticalMeasures.cross_entropy(yhat, y[test])) < 0.95*loss_baseline
@test StatisticalMeasures.cross_entropy(yhat, y[test]) < 0.95*loss_baseline

optimisertest(MLJFlux.NeuralNetworkClassifier,
X,
Expand All @@ -91,4 +91,91 @@ end
reference = losses[1]
@test all(x->abs(x - reference)/reference < 1e-5, losses[2:end])


# # NEURAL NETWORK BINARY CLASSIFIER

@testset "NeuralNetworkBinaryClassifier constructor" begin
model = NeuralNetworkBinaryClassifier()
@test model.loss == Flux.binarycrossentropy
@test model.builder isa MLJFlux.Short
@test model.finaliser == Flux.σ
end

seed!(1234)
N = 300
X = MLJBase.table(rand(Float32, N, 4));
ycont = 2*X.x1 - X.x3 + 0.1*rand(N)
m, M = minimum(ycont), maximum(ycont)
_, a, _ = range(m, stop=M, length=3) |> collect
y = map(ycont) do η
if η < 0.9*a
'a'
else
'b'
end
end |> categorical;

builder = MLJFlux.MLP(hidden=(8,))
optimiser = Optimisers.Adam(0.03)

@testset_accelerated "NeuralNetworkBinaryClassifier" accel begin

# Table input:
@testset "Table input" begin
basictest(
MLJFlux.NeuralNetworkBinaryClassifier,
X,
y,
builder,
optimiser,
0.85,
accel,
)
end

# Matrix input:
@testset "Matrix input" begin
basictest(
MLJFlux.NeuralNetworkBinaryClassifier,
matrix(X),
y,
builder,
optimiser,
0.85,
accel,
)
end

train, test = MLJBase.partition(1:N, 0.7)

# baseline loss (predict constant probability distribution):
dict = StatsBase.countmap(y[train])
prob_given_class = Dict{CategoricalArrays.CategoricalValue,Float64}()
for (k, v) in dict
prob_given_class[k] = dict[k]/length(train)
end
dist = MLJBase.UnivariateFinite(prob_given_class)
loss_baseline =
StatisticalMeasures.cross_entropy(fill(dist, length(test)), y[test])

# check flux model is an improvement on predicting constant
# distribution
# (GPUs only support `default_rng`):
rng = Random.default_rng()
seed!(rng, 123)
model = MLJFlux.NeuralNetworkBinaryClassifier(
epochs=50,
builder=builder,
optimiser=optimiser,
acceleration=accel,
batch_size=10,
rng=rng,
)
@time mach = fit!(machine(model, X, y), rows=train, verbosity=0)
first_last_training_loss = MLJBase.report(mach)[1][[1, end]]
yhat = MLJBase.predict(mach, rows=test);
@test StatisticalMeasures.cross_entropy(yhat, y[test]) < 0.95*loss_baseline

end

true

0 comments on commit b7291f6

Please sign in to comment.