-
Notifications
You must be signed in to change notification settings - Fork 17
/
classifier.jl
39 lines (34 loc) · 1.53 KB
/
classifier.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
# if `b` is a builder, then `b(model, rng, shape...)` is called to make a
# new chain, where `shape` is the return value of this method:
"""
shape(model::NeuralNetworkClassifier, X, y)
A private method that returns the shape of the input and output of the model for given data `X` and `y`.
"""
function MLJFlux.shape(model::NeuralNetworkClassifier, X, y)
X = X isa Matrix ? Tables.table(X) : X
levels = MLJModelInterface.classes(y[1])
n_output = length(levels)
n_input = Tables.schema(X).names |> length
return (n_input, n_output)
end
# builds the end-to-end Flux chain needed, given the `model` and `shape`:
MLJFlux.build(model::NeuralNetworkClassifier, rng, shape) =
Flux.Chain(build(model.builder, rng, shape...),
model.finaliser)
# returns the model `fitresult` (see "Adding Models for General Use"
# section of the MLJ manual) which must always have the form `(chain,
# metadata)`, where `metadata` is anything extra needed by `predict`:
MLJFlux.fitresult(model::NeuralNetworkClassifier, chain, y) =
(chain, MLJModelInterface.classes(y[1]))
function MLJModelInterface.predict(model::NeuralNetworkClassifier,
fitresult,
Xnew)
chain, levels = fitresult
X = reformat(Xnew)
probs = vcat([chain(tomat(X[:, i]))' for i in 1:size(X, 2)]...)
return MLJModelInterface.UnivariateFinite(levels, probs)
end
MLJModelInterface.metadata_model(NeuralNetworkClassifier,
input=Union{AbstractMatrix{Continuous},Table(Continuous)},
target=AbstractVector{<:Finite},
path="MLJFlux.NeuralNetworkClassifier")