Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/entity-embeddings' into double-t…
Browse files Browse the repository at this point in the history
…rouble
  • Loading branch information
ablaom committed Sep 16, 2024
2 parents 945016d + 493d23b commit eac2974
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 13 deletions.
4 changes: 2 additions & 2 deletions src/encoders.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@ function ordinal_encoder_fit(X; featinds)
feat_col = Tables.getcolumn(Tables.columns(X), i)
feat_levels = levels(feat_col)
# Check if feat levels is already ordinal encoded in which case we skip
(Set([float(i) for i in 1:length(feat_levels)]) == Set(feat_levels)) && continue
(Set([Float32(i) for i in 1:length(feat_levels)]) == Set(feat_levels)) && continue
# Compute the dict using the given feature_mapper function
mapping_matrix[i] =
Dict{Any, AbstractFloat}(
value => float(index) for (index, value) in enumerate(feat_levels)
value => Float32(index) for (index, value) in enumerate(feat_levels)
)
end
return mapping_matrix
Expand Down
3 changes: 1 addition & 2 deletions src/entity_embedding.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,14 @@ julia> output = embedder(batch)
```
""" # 1. Define layer struct to hold parameters
struct EntityEmbedder{A1 <: AbstractVector, A2 <: AbstractVector, I <: Integer}

embedders::A1
modifiers::A2 # applied on the input before passing it to the embedder
numfeats::I
end

# 2. Define the forward pass (i.e., calling an instance of the layer)
(m::EntityEmbedder)(x) =
vcat([m.embedders[i](m.modifiers[i](x, i)) for i in 1:m.numfeats]...)
(vcat([m.embedders[i](m.modifiers[i](x, i)) for i in 1:m.numfeats]...))

# 3. Define the constructor which initializes the parameters and returns the instance
function EntityEmbedder(entityprops, numfeats; init = Flux.randn32)
Expand Down
4 changes: 2 additions & 2 deletions test/encoders.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
@test map[2] == Dict('a' => 1, 'b' => 2, 'c' => 3, 'd' => 4, 'e' => 5)
@test map[3] == Dict("b" => 1, "c" => 2, "d" => 3)
@test Xenc.Column1 == [1.0, 2.0, 3.0, 4.0, 5.0]
@test Xenc.Column2 == [1.0, 2.0, 3.0, 4.0, 5.0]
@test Xenc.Column3 == [1, 2, 3]
@test Xenc.Column2 == Float32.([1.0, 2.0, 3.0, 4.0, 5.0])
@test Xenc.Column3 == Float32.([1, 2, 3])
@test Xenc.Column4 == [1.0, 2.0, 3.0, 4.0, 5.0]

X = coerce(X, :Column1 => Multiclass)
Expand Down
15 changes: 8 additions & 7 deletions test/entity_embedding.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
"""
See more functional tests in entity_embedding_utils.jl and mlj_model_interface.jl
"""

batch = [
batch = Float32.([
0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0 1.1;
1 2 3 4 5 6 7 8 9 10;
0.9 0.1 0.4 0.5 0.3 0.7 0.8 0.9 1.0 1.1
1 1 2 2 1 1 2 2 1 1
]
1 2 3 4 5 6 7 8 9 10;
0.9 0.1 0.4 0.5 0.3 0.7 0.8 0.9 1.0 1.1;
1 1 2 2 1 1 2 2 1 1
])


entityprops = [
(index = 2, levels = 10, newdim = 2),
Expand Down Expand Up @@ -145,7 +145,8 @@ end
numfeats = 4
embedder = MLJFlux.EntityEmbedder(entityprops, 4)
output = embedder(batch)
@test output == batch
@test output batch
@test eltype(output) == Float32
end


Expand Down

0 comments on commit eac2974

Please sign in to comment.