Skip to content

Commit

Permalink
✅ Better default for embedding dims
Browse files Browse the repository at this point in the history
  • Loading branch information
EssamWisam committed Sep 1, 2024
1 parent f1f7dfe commit daa4b2a
Show file tree
Hide file tree
Showing 5 changed files with 12 additions and 15 deletions.
6 changes: 2 additions & 4 deletions src/entity_embedding_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,8 @@ end

# function to set default new embedding dimension
function set_default_new_embedding_dim(numlevels)
# Either min_ratio or max_ratio of numlevels depending on >= threshold or < threshold
min_ratio, max_ratio = 0.2, 0.5
threshold = 20
return ceil(Int, ((numlevels >= threshold) ? min_ratio : max_ratio) * numlevels)
# Set default to the minimum of num_levels-1 and 10
return min(numlevels - 1, 10)
end

MISMATCH_INDS(wrong_feats) =
Expand Down
5 changes: 4 additions & 1 deletion src/mlj_model_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ end
include("fit_utils.jl")
include("entity_embedding_utils.jl")

const ERR_BUILDER = "Builder does not appear to build an architecture compatible with supplied data. "
const ERR_BUILDER =
"Builder does not appear to build an architecture compatible with supplied data. "

true_rng(model) = model.rng isa Integer ? Random.Xoshiro(model.rng) : model.rng

Expand All @@ -83,6 +84,8 @@ function MLJModelInterface.fit(model::MLJFluxModel,
if enable_entity_embs
X = convert_to_table(X)
featnames = Tables.schema(X).names
# entityprops is (index = cat_inds[i], levels = num_levels[i], newdim = newdims[i])
# for each categorical feature
entityprops, entityemb_output_dim =
prepare_entityembs(X, featnames, cat_inds, model.embedding_dims)
X, ordinal_mappings = ordinal_encoder_fit_transform(X; featinds = cat_inds)
Expand Down
4 changes: 2 additions & 2 deletions src/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -210,8 +210,8 @@ const XDOC = """
const EMBDOC = """
- `embedding_dims` is a `Dict` whose keys are names of categorical features, given as symbols, and whose values are real numbers representing the desired dimensionality
of the entity embeddings of such features. An integer value such as `7` would set the embedding dimensionality of such feature to `7`. Meanwhile, a float value such as `0.5`
would set the embedding dimensionality of such column to `ceil(0.5 * number of levels in feature)`. Any unspecified features will by default have their values set to either 0.5
or 0.2 depending on whether the number of levels in the column is less than 20 or greater than 20 respectively.
would set the embedding dimensionality of such column to `ceil(0.5 * number of levels in feature)`.
Any unspecified features will by default have their values set to min(number of levels - 1, 10).
"""

const TRANSFORMDOC = """
Expand Down
2 changes: 1 addition & 1 deletion test/entity_embedding.jl
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ end
expected_dims = [
[(3, 5), (2, 4), (1, 3)],
[(1, 5), (4, 4), (2, 3)],
[(3, 5), (2, 4), (2, 3)],
[(4, 5), (3, 4), (2, 3)],
]

size([
Expand Down
10 changes: 3 additions & 7 deletions test/entity_embedding_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,8 @@ end


@testset "set_default_new_embedding_dim" begin
# <= 20
@test MLJFlux.set_default_new_embedding_dim(10) == 5
@test MLJFlux.set_default_new_embedding_dim(15) == 8
# > 20
@test MLJFlux.set_default_new_embedding_dim(25) == 5
@test MLJFlux.set_default_new_embedding_dim(30) == 6
@test MLJFlux.set_default_new_embedding_dim(15) == 10
@test MLJFlux.set_default_new_embedding_dim(9) == 8
end

@testset "check_mismatch_in_cat_feats" begin
Expand Down Expand Up @@ -73,7 +69,7 @@ end
# Test case 2: Handling of unspecified dimensions with defaults
embedding_dims = Dict("color" => 0.5) # "size" is not specified
result = MLJFlux.set_new_embedding_dims(featnames, cat_inds, num_levels, embedding_dims)
@test result == [2, MLJFlux.set_default_new_embedding_dim(5)] # Expected to be ceil(1.5) = 2 for "color", and default 1 for "size"
@test result == [2, MLJFlux.set_default_new_embedding_dim(5)]

# Test case 3: All embedding dimensions are unspecified, default for all
embedding_dims = Dict()
Expand Down

0 comments on commit daa4b2a

Please sign in to comment.