Skip to content

Commit

Permalink
Fix ordinal encoding float types
Browse files Browse the repository at this point in the history
  • Loading branch information
EssamWisam committed Sep 16, 2024
1 parent 7818959 commit 493d23b
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 11 deletions.
4 changes: 2 additions & 2 deletions src/encoders.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@ function ordinal_encoder_fit(X; featinds)
feat_col = Tables.getcolumn(Tables.columns(X), i)
feat_levels = levels(feat_col)
# Check if feat levels is already ordinal encoded in which case we skip
(Set([float(i) for i in 1:length(feat_levels)]) == Set(feat_levels)) && continue
(Set([Float32(i) for i in 1:length(feat_levels)]) == Set(feat_levels)) && continue
# Compute the dict using the given feature_mapper function
mapping_matrix[i] =
Dict{Any, AbstractFloat}(
value => float(index) for (index, value) in enumerate(feat_levels)
value => Float32(index) for (index, value) in enumerate(feat_levels)
)
end
return mapping_matrix
Expand Down
2 changes: 1 addition & 1 deletion src/entity_embedding.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ end

# 2. Define the forward pass (i.e., calling an instance of the layer)
(m::EntityEmbedder)(x) =
Float32.(vcat([m.embedders[i](m.modifiers[i](x, i)) for i in 1:m.numfeats]...))
(vcat([m.embedders[i](m.modifiers[i](x, i)) for i in 1:m.numfeats]...))

# 3. Define the constructor which initializes the parameters and returns the instance
function EntityEmbedder(entityprops, numfeats; init = Flux.randn32)
Expand Down
4 changes: 2 additions & 2 deletions test/encoders.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
@test map[2] == Dict('a' => 1, 'b' => 2, 'c' => 3, 'd' => 4, 'e' => 5)
@test map[3] == Dict("b" => 1, "c" => 2, "d" => 3)
@test Xenc.Column1 == [1.0, 2.0, 3.0, 4.0, 5.0]
@test Xenc.Column2 == [1.0, 2.0, 3.0, 4.0, 5.0]
@test Xenc.Column3 == [1, 2, 3]
@test Xenc.Column2 == Float32.([1.0, 2.0, 3.0, 4.0, 5.0])
@test Xenc.Column3 == Float32.([1, 2, 3])
@test Xenc.Column4 == [1.0, 2.0, 3.0, 4.0, 5.0]

X = coerce(X, :Column1 => Multiclass)
Expand Down
12 changes: 6 additions & 6 deletions test/entity_embedding.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
"""
See more functional tests in entity_embedding_utils.jl and mlj_model_interface.jl
"""

batch = [
batch = Float32.([
0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0 1.1;
1 2 3 4 5 6 7 8 9 10;
0.9 0.1 0.4 0.5 0.3 0.7 0.8 0.9 1.0 1.1
1 1 2 2 1 1 2 2 1 1
]
1 2 3 4 5 6 7 8 9 10;
0.9 0.1 0.4 0.5 0.3 0.7 0.8 0.9 1.0 1.1;
1 1 2 2 1 1 2 2 1 1
])


entityprops = [
(index = 2, levels = 10, newdim = 2),
Expand Down

0 comments on commit 493d23b

Please sign in to comment.