From f60f789acb0b02452565a4654619f133be6fd60b Mon Sep 17 00:00:00 2001 From: Essam Date: Sun, 1 Sep 2024 13:52:25 -0500 Subject: [PATCH] =?UTF-8?q?=E2=9C=85=20Get=20rid=20of=20case=20distinction?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/classifier.jl | 2 - src/encoders.jl | 2 +- src/entity_embedding_utils.jl | 9 ++-- src/image.jl | 3 +- src/mlj_model_interface.jl | 82 ++++++++++++++++------------------ src/regressor.jl | 2 - test/entity_embedding_utils.jl | 16 ------- 7 files changed, 47 insertions(+), 69 deletions(-) diff --git a/src/classifier.jl b/src/classifier.jl index 5fdf152..c690c61 100644 --- a/src/classifier.jl +++ b/src/classifier.jl @@ -13,7 +13,6 @@ function MLJFlux.shape(model::NeuralNetworkClassifier, X, y) n_input = Tables.schema(X).names |> length return (n_input, n_output) end -is_embedding_enabled_type(::NeuralNetworkClassifier) = true # builds the end-to-end Flux chain needed, given the `model` and `shape`: MLJFlux.build( @@ -60,7 +59,6 @@ function MLJFlux.shape(model::NeuralNetworkBinaryClassifier, X, y) n_input = Tables.schema(X).names |> length return (n_input, 1) # n_output is always 1 for a binary classifier end -is_embedding_enabled_type(::NeuralNetworkBinaryClassifier) = true function MLJModelInterface.predict( model::NeuralNetworkBinaryClassifier, diff --git a/src/encoders.jl b/src/encoders.jl index be835b6..de51714 100644 --- a/src/encoders.jl +++ b/src/encoders.jl @@ -120,7 +120,7 @@ each level in each categorical columns using the columns of the matrix. This is used with the embedding matrices of the entity embedding layer in entity enabled models to implement entity embeddings. """ function embedding_transform(X, mapping_matrices) - isnothing(mapping_matrices) && return X + (isempty(mapping_matrices)) && return X feat_names = Tables.schema(X).names new_feat_names = Symbol[] new_cols = [] diff --git a/src/entity_embedding_utils.jl b/src/entity_embedding_utils.jl index c6b0fbb..4a68a3a 100644 --- a/src/entity_embedding_utils.jl +++ b/src/entity_embedding_utils.jl @@ -63,7 +63,8 @@ Returns the number of levels in each categorical column in the table `X`. function get_num_levels(X, cat_inds) num_levels = [] for i in cat_inds - num_levels = push!(num_levels, length(levels(Tables.getcolumn(Tables.columns(X), i)))) + num_levels = + push!(num_levels, length(levels(Tables.getcolumn(Tables.columns(X), i)))) end return num_levels end @@ -79,7 +80,8 @@ function prepare_entityembs(X, featnames, cat_inds, embedding_dims) i in eachindex(cat_inds) ] # 2. Compute entityemb_output_dim - entityemb_output_dim = sum(newdims) + numfeats - length(cat_inds) + sum_newdims = length(newdims) == 0 ? 0 : sum(newdims) + entityemb_output_dim = sum_newdims + numfeats - length(cat_inds) return entityprops, entityemb_output_dim end @@ -125,7 +127,8 @@ function MLJModelInterface.transform( fitresult, Xnew, ) - is_embedding_enabled_type(transformer) || return Xnew + # if it doesn't have the property its not an entity-enabled model + hasproperty(transformer, :embedding_dims) || return Xnew ordinal_mappings, embedding_matrices = fitresult[3:4] Xnew = ordinal_encoder_transform(Xnew, ordinal_mappings) Xnew_transf = embedding_transform(Xnew, embedding_matrices) diff --git a/src/image.jl b/src/image.jl index 012325f..251941f 100644 --- a/src/image.jl +++ b/src/image.jl @@ -10,14 +10,13 @@ function shape(model::ImageClassifier, X, y) end return (n_input, n_output, n_channels) end -is_embedding_enabled_type(::ImageClassifier) = false build(model::ImageClassifier, rng, shape) = Flux.Chain(build(model.builder, rng, shape...), model.finaliser) -fitresult(model::ImageClassifier, chain, y) = +fitresult(model::ImageClassifier, chain, y, ordinal_mappings=nothing, embedding_matrices=nothing) = (chain, MLJModelInterface.classes(y[1])) function MLJModelInterface.predict(model::ImageClassifier, fitresult, Xnew) diff --git a/src/mlj_model_interface.jl b/src/mlj_model_interface.jl index 39e4421..5edef6f 100644 --- a/src/mlj_model_interface.jl +++ b/src/mlj_model_interface.jl @@ -56,8 +56,7 @@ end # # FIT AND UPDATE -const ERR_BUILDER = - "Builder does not appear to build an architecture compatible with supplied data. " +const ERR_BUILDER = "Builder does not appear to build an architecture compatible with supplied data. " true_rng(model) = model.rng isa Integer ? Random.Xoshiro(model.rng) : model.rng @@ -76,19 +75,22 @@ function MLJModelInterface.fit(model::MLJFluxModel, pure_continuous_input = isempty(cat_inds) # Decide whether to enable entity embeddings (e.g., ImageClassifier won't) - enable_entity_embs = is_embedding_enabled_type(model) && !pure_continuous_input + enable_entity_embs = hasproperty(model, :embedding_dims) && !pure_continuous_input # Prepare entity embeddings inputs and encode X if entity embeddings enabled + featnames = [] if enable_entity_embs X = convert_to_table(X) featnames = Tables.schema(X).names - # entityprops is (index = cat_inds[i], levels = num_levels[i], newdim = newdims[i]) - # for each categorical feature - entityprops, entityemb_output_dim = - prepare_entityembs(X, featnames, cat_inds, model.embedding_dims) - X, ordinal_mappings = ordinal_encoder_fit_transform(X; featinds = cat_inds) end + # entityprops is (index = cat_inds[i], levels = num_levels[i], newdim = newdims[i]) + # for each categorical feature + default_embedding_dims = enable_entity_embs ? model.embedding_dims : Dict{Symbol, Real}() + entityprops, entityemb_output_dim = + prepare_entityembs(X, featnames, cat_inds, default_embedding_dims) + X, ordinal_mappings = ordinal_encoder_fit_transform(X; featinds = cat_inds) + ## Construct model chain chain = (!enable_entity_embs) ? construct_model_chain(model, rng, shape, move) : @@ -122,6 +124,9 @@ function MLJModelInterface.fit(model::MLJFluxModel, data[2], ) + # Extract embedding matrices + embedding_matrices = get_embedding_matrices(chain, cat_inds, featnames) + # Prepare cache for potential warm restarts cache = ( deepcopy(model), @@ -132,27 +137,19 @@ function MLJModelInterface.fit(model::MLJFluxModel, optimiser_state, deepcopy(rng), move, + entityprops, + entityemb_output_dim, + ordinal_mappings, + featnames, ) - # Extract embedding matrices - enable_entity_embs && - (embedding_matrices = get_embedding_matrices(chain, cat_inds, featnames)) - # Prepare fitresult - fitresult_args = (model, Flux.cpu(chain), y) + fitresult = + MLJFlux.fitresult(model, Flux.cpu(chain), y, ordinal_mappings, embedding_matrices) # Prepare report report = (training_losses = history,) - # Modify cache and fitresult if entity embeddings enabled - if enable_entity_embs - cache = (cache..., entityprops, entityemb_output_dim, ordinal_mappings, featnames) - fitresult = - MLJFlux.fitresult(fitresult_args..., ordinal_mappings, embedding_matrices) - else - fitresult = MLJFlux.fitresult(fitresult_args...,) - end - return fitresult, cache, report end @@ -165,15 +162,22 @@ function MLJModelInterface.update(model::MLJFluxModel, # Decide whether to enable entity embeddings (e.g., ImageClassifier won't) cat_inds = get_cat_inds(X) pure_continuous_input = (length(cat_inds) == 0) - enable_entity_embs = is_embedding_enabled_type(model) && !pure_continuous_input + enable_entity_embs = hasproperty(model, :embedding_dims) && !pure_continuous_input # Unpack cache from previous fit - old_model, data, old_history, shape, regularized_optimiser, optimiser_state, rng, move = - old_cache[1:8] - if enable_entity_embs - entityprops, entityemb_output_dim, ordinal_mappings, featnames = old_cache[9:12] - cat_inds = [prop.index for prop in entityprops] - end + old_model, + data, + old_history, + shape, + regularized_optimiser, + optimiser_state, + rng, + move, + entityprops, + entityemb_output_dim, + ordinal_mappings, + featnames = old_cache + cat_inds = [prop.index for prop in entityprops] # Extract chain old_chain = old_fitresult[1] @@ -196,6 +200,8 @@ function MLJModelInterface.update(model::MLJFluxModel, else move = Mover(model.acceleration) rng = true_rng(model) + X = convert_to_table(X) + X = ordinal_encoder_transform(X, ordinal_mappings) if enable_entity_embs chain = construct_model_chain_with_entityembs( @@ -206,8 +212,6 @@ function MLJModelInterface.update(model::MLJFluxModel, entityprops, entityemb_output_dim, ) - X = convert_to_table(X) - X = ordinal_encoder_transform(X, ordinal_mappings) else chain = construct_model_chain(model, rng, shape, move) end @@ -237,8 +241,7 @@ function MLJModelInterface.update(model::MLJFluxModel, end # Extract embedding matrices - enable_entity_embs && - (embedding_matrices = get_embedding_matrices(chain, cat_inds, featnames)) + embedding_matrices = get_embedding_matrices(chain, cat_inds, featnames) # Prepare cache, fitresult, and report cache = ( @@ -250,17 +253,10 @@ function MLJModelInterface.update(model::MLJFluxModel, optimiser_state, deepcopy(rng), move, + entityprops, entityemb_output_dim, ordinal_mappings, featnames, ) - - fitresult_args = (model, Flux.cpu(chain), y) - if enable_entity_embs - cache = (cache..., entityprops, entityemb_output_dim, ordinal_mappings, featnames) - fitresult = - MLJFlux.fitresult(fitresult_args..., ordinal_mappings, embedding_matrices) - else - fitresult = MLJFlux.fitresult(fitresult_args...) - end - + fitresult = + MLJFlux.fitresult(model, Flux.cpu(chain), y, ordinal_mappings, embedding_matrices) report = (training_losses = history,) return fitresult, cache, report diff --git a/src/regressor.jl b/src/regressor.jl index 41cc4fb..9b8411a 100644 --- a/src/regressor.jl +++ b/src/regressor.jl @@ -12,7 +12,6 @@ function shape(model::NeuralNetworkRegressor, X, y) return (n_input, 1) end -is_embedding_enabled_type(::NeuralNetworkRegressor) = true build(model::NeuralNetworkRegressor, rng, shape) = build(model.builder, rng, shape...) @@ -50,7 +49,6 @@ A private method that returns the shape of the input and output of the model for data `X` and `y`. """ shape(model::MultitargetNeuralNetworkRegressor, X, y) = (ncols(X), ncols(y)) -is_embedding_enabled_type(::MultitargetNeuralNetworkRegressor) = true build(model::MultitargetNeuralNetworkRegressor, rng, shape) = build(model.builder, rng, shape...) diff --git a/test/entity_embedding_utils.jl b/test/entity_embedding_utils.jl index 3b94ca2..56983a0 100644 --- a/test/entity_embedding_utils.jl +++ b/test/entity_embedding_utils.jl @@ -1,20 +1,4 @@ -@testset "Embedding Enabled Types" begin - clf = MLJFlux.NeuralNetworkClassifier( - builder = MLJFlux.Short(n_hidden = 5, dropout = 0.2), - optimiser = Optimisers.Adam(0.01), - batch_size = 8, - epochs = 100, - ) - @test MLJFlux.is_embedding_enabled_type(clf) - - clf = MLJFlux.ImageClassifier( - batch_size = 50, - epochs = 10, - rng = 123, - ) - @test !MLJFlux.is_embedding_enabled_type(clf) -end @testset "set_default_new_embedding_dim" begin