From afab1e050d73e4a69e4ce050a697b1401856729a Mon Sep 17 00:00:00 2001 From: Essam Date: Mon, 5 Aug 2024 18:29:34 -0500 Subject: [PATCH] =?UTF-8?q?=F0=9D=9A=AB=20Change=20name=20of=20embedding?= =?UTF-8?q?=20layer?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/entity_embedding.jl | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/src/entity_embedding.jl b/src/entity_embedding.jl index 806f0e4..d672c0c 100644 --- a/src/entity_embedding.jl +++ b/src/entity_embedding.jl @@ -25,7 +25,7 @@ entityprops = [ numfeats = 4 # Run it through the categorical embedding layer -embedder = CategoricalEmbedder(entityprops, 4) +embedder = EntityEmbedder(entityprops, 4) julia> output = embedder(batch) 5×10 Matrix{Float64}: 0.2 0.3 0.4 0.5 … 0.8 0.9 1.0 1.1 @@ -37,38 +37,39 @@ julia> output = embedder(batch) """ # 1. Define layer struct to hold parameters -struct CategoricalEmbedder{A1 <: AbstractVector, A2 <: AbstractVector, I <: Integer} +struct EntityEmbedder{A1 <: AbstractVector, A2 <: AbstractVector, I <: Integer} embedders::A1 - modifiers::A2 + modifiers::A2 # applied on the input before passing it to the embedder numfeats::I end # 2. Define the forward pass (i.e., calling an instance of the layer) -(m::CategoricalEmbedder)(x) = vcat([ m.embedders[i](m.modifiers[i](x, i)) for i in 1:m.numfeats]...) +(m::EntityEmbedder)(x) = + vcat([m.embedders[i](m.modifiers[i](x, i)) for i in 1:m.numfeats]...) # 3. Define the constructor which initializes the parameters and returns the instance -function CategoricalEmbedder(entityprops, numfeats; init=Flux.randn32) +function EntityEmbedder(entityprops, numfeats; init = Flux.randn32) embedders = [] modifiers = [] - + # Setup entityprops cat_inds = [entityprop.index for entityprop in entityprops] levels_per_feat = [entityprop.levels for entityprop in entityprops] newdims = [entityprop.newdim for entityprop in entityprops] - + c = 1 for i in 1:numfeats if i in cat_inds - push!(embedders, Flux.Embedding(levels_per_feat[c] => newdims[c], init=init)) + push!(embedders, Flux.Embedding(levels_per_feat[c] => newdims[c], init = init)) push!(modifiers, (x, i) -> Int.(x[i, :])) c += 1 else - push!(embedders, feat->feat) + push!(embedders, feat -> feat) push!(modifiers, (x, i) -> x[i:i, :]) end end - CategoricalEmbedder(embedders, modifiers, numfeats) + EntityEmbedder(embedders, modifiers, numfeats) end # 4. Register it as layer with Flux -Flux.@layer CategoricalEmbedder \ No newline at end of file +Flux.@layer EntityEmbedder \ No newline at end of file