diff --git a/src/classifier.jl b/src/classifier.jl index c690c61..5fdf152 100644 --- a/src/classifier.jl +++ b/src/classifier.jl @@ -13,6 +13,7 @@ function MLJFlux.shape(model::NeuralNetworkClassifier, X, y) n_input = Tables.schema(X).names |> length return (n_input, n_output) end +is_embedding_enabled_type(::NeuralNetworkClassifier) = true # builds the end-to-end Flux chain needed, given the `model` and `shape`: MLJFlux.build( @@ -59,6 +60,7 @@ function MLJFlux.shape(model::NeuralNetworkBinaryClassifier, X, y) n_input = Tables.schema(X).names |> length return (n_input, 1) # n_output is always 1 for a binary classifier end +is_embedding_enabled_type(::NeuralNetworkBinaryClassifier) = true function MLJModelInterface.predict( model::NeuralNetworkBinaryClassifier, diff --git a/src/entity_embedding_utils.jl b/src/entity_embedding_utils.jl index 7477431..c6b0fbb 100644 --- a/src/entity_embedding_utils.jl +++ b/src/entity_embedding_utils.jl @@ -1,21 +1,7 @@ """ A file containing functions or constants used in the `fit` and `update` methods in `mlj_model_interface.jl` for setups supporting entity embeddings """ - -EMBEDDING_ENABLED_MODELS = [ - NeuralNetworkClassifier, - NeuralNetworkBinaryClassifier, - NeuralNetworkRegressor, - MultitargetNeuralNetworkRegressor, -] - -EMBEDDING_ENABLED_MODELS_UNION = Union{EMBEDDING_ENABLED_MODELS...} - - -# A function to check if a model is among those in EMBEDDING_ENABLED_MODELS -function is_embedding_enabled_type(model_type) - return any(model_type <: T for T in EMBEDDING_ENABLED_MODELS) -end +is_empty_enabled_type(model) = false # function to set default new embedding dimension function set_default_new_embedding_dim(numlevels) @@ -135,10 +121,11 @@ end # Transformer for entity-enabled models function MLJModelInterface.transform( - transformer::EMBEDDING_ENABLED_MODELS_UNION, + transformer, fitresult, Xnew, ) + is_embedding_enabled_type(transformer) || return Xnew ordinal_mappings, embedding_matrices = fitresult[3:4] Xnew = ordinal_encoder_transform(Xnew, ordinal_mappings) Xnew_transf = embedding_transform(Xnew, embedding_matrices) diff --git a/src/image.jl b/src/image.jl index dc8d563..012325f 100644 --- a/src/image.jl +++ b/src/image.jl @@ -10,6 +10,8 @@ function shape(model::ImageClassifier, X, y) end return (n_input, n_output, n_channels) end +is_embedding_enabled_type(::ImageClassifier) = false + build(model::ImageClassifier, rng, shape) = Flux.Chain(build(model.builder, rng, shape...), diff --git a/src/mlj_model_interface.jl b/src/mlj_model_interface.jl index ced6344..39e4421 100644 --- a/src/mlj_model_interface.jl +++ b/src/mlj_model_interface.jl @@ -76,7 +76,7 @@ function MLJModelInterface.fit(model::MLJFluxModel, pure_continuous_input = isempty(cat_inds) # Decide whether to enable entity embeddings (e.g., ImageClassifier won't) - enable_entity_embs = is_embedding_enabled_type(typeof(model)) && !pure_continuous_input + enable_entity_embs = is_embedding_enabled_type(model) && !pure_continuous_input # Prepare entity embeddings inputs and encode X if entity embeddings enabled if enable_entity_embs @@ -165,7 +165,7 @@ function MLJModelInterface.update(model::MLJFluxModel, # Decide whether to enable entity embeddings (e.g., ImageClassifier won't) cat_inds = get_cat_inds(X) pure_continuous_input = (length(cat_inds) == 0) - enable_entity_embs = is_embedding_enabled_type(typeof(model)) && !pure_continuous_input + enable_entity_embs = is_embedding_enabled_type(model) && !pure_continuous_input # Unpack cache from previous fit old_model, data, old_history, shape, regularized_optimiser, optimiser_state, rng, move = diff --git a/src/regressor.jl b/src/regressor.jl index e55acbf..41cc4fb 100644 --- a/src/regressor.jl +++ b/src/regressor.jl @@ -12,6 +12,8 @@ function shape(model::NeuralNetworkRegressor, X, y) return (n_input, 1) end +is_embedding_enabled_type(::NeuralNetworkRegressor) = true + build(model::NeuralNetworkRegressor, rng, shape) = build(model.builder, rng, shape...) @@ -19,6 +21,7 @@ fitresult(model::NeuralNetworkRegressor, chain, y, ordinal_mappings=nothing, emb (chain, nothing, ordinal_mappings, embedding_matrices) + function MLJModelInterface.predict(model::NeuralNetworkRegressor, fitresult, Xnew) @@ -47,6 +50,7 @@ A private method that returns the shape of the input and output of the model for data `X` and `y`. """ shape(model::MultitargetNeuralNetworkRegressor, X, y) = (ncols(X), ncols(y)) +is_embedding_enabled_type(::MultitargetNeuralNetworkRegressor) = true build(model::MultitargetNeuralNetworkRegressor, rng, shape) = build(model.builder, rng, shape...) diff --git a/test/entity_embedding_utils.jl b/test/entity_embedding_utils.jl index d7aaaf9..3b94ca2 100644 --- a/test/entity_embedding_utils.jl +++ b/test/entity_embedding_utils.jl @@ -6,14 +6,14 @@ batch_size = 8, epochs = 100, ) - @test MLJFlux.is_embedding_enabled_type(typeof(clf)) + @test MLJFlux.is_embedding_enabled_type(clf) clf = MLJFlux.ImageClassifier( batch_size = 50, epochs = 10, rng = 123, ) - @test !MLJFlux.is_embedding_enabled_type(typeof(clf)) + @test !MLJFlux.is_embedding_enabled_type(clf) end