Skip to content

Commit

Permalink
🤔 Make all input tables float
Browse files Browse the repository at this point in the history
  • Loading branch information
EssamWisam committed Aug 6, 2024
1 parent e7b0fff commit 833b845
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 21 deletions.
4 changes: 2 additions & 2 deletions src/encoders.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ function ordinal_encoder_fit(X; featinds)
(Set(1:length(feat_levels)) == Set(feat_levels)) && continue
# Compute the dict using the given feature_mapper function
mapping_matrix[i] =
Dict{Any, Integer}(value => index for (index, value) in enumerate(feat_levels))
Dict{Any, Integer}(value => float(index) for (index, value) in enumerate(feat_levels))
end
return mapping_matrix
end
Expand Down Expand Up @@ -127,7 +127,7 @@ function embedding_transform(X, mapping_matrices)
# Create the transformation function for each column
if feat_name in keys(mapping_matrices)
level2vector = mapping_matrices[feat_name]
new_multi_col = map(x -> level2vector[:, unwrap(x)], col)
new_multi_col = map(x -> level2vector[:, Int.(unwrap(x))], col)
new_multi_col = [col for col in eachrow(hcat(new_multi_col...))]
push!(new_cols, new_multi_col...)
feat_names_with_inds = generate_new_feat_names(
Expand Down
10 changes: 5 additions & 5 deletions test/encoders.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

@testset "ordinal encoder" begin
X = (
Column1 = [1, 2, 3, 4, 5],
Column1 = [1.0, 2.0, 3.0, 4.0, 5.0],
Column2 = categorical(['a', 'b', 'c', 'd', 'e']),
Column3 = categorical(["b", "c", "d"]),
Column4 = [1.0, 2.0, 3.0, 4.0, 5.0],
Expand All @@ -11,8 +11,8 @@
Xenc = MLJFlux.ordinal_encoder_transform(X, map)
@test map[2] == Dict('a' => 1, 'b' => 2, 'c' => 3, 'd' => 4, 'e' => 5)
@test map[3] == Dict("b" => 1, "c" => 2, "d" => 3)
@test Xenc.Column1 == [1, 2, 3, 4, 5]
@test Xenc.Column2 == [1, 2, 3, 4, 5]
@test Xenc.Column1 == [1.0, 2.0, 3.0, 4.0, 5.0]
@test Xenc.Column2 == [1.0, 2.0, 3.0, 4.0, 5.0]
@test Xenc.Column3 == [1, 2, 3]
@test Xenc.Column4 == [1.0, 2.0, 3.0, 4.0, 5.0]

Expand Down Expand Up @@ -42,7 +42,7 @@ end

@testset "embedding_transform works" begin
X = (
Column1 = [1, 2, 3, 4, 5],
Column1 = [1.0, 2.0, 3.0, 4.0, 5.0],
Column2 = categorical(['a', 'b', 'c', 'd', 'e']),
Column3 = categorical(["b", "c", "d", "f", "f"]),
Column4 = [1.0, 2.0, 3.0, 4.0, 5.0],
Expand All @@ -61,7 +61,7 @@ end
X, _ = MLJFlux.ordinal_encoder_fit_transform(X; featinds = [2, 3])
Xenc = MLJFlux.embedding_transform(X, mapping_matrices)
@test Xenc == (
Column1 = [1, 2, 3, 4, 5],
Column1 = [1.0, 2.0, 3.0, 4.0, 5.0],
Column2_1 = [1.0, 0.5, 0.7, 4.0, 5.0],
Column2_2 = [0.4, 2.0, 3.0, 0.9, 0.2],
Column2_3 = [0.1, 0.6, 0.8, 0.3, 0.4],
Expand Down
14 changes: 7 additions & 7 deletions test/entity_embedding.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ entityprops = [
(index = 4, levels = 2, newdim = 2),
]

embedder = EntityEmbedder(entityprops, 4)
embedder = MLJFlux.EntityEmbedder(entityprops, 4)

output = embedder(batch)

Expand Down Expand Up @@ -68,7 +68,7 @@ end
]

cat_model = Chain(
EntityEmbedder(entityprops, 4),
MLJFlux.EntityEmbedder(entityprops, 4),
Dense(9 => (ind == 1) ? 10 : 1),
finalizer[ind],
)
Expand Down Expand Up @@ -143,7 +143,7 @@ end
@testset "Transparent when no categorical variables" begin
entityprops = []
numfeats = 4
embedder = EntityEmbedder(entityprops, 4)
embedder = MLJFlux.EntityEmbedder(entityprops, 4)
output = embedder(batch)
@test output == batch
end
Expand All @@ -158,7 +158,7 @@ end
]

X = (
Column1 = [1, 2, 3, 4, 5],
Column1 = [1.0, 2.0, 3.0, 4.0, 5.0],
Column2 = categorical(['a', 'b', 'c', 'd', 'e']),
Column3 = categorical(["b", "c", "d", "f", "f"], ordered = true),
Column4 = [1.0, 2.0, 3.0, 4.0, 5.0],
Expand Down Expand Up @@ -230,7 +230,7 @@ end
]
# table case
X1 = (
Column1 = [1, 2, 3, 4, 5],
Column1 = [1.0, 2.0, 3.0, 4.0, 5.0],
Column4 = [1.0, 2.0, 3.0, 4.0, 5.0],
Column5 = randn(5),
)
Expand Down Expand Up @@ -275,7 +275,7 @@ end
]

X = (
Column1 = [1, 2, 3, 4, 5],
Column1 = [1.0, 2.0, 3.0, 4.0, 5.0],
Column2 = categorical(['a', 'b', 'c', 'd', 'e']),
Column3 = [1.0, 2.0, 3.0, 4.0, 5.0],
Column4 = randn(5),
Expand Down Expand Up @@ -337,7 +337,7 @@ end
]

X = (
Column1 = [1, 2, 3, 4, 5],
Column1 = [1.0, 2.0, 3.0, 4.0, 5.0],
Column2 = categorical(['a', 'b', 'c', 'd', 'e']),
Column3 = [1.0, 2.0, 3.0, 4.0, 5.0],
Column4 = randn(5),
Expand Down
25 changes: 18 additions & 7 deletions test/entity_embedding_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,11 @@ end
featnames = [:a, :b, :c]
cat_inds = [1, 3]
specified_featinds = [1, 2, 3]
@test_throws ArgumentError MLJFlux.check_mismatch_in_cat_feats(featnames, cat_inds, specified_featinds)
@test_throws ArgumentError MLJFlux.check_mismatch_in_cat_feats(
featnames,
cat_inds,
specified_featinds,
)

# Test with empty specified_featinds
featnames = [:a, :b, :c]
Expand All @@ -49,7 +53,11 @@ end
featnames = [:a, :b, :c]
cat_inds = []
specified_featinds = [1, 2]
@test_throws ArgumentError MLJFlux.check_mismatch_in_cat_feats(featnames, cat_inds, specified_featinds)
@test_throws ArgumentError MLJFlux.check_mismatch_in_cat_feats(
featnames,
cat_inds,
specified_featinds,
)
end

@testset "Testing set_new_embedding_dims" begin
Expand All @@ -58,7 +66,7 @@ end
cat_inds = [1, 2]
num_levels = [3, 5]
embedding_dims = Dict("color" => 0.5, "size" => 2)

result = MLJFlux.set_new_embedding_dims(featnames, cat_inds, num_levels, embedding_dims)
@test result == [2, 2] # Expected to be ceil(1.5) = 2 for "color", and exact 2 for "size"

Expand All @@ -70,12 +78,15 @@ end
# Test case 3: All embedding dimensions are unspecified, default for all
embedding_dims = Dict()
result = MLJFlux.set_new_embedding_dims(featnames, cat_inds, num_levels, embedding_dims)
@test result == [MLJFlux.set_default_new_embedding_dim(3), MLJFlux.set_default_new_embedding_dim(5)] # Default dimensions for both
@test result == [
MLJFlux.set_default_new_embedding_dim(3),
MLJFlux.set_default_new_embedding_dim(5),
] # Default dimensions for both
end

@testset "test get_cat_inds" begin
X = (
C1 = [1, 2, 3, 4, 5],
C1 = [1.0, 2.0, 3.0, 4.0, 5.0],
C2 = ['a', 'b', 'c', 'd', 'e'],
C3 = ["b", "c", "d", "e", "f"],
C4 = [1.0, 2.0, 3.0, 4.0, 5.0],
Expand All @@ -86,7 +97,7 @@ end

@testset "Number of levels" begin
X = (
C1 = [1, 2, 3, 4, 5],
C1 = [1.0, 2.0, 3.0, 4.0, 5.0],
C2 = ['a', 'b', 'c', 'd', 'e'],
C3 = ["b", "c", "d", "f", "f"],
C4 = [1.0, 2.0, 3.0, 4.0, 5.0],
Expand All @@ -98,7 +109,7 @@ end

@testset "Testing prepare_entityembs" begin
X = (
Column1 = [1, 2, 3, 4, 5],
Column1 = [1.0, 2.0, 3.0, 4.0, 5.0],
Column2 = categorical(['a', 'b', 'c', 'd', 'e']),
Column3 = categorical(["b", "c", "d"]),
Column4 = [1.0, 2.0, 3.0, 4.0, 5.0],
Expand Down

0 comments on commit 833b845

Please sign in to comment.