-
Notifications
You must be signed in to change notification settings - Fork 17
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
👨💻 Refactor fit and update in mlj_model_interface
- Loading branch information
1 parent
afab1e0
commit 86fa74c
Showing
3 changed files
with
349 additions
and
81 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,148 @@ | ||
""" | ||
A file containing functions or constants used in the `fit` and `update` methods in `mlj_model_interface.jl` for setups supporting entity embeddings | ||
""" | ||
|
||
EMBEDDING_ENABLED_MODELS = [ | ||
NeuralNetworkClassifier, | ||
NeuralNetworkBinaryClassifier, | ||
NeuralNetworkRegressor, | ||
MultitargetNeuralNetworkRegressor, | ||
] | ||
|
||
EMBEDDING_ENABLED_MODELS_UNION = Union{EMBEDDING_ENABLED_MODELS...} | ||
|
||
|
||
# A function to check if a model is among those in EMBEDDING_ENABLED_MODELS | ||
function is_embedding_enabled_type(model_type) | ||
return any(model_type <: T for T in EMBEDDING_ENABLED_MODELS) | ||
end | ||
|
||
# function to set default new embedding dimension | ||
function set_default_new_embedding_dim(numlevels) | ||
# Either min_ratio or max_ratio of numlevels depending on >= threshold or < threshold | ||
min_ratio, max_ratio = 0.2, 0.5 | ||
threshold = 20 | ||
return ceil(Int, ((numlevels >= threshold) ? min_ratio : max_ratio) * numlevels) | ||
end | ||
|
||
MISMATCH_INDS(wrong_feats) = | ||
"Features $(join(wrong_feats, ", ")) were specified in embedding_dims hyperparameter but were not recognized as categorical variables because their scitypes are not `Multiclass` or `OrderedFactor`." | ||
function check_mismatch_in_cat_feats(featnames, cat_inds, specified_featinds) | ||
wrong_feats = [featnames[i] for i in specified_featinds if !(i in cat_inds)] | ||
length(wrong_feats) > 0 && throw(ArgumentError(MISMATCH_INDS(wrong_feats))) | ||
end | ||
|
||
# function to set new embedding dimensions | ||
function set_new_embedding_dims(featnames, cat_inds, num_levels, embedding_dims) | ||
specified_featnames = keys(embedding_dims) | ||
specified_featinds = | ||
[i for i in 1:length(featnames) if featnames[i] in specified_featnames] | ||
check_mismatch_in_cat_feats(featnames, cat_inds, specified_featinds) | ||
catind2numlevels = Dict(zip(cat_inds, num_levels)) | ||
# for each value of embedding dim if float then multiply it by the number of levels | ||
for featname in specified_featnames | ||
if embedding_dims[featname] isa AbstractFloat | ||
embedding_dims[featname] = ceil( | ||
Int, | ||
embedding_dims[featname] * | ||
catind2numlevels[findfirst(x -> x == featname, featnames)], | ||
) | ||
end | ||
end | ||
newdims = [ | ||
(cat_ind in specified_featinds) ? embedding_dims[featnames[cat_ind]] : | ||
set_default_new_embedding_dim(num_levels[i]) for | ||
(i, cat_ind) in enumerate(cat_inds) | ||
] | ||
return newdims | ||
end | ||
|
||
|
||
""" | ||
**Private Method** | ||
Returns the indices of the categorical columns in the table `X`. | ||
""" | ||
function get_cat_inds(X) | ||
# if input is a matrix; conclude no categorical columns | ||
Tables.istable(X) || return Int[] | ||
types = schema(X).scitypes | ||
cat_inds = findall(x -> x <: Finite, types) | ||
return cat_inds | ||
end | ||
|
||
""" | ||
**Private Method** | ||
Returns the number of levels in each categorical column in the table `X`. | ||
""" | ||
function get_num_levels(X, cat_inds) | ||
num_levels = [] | ||
for i in cat_inds | ||
num_levels = push!(num_levels, length(levels(Tables.getcolumn(X, i)))) | ||
end | ||
return num_levels | ||
end | ||
|
||
# A function to prepare the inputs for entity embeddings layer | ||
function prepare_entityembs(X, featnames, cat_inds, embedding_dims) | ||
# 1. Construct entityprops | ||
numfeats = length(featnames) | ||
num_levels = get_num_levels(X, cat_inds) | ||
newdims = set_new_embedding_dims(featnames, cat_inds, num_levels, embedding_dims) | ||
entityprops = [ | ||
(index = cat_inds[i], levels = num_levels[i], newdim = newdims[i]) for | ||
i in eachindex(cat_inds) | ||
] | ||
# 2. Compute entityemb_output_dim | ||
entityemb_output_dim = sum(newdims) + numfeats - length(cat_inds) | ||
return entityprops, entityemb_output_dim | ||
end | ||
|
||
# A function to construct model chain including entity embeddings as the first layer | ||
function construct_model_chain_with_entityembs( | ||
model, | ||
rng, | ||
shape, | ||
move, | ||
entityprops, | ||
entityemb_output_dim, | ||
) | ||
chain = try | ||
Flux.Chain( | ||
EntityEmbedder(entityprops, shape[1]; init = Flux.glorot_uniform(rng)), | ||
build(model, rng, (entityemb_output_dim, shape[2])), | ||
) |> move | ||
catch ex | ||
@error ERR_BUILDER | ||
rethrow() | ||
end | ||
return chain | ||
end | ||
|
||
|
||
# A function that given a model chain, returns a dictionary of embedding matrices | ||
function get_embedding_matrices(chain, cat_inds, featnames) | ||
embedder_layer = chain.layers[1] | ||
embedding_matrices = Dict{Symbol, Matrix{Float32}}() | ||
for cat_ind in cat_inds | ||
featname = featnames[cat_ind] | ||
matrix = Flux.params(embedder_layer.embedders[cat_ind])[1] | ||
embedding_matrices[featname] = matrix | ||
end | ||
return embedding_matrices | ||
end | ||
|
||
|
||
|
||
# Transformer for entity-enabled models | ||
function MLJModelInterface.transform( | ||
transformer::EMBEDDING_ENABLED_MODELS_UNION, | ||
fitresult, | ||
Xnew, | ||
) | ||
ordinal_mappings, embedding_matrices = fitresult[3:4] | ||
Xnew = ordinal_encoder_transform(Xnew, ordinal_mappings) | ||
Xnew_transf = embedding_transform(Xnew, embedding_matrices) | ||
return Xnew_transf | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
""" | ||
A file containing functions used in the `fit` and `update` methods in `mlj_model_interface.jl` | ||
""" | ||
|
||
# Converts input to table if it's a matrix | ||
convert_to_table(X) = X isa Matrix ? Tables.table(X) : X | ||
|
||
|
||
# Construct model chain and throws error if it fails | ||
function construct_model_chain(model, rng, shape, move) | ||
chain = try | ||
build(model, rng, shape) |> move | ||
catch ex | ||
@error ERR_BUILDER | ||
rethrow() | ||
end | ||
return chain | ||
end | ||
|
||
# Test whether constructed chain works else throws error | ||
function test_chain_works(x, chain) | ||
try | ||
chain(x) | ||
catch ex | ||
@error ERR_BUILDER | ||
throw(ex) | ||
end | ||
end | ||
|
||
# Models implement L1/L2 regularization by chaining the chosen optimiser with weight/sign | ||
# decay. Note that the weight/sign decay must be scaled down by the number of batches to | ||
# ensure penalization over an epoch does not scale with the choice of batch size; see | ||
# https://github.com/FluxML/MLJFlux.jl/issues/213. | ||
|
||
function regularized_optimiser(model, nbatches) | ||
model.lambda == 0 && return model.optimiser | ||
λ_L1 = model.alpha * model.lambda | ||
λ_L2 = (1 - model.alpha) * model.lambda | ||
λ_sign = λ_L1 / nbatches | ||
λ_weight = 2 * λ_L2 / nbatches | ||
|
||
# recall components in an optimiser chain are executed from left to right: | ||
if model.alpha == 0 | ||
return Optimisers.OptimiserChain( | ||
Optimisers.WeightDecay(λ_weight), | ||
model.optimiser, | ||
) | ||
elseif model.alpha == 1 | ||
return Optimisers.OptimiserChain( | ||
Optimisers.SignDecay(λ_sign), | ||
model.optimiser, | ||
) | ||
else | ||
return Optimisers.OptimiserChain( | ||
Optimisers.SignDecay(λ_sign), | ||
Optimisers.WeightDecay(λ_weight), | ||
model.optimiser, | ||
) | ||
end | ||
end | ||
|
||
# Prepares optimiser for training | ||
function prepare_optimiser(data, model, chain) | ||
nbatches = length(data[2]) | ||
regularized_optimiser = MLJFlux.regularized_optimiser(model, nbatches) | ||
optimiser_state = Optimisers.setup(regularized_optimiser, chain) | ||
return regularized_optimiser, optimiser_state | ||
end |
Oops, something went wrong.