From e7b0fffa12102a822e9d153b5831c2e02b68949b Mon Sep 17 00:00:00 2001 From: Essam Date: Mon, 5 Aug 2024 20:56:30 -0500 Subject: [PATCH] =?UTF-8?q?=F0=9F=94=A5=20Delete=20old=20file?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/entity_embedding_old.jl | 50 ------------------------------------- 1 file changed, 50 deletions(-) delete mode 100644 src/entity_embedding_old.jl diff --git a/src/entity_embedding_old.jl b/src/entity_embedding_old.jl deleted file mode 100644 index b87a79ba..00000000 --- a/src/entity_embedding_old.jl +++ /dev/null @@ -1,50 +0,0 @@ -# This is just some experimental code -# to implement EntityEmbeddings for purely -# categorical features - - -using Flux - -mutable struct EmbeddingMatrix - e - levels - - function EmbeddingMatrix(levels; dim=4) - if dim <= 0 - dimension = div(length(levels), 2) - else - dimension = min(length(levels), dim) # Dummy function for now - end - return new(Dense(length(levels), dimension), levels), dimension - end - -end - -Flux.@treelike EmbeddingMatrix - -function (embed::EmbeddingMatrix)(ip) - return embed.e(Flux.onehot(ip, embed.levels)) -end - -mutable struct EntityEmbedding - embeddingmatrix - - function EntityEmbedding(a...) - return new(a) - end -end - -Flux.@treelike EntityEmbedding - - -# ip is an array of tuples -function (embed::EntityEmbedding)(ip) - return hcat((vcat((embed.embeddingmatrix[i](ip[idx][i]) for i=1:length(ip[idx]))...) for idx =1:length(ip))...) -end - - -# Q1. How should this be called in the API? -# nn = NeuralNetworkClassifier(builder=builder, optimiser = .., embeddingdimension = 5) -# -# -#