diff --git a/src/entity_embedding_old.jl b/src/entity_embedding_old.jl deleted file mode 100644 index b87a79b..0000000 --- a/src/entity_embedding_old.jl +++ /dev/null @@ -1,50 +0,0 @@ -# This is just some experimental code -# to implement EntityEmbeddings for purely -# categorical features - - -using Flux - -mutable struct EmbeddingMatrix - e - levels - - function EmbeddingMatrix(levels; dim=4) - if dim <= 0 - dimension = div(length(levels), 2) - else - dimension = min(length(levels), dim) # Dummy function for now - end - return new(Dense(length(levels), dimension), levels), dimension - end - -end - -Flux.@treelike EmbeddingMatrix - -function (embed::EmbeddingMatrix)(ip) - return embed.e(Flux.onehot(ip, embed.levels)) -end - -mutable struct EntityEmbedding - embeddingmatrix - - function EntityEmbedding(a...) - return new(a) - end -end - -Flux.@treelike EntityEmbedding - - -# ip is an array of tuples -function (embed::EntityEmbedding)(ip) - return hcat((vcat((embed.embeddingmatrix[i](ip[idx][i]) for i=1:length(ip[idx]))...) for idx =1:length(ip))...) -end - - -# Q1. How should this be called in the API? -# nn = NeuralNetworkClassifier(builder=builder, optimiser = .., embeddingdimension = 5) -# -# -#