From 86fa74c3ee0a00f6c345a78d906f2530c0f61d58 Mon Sep 17 00:00:00 2001 From: Essam Date: Mon, 5 Aug 2024 18:30:57 -0500 Subject: [PATCH] =?UTF-8?q?=F0=9F=91=A8=E2=80=8D=F0=9F=92=BB=20Refactor=20?= =?UTF-8?q?fit=20and=20update=20in=20mlj=5Fmodel=5Finterface?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/entity_embedding_utils.jl | 148 +++++++++++++++++++++++ src/fit_utils.jl | 68 +++++++++++ src/mlj_model_interface.jl | 214 +++++++++++++++++++++------------- 3 files changed, 349 insertions(+), 81 deletions(-) create mode 100644 src/entity_embedding_utils.jl create mode 100644 src/fit_utils.jl diff --git a/src/entity_embedding_utils.jl b/src/entity_embedding_utils.jl new file mode 100644 index 00000000..fe1a9135 --- /dev/null +++ b/src/entity_embedding_utils.jl @@ -0,0 +1,148 @@ +""" +A file containing functions or constants used in the `fit` and `update` methods in `mlj_model_interface.jl` for setups supporting entity embeddings +""" + +EMBEDDING_ENABLED_MODELS = [ + NeuralNetworkClassifier, + NeuralNetworkBinaryClassifier, + NeuralNetworkRegressor, + MultitargetNeuralNetworkRegressor, +] + +EMBEDDING_ENABLED_MODELS_UNION = Union{EMBEDDING_ENABLED_MODELS...} + + +# A function to check if a model is among those in EMBEDDING_ENABLED_MODELS +function is_embedding_enabled_type(model_type) + return any(model_type <: T for T in EMBEDDING_ENABLED_MODELS) +end + +# function to set default new embedding dimension +function set_default_new_embedding_dim(numlevels) + # Either min_ratio or max_ratio of numlevels depending on >= threshold or < threshold + min_ratio, max_ratio = 0.2, 0.5 + threshold = 20 + return ceil(Int, ((numlevels >= threshold) ? min_ratio : max_ratio) * numlevels) +end + +MISMATCH_INDS(wrong_feats) = + "Features $(join(wrong_feats, ", ")) were specified in embedding_dims hyperparameter but were not recognized as categorical variables because their scitypes are not `Multiclass` or `OrderedFactor`." +function check_mismatch_in_cat_feats(featnames, cat_inds, specified_featinds) + wrong_feats = [featnames[i] for i in specified_featinds if !(i in cat_inds)] + length(wrong_feats) > 0 && throw(ArgumentError(MISMATCH_INDS(wrong_feats))) +end + +# function to set new embedding dimensions +function set_new_embedding_dims(featnames, cat_inds, num_levels, embedding_dims) + specified_featnames = keys(embedding_dims) + specified_featinds = + [i for i in 1:length(featnames) if featnames[i] in specified_featnames] + check_mismatch_in_cat_feats(featnames, cat_inds, specified_featinds) + catind2numlevels = Dict(zip(cat_inds, num_levels)) + # for each value of embedding dim if float then multiply it by the number of levels + for featname in specified_featnames + if embedding_dims[featname] isa AbstractFloat + embedding_dims[featname] = ceil( + Int, + embedding_dims[featname] * + catind2numlevels[findfirst(x -> x == featname, featnames)], + ) + end + end + newdims = [ + (cat_ind in specified_featinds) ? embedding_dims[featnames[cat_ind]] : + set_default_new_embedding_dim(num_levels[i]) for + (i, cat_ind) in enumerate(cat_inds) + ] + return newdims +end + + +""" +**Private Method** + +Returns the indices of the categorical columns in the table `X`. +""" +function get_cat_inds(X) + # if input is a matrix; conclude no categorical columns + Tables.istable(X) || return Int[] + types = schema(X).scitypes + cat_inds = findall(x -> x <: Finite, types) + return cat_inds +end + +""" +**Private Method** + +Returns the number of levels in each categorical column in the table `X`. +""" +function get_num_levels(X, cat_inds) + num_levels = [] + for i in cat_inds + num_levels = push!(num_levels, length(levels(Tables.getcolumn(X, i)))) + end + return num_levels +end + +# A function to prepare the inputs for entity embeddings layer +function prepare_entityembs(X, featnames, cat_inds, embedding_dims) + # 1. Construct entityprops + numfeats = length(featnames) + num_levels = get_num_levels(X, cat_inds) + newdims = set_new_embedding_dims(featnames, cat_inds, num_levels, embedding_dims) + entityprops = [ + (index = cat_inds[i], levels = num_levels[i], newdim = newdims[i]) for + i in eachindex(cat_inds) + ] + # 2. Compute entityemb_output_dim + entityemb_output_dim = sum(newdims) + numfeats - length(cat_inds) + return entityprops, entityemb_output_dim +end + +# A function to construct model chain including entity embeddings as the first layer +function construct_model_chain_with_entityembs( + model, + rng, + shape, + move, + entityprops, + entityemb_output_dim, +) + chain = try + Flux.Chain( + EntityEmbedder(entityprops, shape[1]; init = Flux.glorot_uniform(rng)), + build(model, rng, (entityemb_output_dim, shape[2])), + ) |> move + catch ex + @error ERR_BUILDER + rethrow() + end + return chain +end + + +# A function that given a model chain, returns a dictionary of embedding matrices +function get_embedding_matrices(chain, cat_inds, featnames) + embedder_layer = chain.layers[1] + embedding_matrices = Dict{Symbol, Matrix{Float32}}() + for cat_ind in cat_inds + featname = featnames[cat_ind] + matrix = Flux.params(embedder_layer.embedders[cat_ind])[1] + embedding_matrices[featname] = matrix + end + return embedding_matrices +end + + + +# Transformer for entity-enabled models +function MLJModelInterface.transform( + transformer::EMBEDDING_ENABLED_MODELS_UNION, + fitresult, + Xnew, +) + ordinal_mappings, embedding_matrices = fitresult[3:4] + Xnew = ordinal_encoder_transform(Xnew, ordinal_mappings) + Xnew_transf = embedding_transform(Xnew, embedding_matrices) + return Xnew_transf +end \ No newline at end of file diff --git a/src/fit_utils.jl b/src/fit_utils.jl new file mode 100644 index 00000000..b2062791 --- /dev/null +++ b/src/fit_utils.jl @@ -0,0 +1,68 @@ +""" +A file containing functions used in the `fit` and `update` methods in `mlj_model_interface.jl` +""" + +# Converts input to table if it's a matrix +convert_to_table(X) = X isa Matrix ? Tables.table(X) : X + + +# Construct model chain and throws error if it fails +function construct_model_chain(model, rng, shape, move) + chain = try + build(model, rng, shape) |> move + catch ex + @error ERR_BUILDER + rethrow() + end + return chain +end + +# Test whether constructed chain works else throws error +function test_chain_works(x, chain) + try + chain(x) + catch ex + @error ERR_BUILDER + throw(ex) + end +end + +# Models implement L1/L2 regularization by chaining the chosen optimiser with weight/sign +# decay. Note that the weight/sign decay must be scaled down by the number of batches to +# ensure penalization over an epoch does not scale with the choice of batch size; see +# https://github.com/FluxML/MLJFlux.jl/issues/213. + +function regularized_optimiser(model, nbatches) + model.lambda == 0 && return model.optimiser + λ_L1 = model.alpha * model.lambda + λ_L2 = (1 - model.alpha) * model.lambda + λ_sign = λ_L1 / nbatches + λ_weight = 2 * λ_L2 / nbatches + + # recall components in an optimiser chain are executed from left to right: + if model.alpha == 0 + return Optimisers.OptimiserChain( + Optimisers.WeightDecay(λ_weight), + model.optimiser, + ) + elseif model.alpha == 1 + return Optimisers.OptimiserChain( + Optimisers.SignDecay(λ_sign), + model.optimiser, + ) + else + return Optimisers.OptimiserChain( + Optimisers.SignDecay(λ_sign), + Optimisers.WeightDecay(λ_weight), + model.optimiser, + ) + end +end + +# Prepares optimiser for training +function prepare_optimiser(data, model, chain) + nbatches = length(data[2]) + regularized_optimiser = MLJFlux.regularized_optimiser(model, nbatches) + optimiser_state = Optimisers.setup(regularized_optimiser, chain) + return regularized_optimiser, optimiser_state +end \ No newline at end of file diff --git a/src/mlj_model_interface.jl b/src/mlj_model_interface.jl index aa9850c4..88650a71 100644 --- a/src/mlj_model_interface.jl +++ b/src/mlj_model_interface.jl @@ -8,8 +8,8 @@ MLJModelInterface.deep_properties(::Type{<:MLJFluxModel}) = # # CLEAN METHOD const ERR_BAD_OPTIMISER = ArgumentError( - "Flux.jl optimiser detected. Only optimisers from Optimisers.jl are supported. "* - "For example, use `optimiser=Optimisers.Momentum()` after `import Optimisers`. " + "Flux.jl optimiser detected. Only optimisers from Optimisers.jl are supported. " * + "For example, use `optimiser=Optimisers.Momentum()` after `import Optimisers`. ", ) function MLJModelInterface.clean!(model::MLJFluxModel) @@ -19,8 +19,8 @@ function MLJModelInterface.clean!(model::MLJFluxModel) model.lambda = 0 end if model.alpha < 0 || model.alpha > 1 - warning *= "Need alpha in the interval `[0, 1]`. "* - "Resetting `alpha = 0`. " + warning *= "Need alpha in the interval `[0, 1]`. " * + "Resetting `alpha = 0`. " model.alpha = 0 end if model.epochs < 0 @@ -32,7 +32,8 @@ function MLJModelInterface.clean!(model::MLJFluxModel) model.batch_size = 1 end if model.acceleration isa CUDALibs && gpu_isdead() - warning *= "`acceleration isa CUDALibs` "* + warning *= + "`acceleration isa CUDALibs` " * "but no CUDA device (GPU) currently live. " end if !(model.acceleration isa CUDALibs || model.acceleration isa CPU1) @@ -40,9 +41,10 @@ function MLJModelInterface.clean!(model::MLJFluxModel) model.acceleration = CPU1() end if model.acceleration isa CUDALibs && model.rng isa Integer - warning *= "Specifying an RNG seed when "* - "`acceleration isa CUDALibs()` may fail for layers depending "* - "on an RNG during training, such as `Dropout`. Consider using "* + warning *= + "Specifying an RNG seed when " * + "`acceleration isa CUDALibs()` may fail for layers depending " * + "on an RNG during training, such as `Dropout`. Consider using " * " `Random.default_rng()` instead. `" end # TODO: This could be removed in next breaking release (0.6.0): @@ -53,74 +55,61 @@ end # # FIT AND UPDATE +include("fit_utils.jl") +include("entity_embedding_utils.jl") -const ERR_BUILDER = - "Builder does not appear to build an architecture compatible with supplied data. " +const ERR_BUILDER = "Builder does not appear to build an architecture compatible with supplied data. " true_rng(model) = model.rng isa Integer ? Random.Xoshiro(model.rng) : model.rng -# Models implement L1/L2 regularization by chaining the chosen optimiser with weight/sign -# decay. Note that the weight/sign decay must be scaled down by the number of batches to -# ensure penalization over an epoch does not scale with the choice of batch size; see -# https://github.com/FluxML/MLJFlux.jl/issues/213. - -function regularized_optimiser(model, nbatches) - model.lambda == 0 && return model.optimiser - λ_L1 = model.alpha*model.lambda - λ_L2 = (1 - model.alpha)*model.lambda - λ_sign = λ_L1/nbatches - λ_weight = 2*λ_L2/nbatches - - # recall components in an optimiser chain are executed from left to right: - if model.alpha == 0 - return Optimisers.OptimiserChain( - Optimisers.WeightDecay(λ_weight), - model.optimiser, - ) - elseif model.alpha == 1 - return Optimisers.OptimiserChain( - Optimisers.SignDecay(λ_sign), - model.optimiser, - ) - else return Optimisers.OptimiserChain( - Optimisers.SignDecay(λ_sign), - Optimisers.WeightDecay(λ_weight), - model.optimiser, - ) - end -end function MLJModelInterface.fit(model::MLJFluxModel, - verbosity, - X, - y) - + verbosity, + X, + y) + # GPU and rng related variables move = Mover(model.acceleration) - rng = true_rng(model) - shape = MLJFlux.shape(model, X, y) - chain = try - build(model, rng, shape) |> move - catch ex - @error ERR_BUILDER - rethrow() + # Get input properties + shape = MLJFlux.shape(model, X, y) + cat_inds = get_cat_inds(X) + pure_continuous_input = (length(cat_inds) == 0) + + # Decide whether to enable entity embeddings (e.g., ImageClassifier won't) + enable_entity_embs = is_embedding_enabled_type(typeof(model)) && !pure_continuous_input + + # Prepare entity embeddings inputs and encode X if entity embeddings enabled + if enable_entity_embs + X = convert_to_table(X) + featnames = Tables.schema(X).names + entityprops, entityemb_output_dim = + prepare_entityembs(X, featnames, cat_inds, model.embedding_dims) + X, ordinal_mappings = ordinal_encoder_fit_transform(X; featinds = cat_inds) end - data = move.(collate(model, X, y)) - x = data[1][1] + ## Construct model chain + chain = + (!enable_entity_embs) ? construct_model_chain(model, rng, shape, move) : + construct_model_chain_with_entityembs( + model, + rng, + shape, + move, + entityprops, + entityemb_output_dim, + ) - try - chain(x) - catch ex - @error ERR_BUILDER - throw(ex) - end + # Format data as needed by Flux and move to GPU + data = move.(collate(model, X, y)) - nbatches = length(data[2]) - regularized_optimiser = MLJFlux.regularized_optimiser(model, nbatches) - optimiser_state = Optimisers.setup(regularized_optimiser, chain) + # Test chain works (as it may be custom) + x = data[1][1] + test_chain_works(x, chain) + # Train model with Flux + regularized_optimiser, optimiser_state = + prepare_optimiser(data, model, chain) chain, optimiser_state, history = train( model, chain, @@ -132,6 +121,7 @@ function MLJModelInterface.fit(model::MLJFluxModel, data[2], ) + # Prepare cache for potential warm restarts cache = ( deepcopy(model), data, @@ -142,30 +132,62 @@ function MLJModelInterface.fit(model::MLJFluxModel, deepcopy(rng), move, ) - fitresult = MLJFlux.fitresult(model, Flux.cpu(chain), y) - report = (training_losses=history, ) + # Extract embedding matrices + enable_entity_embs && + (embedding_matrices = get_embedding_matrices(chain, cat_inds, featnames)) + + # Prepare fitresult + fitresult_args = (model, Flux.cpu(chain), y) + + # Prepare report + report = (training_losses = history,) + + # Modify cache and fitresult if entity embeddings enabled + if enable_entity_embs + cache = (cache..., entityprops, entityemb_output_dim, ordinal_mappings, featnames) + fitresult = + MLJFlux.fitresult(fitresult_args..., ordinal_mappings, embedding_matrices) + else + fitresult = MLJFlux.fitresult(fitresult_args...,) + end return fitresult, cache, report end function MLJModelInterface.update(model::MLJFluxModel, - verbosity, - old_fitresult, - old_cache, - X, - y) - - old_model, data, old_history, shape, regularized_optimiser, - optimiser_state, rng, move = old_cache + verbosity, + old_fitresult, + old_cache, + X, + y) + # Decide whether to enable entity embeddings (e.g., ImageClassifier won't) + cat_inds = get_cat_inds(X) + pure_continuous_input = (length(cat_inds) == 0) + enable_entity_embs = is_embedding_enabled_type(typeof(model)) && !pure_continuous_input + + # Unpack cache from previous fit + old_model, data, old_history, shape, regularized_optimiser, optimiser_state, rng, move = + old_cache[1:8] + if enable_entity_embs + entityprops, entityemb_output_dim, ordinal_mappings, featnames = old_cache[9:12] + cat_inds = [prop.index for prop in entityprops] + end + + # Extract chain old_chain = old_fitresult[1] - optimiser_flag = model.optimiser_changes_trigger_retraining && + # Decide whether optimiser should trigger retraining from scratch + optimiser_flag = + model.optimiser_changes_trigger_retraining && model.optimiser != old_model.optimiser - keep_chain = !optimiser_flag && model.epochs >= old_model.epochs && + # Decide whether to retrain from scratch + keep_chain = + !optimiser_flag && model.epochs >= old_model.epochs && MLJModelInterface.is_same_except(model, old_model, :optimiser, :epochs) + # Use old chain if not retraining from scratch or reconstruct and prepare to retrain if keep_chain chain = move(old_chain) epochs = model.epochs - old_model.epochs @@ -173,15 +195,29 @@ function MLJModelInterface.update(model::MLJFluxModel, else move = Mover(model.acceleration) rng = true_rng(model) - chain = build(model, rng, shape) |> move + if enable_entity_embs + chain = + construct_model_chain_with_entityembs( + model, + rng, + shape, + move, + entityprops, + entityemb_output_dim, + ) + X = convert_to_table(X) + X = ordinal_encoder_transform(X, ordinal_mappings) + else + chain = construct_model_chain(model, rng, shape, move) + end # reset `optimiser_state`: data = move.(collate(model, X, y)) - nbatches = length(data[2]) - regularized_optimiser = MLJFlux.regularized_optimiser(model, nbatches) - optimiser_state = Optimisers.setup(regularized_optimiser, chain) + regularized_optimiser, optimiser_state = + prepare_optimiser(data, model, chain) epochs = model.epochs end + # Train model with Flux chain, optimiser_state, history = train( model, chain, @@ -192,12 +228,18 @@ function MLJModelInterface.update(model::MLJFluxModel, data[1], data[2], ) + + # Properly set history if keep_chain # note: history[1] = old_history[end] history = vcat(old_history[1:end-1], history) end - fitresult = MLJFlux.fitresult(model, Flux.cpu(chain), y) + # Extract embedding matrices + enable_entity_embs && + (embedding_matrices = get_embedding_matrices(chain, cat_inds, featnames)) + + # Prepare cache, fitresult, and report cache = ( deepcopy(model), data, @@ -208,14 +250,24 @@ function MLJModelInterface.update(model::MLJFluxModel, deepcopy(rng), move, ) - report = (training_losses=history, ) + + fitresult_args = (model, Flux.cpu(chain), y) + if enable_entity_embs + cache = (cache..., entityprops, entityemb_output_dim, ordinal_mappings, featnames) + fitresult = + MLJFlux.fitresult(fitresult_args..., ordinal_mappings, embedding_matrices) + else + fitresult = MLJFlux.fitresult(fitresult_args...) + end + + report = (training_losses = history,) return fitresult, cache, report end MLJModelInterface.fitted_params(::MLJFluxModel, fitresult) = - (chain=fitresult[1],) + (chain = fitresult[1],) # # SUPPORT FOR MLJ ITERATION API