Skip to content

Commit

Permalink
👨‍💻 Refactor fit and update in mlj_model_interface
Browse files Browse the repository at this point in the history
  • Loading branch information
EssamWisam committed Aug 5, 2024
1 parent afab1e0 commit 86fa74c
Show file tree
Hide file tree
Showing 3 changed files with 349 additions and 81 deletions.
148 changes: 148 additions & 0 deletions src/entity_embedding_utils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
"""
A file containing functions or constants used in the `fit` and `update` methods in `mlj_model_interface.jl` for setups supporting entity embeddings
"""

EMBEDDING_ENABLED_MODELS = [
NeuralNetworkClassifier,
NeuralNetworkBinaryClassifier,
NeuralNetworkRegressor,
MultitargetNeuralNetworkRegressor,
]

EMBEDDING_ENABLED_MODELS_UNION = Union{EMBEDDING_ENABLED_MODELS...}


# A function to check if a model is among those in EMBEDDING_ENABLED_MODELS
function is_embedding_enabled_type(model_type)
return any(model_type <: T for T in EMBEDDING_ENABLED_MODELS)
end

# function to set default new embedding dimension
function set_default_new_embedding_dim(numlevels)
# Either min_ratio or max_ratio of numlevels depending on >= threshold or < threshold
min_ratio, max_ratio = 0.2, 0.5
threshold = 20
return ceil(Int, ((numlevels >= threshold) ? min_ratio : max_ratio) * numlevels)
end

MISMATCH_INDS(wrong_feats) =
"Features $(join(wrong_feats, ", ")) were specified in embedding_dims hyperparameter but were not recognized as categorical variables because their scitypes are not `Multiclass` or `OrderedFactor`."
function check_mismatch_in_cat_feats(featnames, cat_inds, specified_featinds)
wrong_feats = [featnames[i] for i in specified_featinds if !(i in cat_inds)]
length(wrong_feats) > 0 && throw(ArgumentError(MISMATCH_INDS(wrong_feats)))
end

# function to set new embedding dimensions
function set_new_embedding_dims(featnames, cat_inds, num_levels, embedding_dims)
specified_featnames = keys(embedding_dims)
specified_featinds =
[i for i in 1:length(featnames) if featnames[i] in specified_featnames]
check_mismatch_in_cat_feats(featnames, cat_inds, specified_featinds)
catind2numlevels = Dict(zip(cat_inds, num_levels))
# for each value of embedding dim if float then multiply it by the number of levels
for featname in specified_featnames
if embedding_dims[featname] isa AbstractFloat
embedding_dims[featname] = ceil(
Int,
embedding_dims[featname] *
catind2numlevels[findfirst(x -> x == featname, featnames)],
)
end
end
newdims = [
(cat_ind in specified_featinds) ? embedding_dims[featnames[cat_ind]] :
set_default_new_embedding_dim(num_levels[i]) for
(i, cat_ind) in enumerate(cat_inds)
]
return newdims
end


"""
**Private Method**
Returns the indices of the categorical columns in the table `X`.
"""
function get_cat_inds(X)
# if input is a matrix; conclude no categorical columns
Tables.istable(X) || return Int[]
types = schema(X).scitypes
cat_inds = findall(x -> x <: Finite, types)
return cat_inds
end

"""
**Private Method**
Returns the number of levels in each categorical column in the table `X`.
"""
function get_num_levels(X, cat_inds)
num_levels = []
for i in cat_inds
num_levels = push!(num_levels, length(levels(Tables.getcolumn(X, i))))
end
return num_levels
end

# A function to prepare the inputs for entity embeddings layer
function prepare_entityembs(X, featnames, cat_inds, embedding_dims)
# 1. Construct entityprops
numfeats = length(featnames)
num_levels = get_num_levels(X, cat_inds)
newdims = set_new_embedding_dims(featnames, cat_inds, num_levels, embedding_dims)
entityprops = [
(index = cat_inds[i], levels = num_levels[i], newdim = newdims[i]) for
i in eachindex(cat_inds)
]
# 2. Compute entityemb_output_dim
entityemb_output_dim = sum(newdims) + numfeats - length(cat_inds)
return entityprops, entityemb_output_dim
end

# A function to construct model chain including entity embeddings as the first layer
function construct_model_chain_with_entityembs(
model,
rng,
shape,
move,
entityprops,
entityemb_output_dim,
)
chain = try
Flux.Chain(
EntityEmbedder(entityprops, shape[1]; init = Flux.glorot_uniform(rng)),
build(model, rng, (entityemb_output_dim, shape[2])),
) |> move
catch ex
@error ERR_BUILDER
rethrow()
end
return chain
end


# A function that given a model chain, returns a dictionary of embedding matrices
function get_embedding_matrices(chain, cat_inds, featnames)
embedder_layer = chain.layers[1]
embedding_matrices = Dict{Symbol, Matrix{Float32}}()
for cat_ind in cat_inds
featname = featnames[cat_ind]
matrix = Flux.params(embedder_layer.embedders[cat_ind])[1]
embedding_matrices[featname] = matrix
end
return embedding_matrices
end



# Transformer for entity-enabled models
function MLJModelInterface.transform(
transformer::EMBEDDING_ENABLED_MODELS_UNION,
fitresult,
Xnew,
)
ordinal_mappings, embedding_matrices = fitresult[3:4]
Xnew = ordinal_encoder_transform(Xnew, ordinal_mappings)
Xnew_transf = embedding_transform(Xnew, embedding_matrices)
return Xnew_transf
end
68 changes: 68 additions & 0 deletions src/fit_utils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
"""
A file containing functions used in the `fit` and `update` methods in `mlj_model_interface.jl`
"""

# Converts input to table if it's a matrix
convert_to_table(X) = X isa Matrix ? Tables.table(X) : X


# Construct model chain and throws error if it fails
function construct_model_chain(model, rng, shape, move)
chain = try
build(model, rng, shape) |> move
catch ex
@error ERR_BUILDER
rethrow()
end
return chain
end

# Test whether constructed chain works else throws error
function test_chain_works(x, chain)
try
chain(x)
catch ex
@error ERR_BUILDER
throw(ex)
end
end

# Models implement L1/L2 regularization by chaining the chosen optimiser with weight/sign
# decay. Note that the weight/sign decay must be scaled down by the number of batches to
# ensure penalization over an epoch does not scale with the choice of batch size; see
# https://github.com/FluxML/MLJFlux.jl/issues/213.

function regularized_optimiser(model, nbatches)
model.lambda == 0 && return model.optimiser
λ_L1 = model.alpha * model.lambda
λ_L2 = (1 - model.alpha) * model.lambda
λ_sign = λ_L1 / nbatches
λ_weight = 2 * λ_L2 / nbatches

# recall components in an optimiser chain are executed from left to right:
if model.alpha == 0
return Optimisers.OptimiserChain(
Optimisers.WeightDecay(λ_weight),
model.optimiser,
)
elseif model.alpha == 1
return Optimisers.OptimiserChain(
Optimisers.SignDecay(λ_sign),
model.optimiser,
)
else
return Optimisers.OptimiserChain(
Optimisers.SignDecay(λ_sign),
Optimisers.WeightDecay(λ_weight),
model.optimiser,
)
end
end

# Prepares optimiser for training
function prepare_optimiser(data, model, chain)
nbatches = length(data[2])
regularized_optimiser = MLJFlux.regularized_optimiser(model, nbatches)
optimiser_state = Optimisers.setup(regularized_optimiser, chain)
return regularized_optimiser, optimiser_state
end
Loading

0 comments on commit 86fa74c

Please sign in to comment.