Skip to content

Commit

Permalink
𝚫 Change name of embedding layer
Browse files Browse the repository at this point in the history
  • Loading branch information
EssamWisam committed Aug 5, 2024
1 parent 41ca8e0 commit afab1e0
Showing 1 changed file with 12 additions and 11 deletions.
23 changes: 12 additions & 11 deletions src/entity_embedding.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ entityprops = [
numfeats = 4
# Run it through the categorical embedding layer
embedder = CategoricalEmbedder(entityprops, 4)
embedder = EntityEmbedder(entityprops, 4)
julia> output = embedder(batch)
5×10 Matrix{Float64}:
0.2 0.3 0.4 0.5 … 0.8 0.9 1.0 1.1
Expand All @@ -37,38 +37,39 @@ julia> output = embedder(batch)
"""

# 1. Define layer struct to hold parameters
struct CategoricalEmbedder{A1 <: AbstractVector, A2 <: AbstractVector, I <: Integer}
struct EntityEmbedder{A1 <: AbstractVector, A2 <: AbstractVector, I <: Integer}
embedders::A1
modifiers::A2
modifiers::A2 # applied on the input before passing it to the embedder
numfeats::I
end

# 2. Define the forward pass (i.e., calling an instance of the layer)
(m::CategoricalEmbedder)(x) = vcat([ m.embedders[i](m.modifiers[i](x, i)) for i in 1:m.numfeats]...)
(m::EntityEmbedder)(x) =
vcat([m.embedders[i](m.modifiers[i](x, i)) for i in 1:m.numfeats]...)

# 3. Define the constructor which initializes the parameters and returns the instance
function CategoricalEmbedder(entityprops, numfeats; init=Flux.randn32)
function EntityEmbedder(entityprops, numfeats; init = Flux.randn32)
embedders = []
modifiers = []

# Setup entityprops
cat_inds = [entityprop.index for entityprop in entityprops]
levels_per_feat = [entityprop.levels for entityprop in entityprops]
newdims = [entityprop.newdim for entityprop in entityprops]

c = 1
for i in 1:numfeats
if i in cat_inds
push!(embedders, Flux.Embedding(levels_per_feat[c] => newdims[c], init=init))
push!(embedders, Flux.Embedding(levels_per_feat[c] => newdims[c], init = init))
push!(modifiers, (x, i) -> Int.(x[i, :]))
c += 1
else
push!(embedders, feat->feat)
push!(embedders, feat -> feat)
push!(modifiers, (x, i) -> x[i:i, :])
end
end

CategoricalEmbedder(embedders, modifiers, numfeats)
EntityEmbedder(embedders, modifiers, numfeats)
end

# 4. Register it as layer with Flux
Flux.@layer CategoricalEmbedder
Flux.@layer EntityEmbedder

0 comments on commit afab1e0

Please sign in to comment.