Skip to content

Commit

Permalink
✅ Add CategoricalEmbedding layer with documentation and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
EssamWisam committed Aug 3, 2024
1 parent 70dff6e commit fb98c27
Show file tree
Hide file tree
Showing 5 changed files with 259 additions and 37 deletions.
2 changes: 2 additions & 0 deletions src/MLJFlux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,12 @@ include("regressor.jl")
include("classifier.jl")
include("image.jl")
include("mlj_model_interface.jl")
include("entity_embedding.jl")

export NeuralNetworkRegressor, MultitargetNeuralNetworkRegressor
export NeuralNetworkClassifier, NeuralNetworkBinaryClassifier, ImageClassifier
export CUDALibs, CPU1
export CategoricalEmbedder

include("deprecated.jl")

Expand Down
98 changes: 61 additions & 37 deletions src/entity_embedding.jl
Original file line number Diff line number Diff line change
@@ -1,50 +1,74 @@
# This is just some experimental code
# to implement EntityEmbeddings for purely
# categorical features
"""
A layer that implements entity embedding layers as presented in 'Entity Embeddings of
Categorical Variables by Cheng Guo, Felix Berkhahn'. Expects a matrix of dimensions (numfeats, batchsize)
and applies entity embeddings to each specified categorical feature. Other features will be left as is.
# Arguments
- `entityprops`: a vector of named tuples each of the form `(index=..., levels=..., newdim=...)` to
specify the feature index, the number of levels and the desired embeddings dimensionality for selected features of the input.
- `numfeats`: the number of features in the input.
using Flux
# Example
```julia
# Prepare a batch of four features where the 2nd and the 4th are categorical
batch = [
0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0 1.1;
1 2 3 4 5 6 7 8 9 10;
0.9 0.1 0.4 0.5 0.3 0.7 0.8 0.9 1.0 1.1
1 1 2 2 1 1 2 2 1 1;
]
mutable struct EmbeddingMatrix
e
levels
entityprops = [
(index=2, levels=10, newdim=2),
(index=4, levels=2, newdim=1)
]
numfeats = 4
function EmbeddingMatrix(levels; dim=4)
if dim <= 0
dimension = div(length(levels), 2)
else
dimension = min(length(levels), dim) # Dummy function for now
end
return new(Dense(length(levels), dimension), levels), dimension
end
# Run it through the categorical embedding layer
embedder = CategoricalEmbedder(entityprops, 4)
julia> output = embedder(batch)
5×10 Matrix{Float64}:
0.2 0.3 0.4 0.5 … 0.8 0.9 1.0 1.1
-1.27129 -0.417667 -1.40326 -0.695701 0.371741 1.69952 -1.40034 -2.04078
-0.166796 0.657619 -0.659249 -0.337757 -0.717179 -0.0176273 -1.2817 -0.0372752
0.9 0.1 0.4 0.5 0.8 0.9 1.0 1.1
-0.847354 -0.847354 -1.66261 -1.66261 -1.66261 -1.66261 -0.847354 -0.847354
```
"""

# 1. Define layer struct to hold parameters
struct CategoricalEmbedder{A1 <: AbstractVector, A2 <: AbstractVector, I <: Integer}
embedders::A1
modifiers::A2
numfeats::I
end

Flux.@treelike EmbeddingMatrix
# 2. Define the forward pass (i.e., calling an instance of the layer)
(m::CategoricalEmbedder)(x) = vcat([ m.embedders[i](m.modifiers[i](x, i)) for i in 1:m.numfeats]...)

function (embed::EmbeddingMatrix)(ip)
return embed.e(Flux.onehot(ip, embed.levels))
end

mutable struct EntityEmbedding
embeddingmatrix
# 3. Define the constructor which initializes the parameters and returns the instance
function CategoricalEmbedder(entityprops, numfeats; init=Flux.randn32)
embedders = []
modifiers = []

function EntityEmbedding(a...)
return new(a)
cat_inds = [entityprop.index for entityprop in entityprops]
levels_per_feat = [entityprop.levels for entityprop in entityprops]
newdims = [entityprop.newdim for entityprop in entityprops]

c = 1
for i in 1:numfeats
if i in cat_inds
push!(embedders, Flux.Embedding(levels_per_feat[c] => newdims[c], init=init))
push!(modifiers, (x, i) -> Int.(x[i, :]))
c += 1
else
push!(embedders, feat->feat)
push!(modifiers, (x, i) -> x[i:i, :])
end
end
end

