Skip to content

Commit

Permalink
✅ Make embedding model output Float32
Browse files Browse the repository at this point in the history
  • Loading branch information
EssamWisam committed Sep 14, 2024
1 parent 407edc5 commit 7818959
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
3 changes: 1 addition & 2 deletions src/entity_embedding.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,14 @@ julia> output = embedder(batch)
```
""" # 1. Define layer struct to hold parameters
struct EntityEmbedder{A1 <: AbstractVector, A2 <: AbstractVector, I <: Integer}

embedders::A1
modifiers::A2 # applied on the input before passing it to the embedder
numfeats::I
end

# 2. Define the forward pass (i.e., calling an instance of the layer)
(m::EntityEmbedder)(x) =
vcat([m.embedders[i](m.modifiers[i](x, i)) for i in 1:m.numfeats]...)
Float32.(vcat([m.embedders[i](m.modifiers[i](x, i)) for i in 1:m.numfeats]...))

# 3. Define the constructor which initializes the parameters and returns the instance
function EntityEmbedder(entityprops, numfeats; init = Flux.randn32)
Expand Down
3 changes: 2 additions & 1 deletion test/entity_embedding.jl
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,8 @@ end
numfeats = 4
embedder = MLJFlux.EntityEmbedder(entityprops, 4)
output = embedder(batch)
@test output == batch
@test output batch
@test eltype(output) == Float32
end


Expand Down

0 comments on commit 7818959

Please sign in to comment.