diff --git a/src/entity_embedding_utils.jl b/src/entity_embedding_utils.jl index a87c184..c75b7a9 100644 --- a/src/entity_embedding_utils.jl +++ b/src/entity_embedding_utils.jl @@ -19,10 +19,8 @@ end # function to set default new embedding dimension function set_default_new_embedding_dim(numlevels) - # Either min_ratio or max_ratio of numlevels depending on >= threshold or < threshold - min_ratio, max_ratio = 0.2, 0.5 - threshold = 20 - return ceil(Int, ((numlevels >= threshold) ? min_ratio : max_ratio) * numlevels) + # Set default to the minimum of num_levels-1 and 10 + return min(numlevels - 1, 10) end MISMATCH_INDS(wrong_feats) = diff --git a/src/mlj_model_interface.jl b/src/mlj_model_interface.jl index 51bea15..1ba0d03 100644 --- a/src/mlj_model_interface.jl +++ b/src/mlj_model_interface.jl @@ -58,7 +58,8 @@ end include("fit_utils.jl") include("entity_embedding_utils.jl") -const ERR_BUILDER = "Builder does not appear to build an architecture compatible with supplied data. " +const ERR_BUILDER = + "Builder does not appear to build an architecture compatible with supplied data. " true_rng(model) = model.rng isa Integer ? Random.Xoshiro(model.rng) : model.rng @@ -83,6 +84,8 @@ function MLJModelInterface.fit(model::MLJFluxModel, if enable_entity_embs X = convert_to_table(X) featnames = Tables.schema(X).names + # entityprops is (index = cat_inds[i], levels = num_levels[i], newdim = newdims[i]) + # for each categorical feature entityprops, entityemb_output_dim = prepare_entityembs(X, featnames, cat_inds, model.embedding_dims) X, ordinal_mappings = ordinal_encoder_fit_transform(X; featinds = cat_inds) diff --git a/src/types.jl b/src/types.jl index f0edd14..36f61d2 100644 --- a/src/types.jl +++ b/src/types.jl @@ -210,8 +210,8 @@ const XDOC = """ const EMBDOC = """ - `embedding_dims` is a `Dict` whose keys are names of categorical features, given as symbols, and whose values are real numbers representing the desired dimensionality of the entity embeddings of such features. An integer value such as `7` would set the embedding dimensionality of such feature to `7`. Meanwhile, a float value such as `0.5` - would set the embedding dimensionality of such column to `ceil(0.5 * number of levels in feature)`. Any unspecified features will by default have their values set to either 0.5 - or 0.2 depending on whether the number of levels in the column is less than 20 or greater than 20 respectively. + would set the embedding dimensionality of such column to `ceil(0.5 * number of levels in feature)`. + Any unspecified features will by default have their values set to min(number of levels - 1, 10). """ const TRANSFORMDOC = """ diff --git a/test/entity_embedding.jl b/test/entity_embedding.jl index 0202a41..dcfda7c 100644 --- a/test/entity_embedding.jl +++ b/test/entity_embedding.jl @@ -178,7 +178,7 @@ end expected_dims = [ [(3, 5), (2, 4), (1, 3)], [(1, 5), (4, 4), (2, 3)], - [(3, 5), (2, 4), (2, 3)], + [(4, 5), (3, 4), (2, 3)], ] size([ diff --git a/test/entity_embedding_utils.jl b/test/entity_embedding_utils.jl index fbd64a0..d7aaaf9 100644 --- a/test/entity_embedding_utils.jl +++ b/test/entity_embedding_utils.jl @@ -18,12 +18,8 @@ end @testset "set_default_new_embedding_dim" begin - # <= 20 - @test MLJFlux.set_default_new_embedding_dim(10) == 5 - @test MLJFlux.set_default_new_embedding_dim(15) == 8 - # > 20 - @test MLJFlux.set_default_new_embedding_dim(25) == 5 - @test MLJFlux.set_default_new_embedding_dim(30) == 6 + @test MLJFlux.set_default_new_embedding_dim(15) == 10 + @test MLJFlux.set_default_new_embedding_dim(9) == 8 end @testset "check_mismatch_in_cat_feats" begin @@ -73,7 +69,7 @@ end # Test case 2: Handling of unspecified dimensions with defaults embedding_dims = Dict("color" => 0.5) # "size" is not specified result = MLJFlux.set_new_embedding_dims(featnames, cat_inds, num_levels, embedding_dims) - @test result == [2, MLJFlux.set_default_new_embedding_dim(5)] # Expected to be ceil(1.5) = 2 for "color", and default 1 for "size" + @test result == [2, MLJFlux.set_default_new_embedding_dim(5)] # Test case 3: All embedding dimensions are unspecified, default for all embedding_dims = Dict()