Flux.@treelike EntityEmbedding


# ip is an array of tuples
function (embed::EntityEmbedding)(ip)
return hcat((vcat((embed.embeddingmatrix[i](ip[idx][i]) for i=1:length(ip[idx]))...) for idx =1:length(ip))...)
CategoricalEmbedder(embedders, modifiers, numfeats)
end


# Q1. How should this be called in the API?
# nn = NeuralNetworkClassifier(builder=builder, optimiser = .., embeddingdimension = 5)
#
#
#
# 4. Register it as layer with Flux
Flux.@layer CategoricalEmbedder
50 changes: 50 additions & 0 deletions src/entity_embedding_old.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# This is just some experimental code
# to implement EntityEmbeddings for purely
# categorical features


using Flux

mutable struct EmbeddingMatrix
e
levels

function EmbeddingMatrix(levels; dim=4)
if dim <= 0
dimension = div(length(levels), 2)

Check warning on line 14 in src/entity_embedding_old.jl

View check run for this annotation

Codecov / codecov/patch

src/entity_embedding_old.jl#L12-L14

Added lines #L12 - L14 were not covered by tests
else
dimension = min(length(levels), dim) # Dummy function for now

Check warning on line 16 in src/entity_embedding_old.jl

View check run for this annotation

Codecov / codecov/patch

src/entity_embedding_old.jl#L16

Added line #L16 was not covered by tests
end
return new(Dense(length(levels), dimension), levels), dimension

Check warning on line 18 in src/entity_embedding_old.jl

View check run for this annotation

Codecov / codecov/patch

src/entity_embedding_old.jl#L18

Added line #L18 was not covered by tests
end

end

Flux.@treelike EmbeddingMatrix

function (embed::EmbeddingMatrix)(ip)
return embed.e(Flux.onehot(ip, embed.levels))

Check warning on line 26 in src/entity_embedding_old.jl

View check run for this annotation

Codecov / codecov/patch

src/entity_embedding_old.jl#L25-L26

Added lines #L25 - L26 were not covered by tests
end

mutable struct EntityEmbedding
embeddingmatrix

function EntityEmbedding(a...)
return new(a)

Check warning on line 33 in src/entity_embedding_old.jl

View check run for this annotation

Codecov / codecov/patch

src/entity_embedding_old.jl#L32-L33

Added lines #L32 - L33 were not covered by tests
end
end

Flux.@treelike EntityEmbedding


# ip is an array of tuples
function (embed::EntityEmbedding)(ip)
return hcat((vcat((embed.embeddingmatrix[i](ip[idx][i]) for i=1:length(ip[idx]))...) for idx =1:length(ip))...)

Check warning on line 42 in src/entity_embedding_old.jl

View check run for this annotation

Codecov / codecov/patch

src/entity_embedding_old.jl#L41-L42

Added lines #L41 - L42 were not covered by tests
end


# Q1. How should this be called in the API?
# nn = NeuralNetworkClassifier(builder=builder, optimiser = .., embeddingdimension = 5)
#
#
#
142 changes: 142 additions & 0 deletions test/entity_embedding.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
batch = [
0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0 1.1;
1 2 3 4 5 6 7 8 9 10;
0.9 0.1 0.4 0.5 0.3 0.7 0.8 0.9 1.0 1.1
1 1 2 2 1 1 2 2 1 1;
]

entityprops = [
(index=2, levels=10, newdim=2),
(index=4, levels=2, newdim=1)
]


@testset "Feedforward with Entity Embedder Works" begin
### Option 1: Use CategoricalEmbedder
entityprops = [
(index=2, levels=10, newdim=5),
(index=4, levels=2, newdim=2)
]

embedder = CategoricalEmbedder(entityprops, 4)

output = embedder(batch)

### Option 2: Manual feedforward
x1 = batch[1:1,:]
z2 = Int.(batch[2,:])
x3 = batch[3:3,:]
z4 = Int.(batch[4,:])

# extract matrices from categorical embedder
EE1 = Flux.params(embedder.embedders[2])[1] # (newdim, levels) = (5, 10)
EE2 = Flux.params(embedder.embedders[4])[1] # (newdim, levels) = (2, 2)

