-
Notifications
You must be signed in to change notification settings - Fork 17
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
✅ Add CategoricalEmbedding layer with documentation and tests
- Loading branch information
1 parent
70dff6e
commit fb98c27
Showing
5 changed files
with
259 additions
and
37 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,50 +1,74 @@ | ||
# This is just some experimental code | ||
# to implement EntityEmbeddings for purely | ||
# categorical features | ||
""" | ||
A layer that implements entity embedding layers as presented in 'Entity Embeddings of | ||
Categorical Variables by Cheng Guo, Felix Berkhahn'. Expects a matrix of dimensions (numfeats, batchsize) | ||
and applies entity embeddings to each specified categorical feature. Other features will be left as is. | ||
# Arguments | ||
- `entityprops`: a vector of named tuples each of the form `(index=..., levels=..., newdim=...)` to | ||
specify the feature index, the number of levels and the desired embeddings dimensionality for selected features of the input. | ||
- `numfeats`: the number of features in the input. | ||
using Flux | ||
# Example | ||
```julia | ||
# Prepare a batch of four features where the 2nd and the 4th are categorical | ||
batch = [ | ||
0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0 1.1; | ||
1 2 3 4 5 6 7 8 9 10; | ||
0.9 0.1 0.4 0.5 0.3 0.7 0.8 0.9 1.0 1.1 | ||
1 1 2 2 1 1 2 2 1 1; | ||
] | ||
mutable struct EmbeddingMatrix | ||
e | ||
levels | ||
entityprops = [ | ||
(index=2, levels=10, newdim=2), | ||
(index=4, levels=2, newdim=1) | ||
] | ||
numfeats = 4 | ||
function EmbeddingMatrix(levels; dim=4) | ||
if dim <= 0 | ||
dimension = div(length(levels), 2) | ||
else | ||
dimension = min(length(levels), dim) # Dummy function for now | ||
end | ||
return new(Dense(length(levels), dimension), levels), dimension | ||
end | ||
# Run it through the categorical embedding layer | ||
embedder = CategoricalEmbedder(entityprops, 4) | ||
julia> output = embedder(batch) | ||
5×10 Matrix{Float64}: | ||
0.2 0.3 0.4 0.5 … 0.8 0.9 1.0 1.1 | ||
-1.27129 -0.417667 -1.40326 -0.695701 0.371741 1.69952 -1.40034 -2.04078 | ||
-0.166796 0.657619 -0.659249 -0.337757 -0.717179 -0.0176273 -1.2817 -0.0372752 | ||
0.9 0.1 0.4 0.5 0.8 0.9 1.0 1.1 | ||
-0.847354 -0.847354 -1.66261 -1.66261 -1.66261 -1.66261 -0.847354 -0.847354 | ||
``` | ||
""" | ||
|
||
# 1. Define layer struct to hold parameters | ||
struct CategoricalEmbedder{A1 <: AbstractVector, A2 <: AbstractVector, I <: Integer} | ||
embedders::A1 | ||
modifiers::A2 | ||
numfeats::I | ||
end | ||
|
||
Flux.@treelike EmbeddingMatrix | ||
# 2. Define the forward pass (i.e., calling an instance of the layer) | ||
(m::CategoricalEmbedder)(x) = vcat([ m.embedders[i](m.modifiers[i](x, i)) for i in 1:m.numfeats]...) | ||
|
||
function (embed::EmbeddingMatrix)(ip) | ||
return embed.e(Flux.onehot(ip, embed.levels)) | ||
end | ||
|
||
mutable struct EntityEmbedding | ||
embeddingmatrix | ||
# 3. Define the constructor which initializes the parameters and returns the instance | ||
function CategoricalEmbedder(entityprops, numfeats; init=Flux.randn32) | ||
embedders = [] | ||
modifiers = [] | ||
|
||
function EntityEmbedding(a...) | ||
return new(a) | ||
cat_inds = [entityprop.index for entityprop in entityprops] | ||
levels_per_feat = [entityprop.levels for entityprop in entityprops] | ||
newdims = [entityprop.newdim for entityprop in entityprops] | ||
|
||
c = 1 | ||
for i in 1:numfeats | ||
if i in cat_inds | ||
push!(embedders, Flux.Embedding(levels_per_feat[c] => newdims[c], init=init)) | ||
push!(modifiers, (x, i) -> Int.(x[i, :])) | ||
c += 1 | ||
else | ||
push!(embedders, feat->feat) | ||
push!(modifiers, (x, i) -> x[i:i, :]) | ||
end | ||
end | ||
end | ||
|
||
Flux.@treelike EntityEmbedding | ||
|
||
|
||
# ip is an array of tuples | ||
function (embed::EntityEmbedding)(ip) | ||
return hcat((vcat((embed.embeddingmatrix[i](ip[idx][i]) for i=1:length(ip[idx]))...) for idx =1:length(ip))...) | ||
CategoricalEmbedder(embedders, modifiers, numfeats) | ||
end | ||
|
||
|
||
# Q1. How should this be called in the API? | ||
# nn = NeuralNetworkClassifier(builder=builder, optimiser = .., embeddingdimension = 5) | ||
# | ||
# | ||
# | ||
# 4. Register it as layer with Flux | ||
Flux.@layer CategoricalEmbedder |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
# This is just some experimental code | ||
# to implement EntityEmbeddings for purely | ||
# categorical features | ||
|
||
|
||
using Flux | ||
|
||
mutable struct EmbeddingMatrix | ||
e | ||
levels | ||
|
||
function EmbeddingMatrix(levels; dim=4) | ||
if dim <= 0 | ||
dimension = div(length(levels), 2) | ||
else | ||
dimension = min(length(levels), dim) # Dummy function for now | ||
end | ||
return new(Dense(length(levels), dimension), levels), dimension | ||
end | ||
|
||
end | ||
|
||
Flux.@treelike EmbeddingMatrix | ||
|
||
function (embed::EmbeddingMatrix)(ip) | ||
return embed.e(Flux.onehot(ip, embed.levels)) | ||
end | ||
|
||
mutable struct EntityEmbedding | ||
embeddingmatrix | ||
|
||
function EntityEmbedding(a...) | ||
return new(a) | ||
end | ||
end | ||
|
||
Flux.@treelike EntityEmbedding | ||
|
||
|
||
# ip is an array of tuples | ||
function (embed::EntityEmbedding)(ip) | ||
return hcat((vcat((embed.embeddingmatrix[i](ip[idx][i]) for i=1:length(ip[idx]))...) for idx =1:length(ip))...) | ||
end | ||
|
||
|
||
# Q1. How should this be called in the API? | ||
# nn = NeuralNetworkClassifier(builder=builder, optimiser = .., embeddingdimension = 5) | ||
# | ||
# | ||
# |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,142 @@ | ||
batch = [ | ||
0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0 1.1; | ||
1 2 3 4 5 6 7 8 9 10; | ||
0.9 0.1 0.4 0.5 0.3 0.7 0.8 0.9 1.0 1.1 | ||
1 1 2 2 1 1 2 2 1 1; | ||
] | ||
|
||
entityprops = [ | ||
(index=2, levels=10, newdim=2), | ||
(index=4, levels=2, newdim=1) | ||
] | ||
|
||
|
||
@testset "Feedforward with Entity Embedder Works" begin | ||
### Option 1: Use CategoricalEmbedder | ||
entityprops = [ | ||
(index=2, levels=10, newdim=5), | ||
(index=4, levels=2, newdim=2) | ||
] | ||
|
||
embedder = CategoricalEmbedder(entityprops, 4) | ||
|
||
output = embedder(batch) | ||
|
||
### Option 2: Manual feedforward | ||
x1 = batch[1:1,:] | ||
z2 = Int.(batch[2,:]) | ||
x3 = batch[3:3,:] | ||
z4 = Int.(batch[4,:]) | ||
|
||
# extract matrices from categorical embedder | ||
EE1 = Flux.params(embedder.embedders[2])[1] # (newdim, levels) = (5, 10) | ||
EE2 = Flux.params(embedder.embedders[4])[1] # (newdim, levels) = (2, 2) | ||
|
||
## One-hot encoding | ||
z2_hot = Flux.onehotbatch(z2, levels(z2)) | ||
z4_hot = Flux.onehotbatch(z4, levels(z4)) | ||
|
||
function feedforward(x1, z2_hot, x3, z4_hot) | ||
f_z2 = EE1 * z2_hot | ||
f_z4 = EE2 * z4_hot | ||
return vcat([x1, f_z2, x3, f_z4]...) | ||
end | ||
|
||
real_output = feedforward(x1, z2_hot, x3, z4_hot) | ||
@test output ≈ real_output | ||
end | ||
|
||
|
||
@testset "Feedforward and Backward Pass with Entity Embedder Works" begin | ||
y_batch_reg = [0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0] # Regression | ||
y_batch_cls = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] # Classification | ||
y_batch_cls_o = Flux.onehotbatch(y_batch_cls, 1:10) | ||
|
||
losses = [Flux.crossentropy, Flux.mse] | ||
targets = [y_batch_cls_o, y_batch_reg] | ||
finalizer = [softmax, relu] | ||
|
||
for ind in 1:2 | ||
### Option 1: Feedforward with CategoricalEmbedder in the network | ||
entityprops = [ | ||
(index=2, levels=10, newdim=5), | ||
(index=4, levels=2, newdim=2) | ||
] | ||
|
||
cat_model = Chain( | ||
CategoricalEmbedder(entityprops, 4), | ||
Dense(9 => (ind == 1) ? 10 : 1), | ||
finalizer[ind], | ||
) | ||
|
||
EE1_before = Flux.params(cat_model.layers[1].embedders[2])[1] | ||
EE2_before = Flux.params(cat_model.layers[1].embedders[4])[1] | ||
W_before = Flux.params(cat_model.layers[2])[1] | ||
|
||
### Test with obvious equivalent feedforward | ||
x1 = batch[1:1,:] | ||
z2 = Int.(batch[2,:]) | ||
x3 = batch[3:3,:] | ||
z4 = Int.(batch[4,:]) | ||
|
||
z2_hot = Flux.onehotbatch(z2, levels(z2)) | ||
z4_hot = Flux.onehotbatch(z4, levels(z4)) | ||
|
||
### Option 2: Manual feedforward | ||
function feedforward(x1, z2_hot, x3, z4_hot, W, EE1, EE2) | ||
f_z2 = EE1 * z2_hot | ||
f_z4 = EE2 * z4_hot | ||
return finalizer[ind](W * vcat([x1, f_z2, x3, f_z4]...)) | ||
end | ||
|
||
struct ObviousNetwork | ||
W | ||
EE1 | ||
EE2 | ||
end | ||
|
||
(m::ObviousNetwork)(x1, z2_hot, x3, z4_hot) = feedforward(x1, z2_hot, x3, z4_hot, m.W, m.EE1, m.EE2) | ||
Flux.@layer ObviousNetwork | ||
|
||
W_before_cp, EE1_before_cp, EE2_before_cp = deepcopy(W_before), deepcopy(EE1_before), deepcopy(EE2_before) | ||
net = ObviousNetwork(W_before_cp, EE1_before_cp, EE2_before_cp) | ||
|
||
@test feedforward(x1, z2_hot, x3, z4_hot, W_before, EE1_before, EE2_before) ≈ cat_model(batch) | ||
|
||
## Option 1: Backward with CategoricalEmbedder in the network | ||
loss, grads = Flux.withgradient(cat_model) do m | ||
y_pred_cls = m(batch) | ||
losses[ind](y_pred_cls, targets[ind]) | ||
end | ||
optim = Flux.setup(Flux.Adam(10), cat_model) | ||
new_params = Flux.update!(optim, cat_model, grads[1]) | ||
|
||
EE1_after = Flux.params(new_params[1].layers[1].embedders[2].weight)[1] | ||
EE2_after = Flux.params(new_params[1].layers[1].embedders[4].weight)[1] | ||
W_after = Flux.params(new_params[1].layers[2].weight)[1] | ||
|
||
## Option 2: Backward with ObviousNetwork | ||
loss, grads = Flux.withgradient(net) do m | ||
y_pred_cls = m(x1, z2_hot, x3, z4_hot) | ||
losses[ind](y_pred_cls, targets[ind]) | ||
end | ||
|
||
optim = Flux.setup(Flux.Adam(10), net) | ||
z = Flux.update!(optim, net, grads[1]) | ||
EE1_after_cp = Flux.params(z[1].EE1)[1] | ||
EE2_after_cp = Flux.params(z[1].EE2)[1] | ||
W_after_cp =Flux.params(z[1].W)[1] | ||
@test EE1_after_cp ≈ EE1_after | ||
@test EE2_after_cp ≈ EE2_after | ||
@test W_after_cp ≈ W_after | ||
end | ||
end | ||
|
||
|
||
@testset "Transparent when no categorical variables" begin | ||
entityprops = [] | ||
numfeats = 4 | ||
embedder = CategoricalEmbedder(entityprops, 4) | ||
output = embedder(batch) | ||
@test output == batch | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters