Skip to content

Commit

Permalink
✅ Add trait for embedding model
Browse files Browse the repository at this point in the history
  • Loading branch information
EssamWisam committed Sep 1, 2024
1 parent 46577f7 commit e5d8141
Show file tree
Hide file tree
Showing 6 changed files with 15 additions and 20 deletions.
2 changes: 2 additions & 0 deletions src/classifier.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ function MLJFlux.shape(model::NeuralNetworkClassifier, X, y)
n_input = Tables.schema(X).names |> length
return (n_input, n_output)
end
is_embedding_enabled_type(::NeuralNetworkClassifier) = true

# builds the end-to-end Flux chain needed, given the `model` and `shape`:
MLJFlux.build(
Expand Down Expand Up @@ -59,6 +60,7 @@ function MLJFlux.shape(model::NeuralNetworkBinaryClassifier, X, y)
n_input = Tables.schema(X).names |> length
return (n_input, 1) # n_output is always 1 for a binary classifier
end
is_embedding_enabled_type(::NeuralNetworkBinaryClassifier) = true

function MLJModelInterface.predict(
model::NeuralNetworkBinaryClassifier,
Expand Down
19 changes: 3 additions & 16 deletions src/entity_embedding_utils.jl
Original file line number Diff line number Diff line change
@@ -1,21 +1,7 @@
"""
A file containing functions or constants used in the `fit` and `update` methods in `mlj_model_interface.jl` for setups supporting entity embeddings
"""

EMBEDDING_ENABLED_MODELS = [
NeuralNetworkClassifier,
NeuralNetworkBinaryClassifier,
NeuralNetworkRegressor,
MultitargetNeuralNetworkRegressor,
]

EMBEDDING_ENABLED_MODELS_UNION = Union{EMBEDDING_ENABLED_MODELS...}


# A function to check if a model is among those in EMBEDDING_ENABLED_MODELS
function is_embedding_enabled_type(model_type)
return any(model_type <: T for T in EMBEDDING_ENABLED_MODELS)
end
is_empty_enabled_type(model) = false

Check warning on line 4 in src/entity_embedding_utils.jl

View check run for this annotation

Codecov / codecov/patch

src/entity_embedding_utils.jl#L4

Added line #L4 was not covered by tests

# function to set default new embedding dimension
function set_default_new_embedding_dim(numlevels)
Expand Down Expand Up @@ -135,10 +121,11 @@ end

# Transformer for entity-enabled models
function MLJModelInterface.transform(
transformer::EMBEDDING_ENABLED_MODELS_UNION,
transformer,
fitresult,
Xnew,
)
is_embedding_enabled_type(transformer) || return Xnew
ordinal_mappings, embedding_matrices = fitresult[3:4]
Xnew = ordinal_encoder_transform(Xnew, ordinal_mappings)
Xnew_transf = embedding_transform(Xnew, embedding_matrices)
Expand Down
2 changes: 2 additions & 0 deletions src/image.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ function shape(model::ImageClassifier, X, y)
end
return (n_input, n_output, n_channels)
end
is_embedding_enabled_type(::ImageClassifier) = false


build(model::ImageClassifier, rng, shape) =
Flux.Chain(build(model.builder, rng, shape...),
Expand Down
4 changes: 2 additions & 2 deletions src/mlj_model_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ function MLJModelInterface.fit(model::MLJFluxModel,
pure_continuous_input = isempty(cat_inds)

# Decide whether to enable entity embeddings (e.g., ImageClassifier won't)
enable_entity_embs = is_embedding_enabled_type(typeof(model)) && !pure_continuous_input
enable_entity_embs = is_embedding_enabled_type(model) && !pure_continuous_input

# Prepare entity embeddings inputs and encode X if entity embeddings enabled
if enable_entity_embs
Expand Down Expand Up @@ -165,7 +165,7 @@ function MLJModelInterface.update(model::MLJFluxModel,
# Decide whether to enable entity embeddings (e.g., ImageClassifier won't)
cat_inds = get_cat_inds(X)
pure_continuous_input = (length(cat_inds) == 0)
enable_entity_embs = is_embedding_enabled_type(typeof(model)) && !pure_continuous_input
enable_entity_embs = is_embedding_enabled_type(model) && !pure_continuous_input

# Unpack cache from previous fit
old_model, data, old_history, shape, regularized_optimiser, optimiser_state, rng, move =
Expand Down
4 changes: 4 additions & 0 deletions src/regressor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,16 @@ function shape(model::NeuralNetworkRegressor, X, y)
return (n_input, 1)
end

is_embedding_enabled_type(::NeuralNetworkRegressor) = true

build(model::NeuralNetworkRegressor, rng, shape) =
build(model.builder, rng, shape...)

fitresult(model::NeuralNetworkRegressor, chain, y, ordinal_mappings=nothing, embedding_matrices=nothing) =
(chain, nothing, ordinal_mappings, embedding_matrices)



function MLJModelInterface.predict(model::NeuralNetworkRegressor,
fitresult,
Xnew)
Expand Down Expand Up @@ -47,6 +50,7 @@ A private method that returns the shape of the input and output of the model for
data `X` and `y`.
"""
shape(model::MultitargetNeuralNetworkRegressor, X, y) = (ncols(X), ncols(y))
is_embedding_enabled_type(::MultitargetNeuralNetworkRegressor) = true

build(model::MultitargetNeuralNetworkRegressor, rng, shape) =
build(model.builder, rng, shape...)
Expand Down
4 changes: 2 additions & 2 deletions test/entity_embedding_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@
batch_size = 8,
epochs = 100,
)
@test MLJFlux.is_embedding_enabled_type(typeof(clf))
@test MLJFlux.is_embedding_enabled_type(clf)

clf = MLJFlux.ImageClassifier(
batch_size = 50,
epochs = 10,
rng = 123,
)
@test !MLJFlux.is_embedding_enabled_type(typeof(clf))
@test !MLJFlux.is_embedding_enabled_type(clf)
end


Expand Down

0 comments on commit e5d8141

Please sign in to comment.