## One-hot encoding
z2_hot = Flux.onehotbatch(z2, levels(z2))
z4_hot = Flux.onehotbatch(z4, levels(z4))

function feedforward(x1, z2_hot, x3, z4_hot)
f_z2 = EE1 * z2_hot
f_z4 = EE2 * z4_hot
return vcat([x1, f_z2, x3, f_z4]...)
end

real_output = feedforward(x1, z2_hot, x3, z4_hot)
@test output real_output
end


@testset "Feedforward and Backward Pass with Entity Embedder Works" begin
y_batch_reg = [0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0] # Regression
y_batch_cls = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] # Classification
y_batch_cls_o = Flux.onehotbatch(y_batch_cls, 1:10)

losses = [Flux.crossentropy, Flux.mse]
targets = [y_batch_cls_o, y_batch_reg]
finalizer = [softmax, relu]

for ind in 1:2
### Option 1: Feedforward with CategoricalEmbedder in the network
entityprops = [
(index=2, levels=10, newdim=5),
(index=4, levels=2, newdim=2)
]

cat_model = Chain(
CategoricalEmbedder(entityprops, 4),
Dense(9 => (ind == 1) ? 10 : 1),
finalizer[ind],
)

EE1_before = Flux.params(cat_model.layers[1].embedders[2])[1]
EE2_before = Flux.params(cat_model.layers[1].embedders[4])[1]
W_before = Flux.params(cat_model.layers[2])[1]

### Test with obvious equivalent feedforward
x1 = batch[1:1,:]
z2 = Int.(batch[2,:])
x3 = batch[3:3,:]
z4 = Int.(batch[4,:])

z2_hot = Flux.onehotbatch(z2, levels(z2))
z4_hot = Flux.onehotbatch(z4, levels(z4))

### Option 2: Manual feedforward
function feedforward(x1, z2_hot, x3, z4_hot, W, EE1, EE2)
f_z2 = EE1 * z2_hot
f_z4 = EE2 * z4_hot
return finalizer[ind](W * vcat([x1, f_z2, x3, f_z4]...))
end

struct ObviousNetwork
W
EE1
EE2
end

(m::ObviousNetwork)(x1, z2_hot, x3, z4_hot) = feedforward(x1, z2_hot, x3, z4_hot, m.W, m.EE1, m.EE2)
Flux.@layer ObviousNetwork

W_before_cp, EE1_before_cp, EE2_before_cp = deepcopy(W_before), deepcopy(EE1_before), deepcopy(EE2_before)
net = ObviousNetwork(W_before_cp, EE1_before_cp, EE2_before_cp)

@test feedforward(x1, z2_hot, x3, z4_hot, W_before, EE1_before, EE2_before) cat_model(batch)

## Option 1: Backward with CategoricalEmbedder in the network
loss, grads = Flux.withgradient(cat_model) do m
y_pred_cls = m(batch)
losses[ind](y_pred_cls, targets[ind])
end
optim = Flux.setup(Flux.Adam(10), cat_model)
new_params = Flux.update!(optim, cat_model, grads[1])

EE1_after = Flux.params(new_params[1].layers[1].embedders[2].weight)[1]
EE2_after = Flux.params(new_params[1].layers[1].embedders[4].weight)[1]
W_after = Flux.params(new_params[1].layers[2].weight)[1]

## Option 2: Backward with ObviousNetwork
loss, grads = Flux.withgradient(net) do m
y_pred_cls = m(x1, z2_hot, x3, z4_hot)
losses[ind](y_pred_cls, targets[ind])
end

optim = Flux.setup(Flux.Adam(10), net)
z = Flux.update!(optim, net, grads[1])
EE1_after_cp = Flux.params(z[1].EE1)[1]
EE2_after_cp = Flux.params(z[1].EE2)[1]
W_after_cp =Flux.params(z[1].W)[1]
@test EE1_after_cp EE1_after
@test EE2_after_cp EE2_after
@test W_after_cp W_after
end
end


@testset "Transparent when no categorical variables" begin
entityprops = []
numfeats = 4
embedder = CategoricalEmbedder(entityprops, 4)
output = embedder(batch)
@test output == batch
end
4 changes: 4 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,7 @@ end
@conditional_testset "integration" begin
include("integration.jl")
end

@conditional_testset "entity_embedding.jl" begin
include("entity_embedding.jl")
end

0 comments on commit fb98c27

Please sign in to comment.