From 7818959d7c86289e4a7808ac636d93e99f6c9f23 Mon Sep 17 00:00:00 2001 From: Essam Date: Sat, 14 Sep 2024 18:18:56 -0500 Subject: [PATCH 1/4] =?UTF-8?q?=E2=9C=85=20Make=20embedding=20model=20outp?= =?UTF-8?q?ut=20Float32?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/entity_embedding.jl | 3 +-- test/entity_embedding.jl | 3 ++- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/entity_embedding.jl b/src/entity_embedding.jl index 3e16ea2..4009745 100644 --- a/src/entity_embedding.jl +++ b/src/entity_embedding.jl @@ -36,7 +36,6 @@ julia> output = embedder(batch) ``` """ # 1. Define layer struct to hold parameters struct EntityEmbedder{A1 <: AbstractVector, A2 <: AbstractVector, I <: Integer} - embedders::A1 modifiers::A2 # applied on the input before passing it to the embedder numfeats::I @@ -44,7 +43,7 @@ end # 2. Define the forward pass (i.e., calling an instance of the layer) (m::EntityEmbedder)(x) = - vcat([m.embedders[i](m.modifiers[i](x, i)) for i in 1:m.numfeats]...) + Float32.(vcat([m.embedders[i](m.modifiers[i](x, i)) for i in 1:m.numfeats]...)) # 3. Define the constructor which initializes the parameters and returns the instance function EntityEmbedder(entityprops, numfeats; init = Flux.randn32) diff --git a/test/entity_embedding.jl b/test/entity_embedding.jl index 738351a..5cc1f23 100644 --- a/test/entity_embedding.jl +++ b/test/entity_embedding.jl @@ -145,7 +145,8 @@ end numfeats = 4 embedder = MLJFlux.EntityEmbedder(entityprops, 4) output = embedder(batch) - @test output == batch + @test output ≈ batch + @test eltype(output) == Float32 end From 493d23bbb7ab847df56bf9d17c6f0a9695f84cf9 Mon Sep 17 00:00:00 2001 From: Essam Date: Sun, 15 Sep 2024 21:49:09 -0500 Subject: [PATCH 2/4] Fix ordinal encoding float types --- src/encoders.jl | 4 ++-- src/entity_embedding.jl | 2 +- test/encoders.jl | 4 ++-- test/entity_embedding.jl | 12 ++++++------ 4 files changed, 11 insertions(+), 11 deletions(-) diff --git a/src/encoders.jl b/src/encoders.jl index de51714..a7eba0b 100644 --- a/src/encoders.jl +++ b/src/encoders.jl @@ -19,11 +19,11 @@ function ordinal_encoder_fit(X; featinds) feat_col = Tables.getcolumn(Tables.columns(X), i) feat_levels = levels(feat_col) # Check if feat levels is already ordinal encoded in which case we skip - (Set([float(i) for i in 1:length(feat_levels)]) == Set(feat_levels)) && continue + (Set([Float32(i) for i in 1:length(feat_levels)]) == Set(feat_levels)) && continue # Compute the dict using the given feature_mapper function mapping_matrix[i] = Dict{Any, AbstractFloat}( - value => float(index) for (index, value) in enumerate(feat_levels) + value => Float32(index) for (index, value) in enumerate(feat_levels) ) end return mapping_matrix diff --git a/src/entity_embedding.jl b/src/entity_embedding.jl index 4009745..313e3e6 100644 --- a/src/entity_embedding.jl +++ b/src/entity_embedding.jl @@ -43,7 +43,7 @@ end # 2. Define the forward pass (i.e., calling an instance of the layer) (m::EntityEmbedder)(x) = - Float32.(vcat([m.embedders[i](m.modifiers[i](x, i)) for i in 1:m.numfeats]...)) + (vcat([m.embedders[i](m.modifiers[i](x, i)) for i in 1:m.numfeats]...)) # 3. Define the constructor which initializes the parameters and returns the instance function EntityEmbedder(entityprops, numfeats; init = Flux.randn32) diff --git a/test/encoders.jl b/test/encoders.jl index 7bd5da9..50b93b6 100644 --- a/test/encoders.jl +++ b/test/encoders.jl @@ -12,8 +12,8 @@ @test map[2] == Dict('a' => 1, 'b' => 2, 'c' => 3, 'd' => 4, 'e' => 5) @test map[3] == Dict("b" => 1, "c" => 2, "d" => 3) @test Xenc.Column1 == [1.0, 2.0, 3.0, 4.0, 5.0] - @test Xenc.Column2 == [1.0, 2.0, 3.0, 4.0, 5.0] - @test Xenc.Column3 == [1, 2, 3] + @test Xenc.Column2 == Float32.([1.0, 2.0, 3.0, 4.0, 5.0]) + @test Xenc.Column3 == Float32.([1, 2, 3]) @test Xenc.Column4 == [1.0, 2.0, 3.0, 4.0, 5.0] X = coerce(X, :Column1 => Multiclass) diff --git a/test/entity_embedding.jl b/test/entity_embedding.jl index 5cc1f23..da0c89b 100644 --- a/test/entity_embedding.jl +++ b/test/entity_embedding.jl @@ -1,13 +1,13 @@ """ See more functional tests in entity_embedding_utils.jl and mlj_model_interface.jl """ - -batch = [ +batch = Float32.([ 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0 1.1; - 1 2 3 4 5 6 7 8 9 10; - 0.9 0.1 0.4 0.5 0.3 0.7 0.8 0.9 1.0 1.1 - 1 1 2 2 1 1 2 2 1 1 -] + 1 2 3 4 5 6 7 8 9 10; + 0.9 0.1 0.4 0.5 0.3 0.7 0.8 0.9 1.0 1.1; + 1 1 2 2 1 1 2 2 1 1 +]) + entityprops = [ (index = 2, levels = 10, newdim = 2), From c57870ba5139b591c3bec0303acf8ad815d7db96 Mon Sep 17 00:00:00 2001 From: Essam Date: Wed, 25 Sep 2024 20:11:33 -0500 Subject: [PATCH 3/4] Fix ordinal encoding ransform type --- src/encoders.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/encoders.jl b/src/encoders.jl index a7eba0b..4d00d2d 100644 --- a/src/encoders.jl +++ b/src/encoders.jl @@ -67,7 +67,7 @@ function ordinal_encoder_transform(X, mapping_matrix) test_levels = levels(col) check_unkown_levels(train_levels, test_levels) level2scalar = mapping_matrix[ind] - new_col = recode(col, level2scalar...) + new_col = recode(unwrap.(col), level2scalar...) push!(new_feats, new_col) else push!(new_feats, col) From 3ecf348dac53f24b14406a373c34719bf801b0ee Mon Sep 17 00:00:00 2001 From: Essam Date: Thu, 26 Sep 2024 19:19:36 -0500 Subject: [PATCH 4/4] =?UTF-8?q?=F0=9F=9B=91=20=20Stop=20forcing=20abstract?= =?UTF-8?q?=20type?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/encoders.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/encoders.jl b/src/encoders.jl index 4d00d2d..60aac91 100644 --- a/src/encoders.jl +++ b/src/encoders.jl @@ -22,7 +22,7 @@ function ordinal_encoder_fit(X; featinds) (Set([Float32(i) for i in 1:length(feat_levels)]) == Set(feat_levels)) && continue # Compute the dict using the given feature_mapper function mapping_matrix[i] = - Dict{Any, AbstractFloat}( + Dict( value => Float32(index) for (index, value) in enumerate(feat_levels) ) end