Skip to content

Commit

Permalink
✅ Get rid of case distinction
Browse files Browse the repository at this point in the history
  • Loading branch information
EssamWisam committed Sep 1, 2024
1 parent e5d8141 commit f60f789
Show file tree
Hide file tree
Showing 7 changed files with 47 additions and 69 deletions.
2 changes: 0 additions & 2 deletions src/classifier.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ function MLJFlux.shape(model::NeuralNetworkClassifier, X, y)
n_input = Tables.schema(X).names |> length
return (n_input, n_output)
end
is_embedding_enabled_type(::NeuralNetworkClassifier) = true

# builds the end-to-end Flux chain needed, given the `model` and `shape`:
MLJFlux.build(
Expand Down Expand Up @@ -60,7 +59,6 @@ function MLJFlux.shape(model::NeuralNetworkBinaryClassifier, X, y)
n_input = Tables.schema(X).names |> length
return (n_input, 1) # n_output is always 1 for a binary classifier
end
is_embedding_enabled_type(::NeuralNetworkBinaryClassifier) = true

function MLJModelInterface.predict(
model::NeuralNetworkBinaryClassifier,
Expand Down
2 changes: 1 addition & 1 deletion src/encoders.jl
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ each level in each categorical columns using the columns of the matrix.
This is used with the embedding matrices of the entity embedding layer in entity enabled models to implement entity embeddings.
"""
function embedding_transform(X, mapping_matrices)
isnothing(mapping_matrices) && return X
(isempty(mapping_matrices)) && return X
feat_names = Tables.schema(X).names
new_feat_names = Symbol[]
new_cols = []
Expand Down
9 changes: 6 additions & 3 deletions src/entity_embedding_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ Returns the number of levels in each categorical column in the table `X`.
function get_num_levels(X, cat_inds)
num_levels = []
for i in cat_inds
num_levels = push!(num_levels, length(levels(Tables.getcolumn(Tables.columns(X), i))))
num_levels =
push!(num_levels, length(levels(Tables.getcolumn(Tables.columns(X), i))))
end
return num_levels
end
Expand All @@ -79,7 +80,8 @@ function prepare_entityembs(X, featnames, cat_inds, embedding_dims)
i in eachindex(cat_inds)
]
# 2. Compute entityemb_output_dim
entityemb_output_dim = sum(newdims) + numfeats - length(cat_inds)
sum_newdims = length(newdims) == 0 ? 0 : sum(newdims)
entityemb_output_dim = sum_newdims + numfeats - length(cat_inds)
return entityprops, entityemb_output_dim
end

Expand Down Expand Up @@ -125,7 +127,8 @@ function MLJModelInterface.transform(
fitresult,
Xnew,
)
is_embedding_enabled_type(transformer) || return Xnew
# if it doesn't have the property its not an entity-enabled model
hasproperty(transformer, :embedding_dims) || return Xnew
ordinal_mappings, embedding_matrices = fitresult[3:4]
Xnew = ordinal_encoder_transform(Xnew, ordinal_mappings)
Xnew_transf = embedding_transform(Xnew, embedding_matrices)
Expand Down
3 changes: 1 addition & 2 deletions src/image.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,13 @@ function shape(model::ImageClassifier, X, y)
end
return (n_input, n_output, n_channels)
end
is_embedding_enabled_type(::ImageClassifier) = false


build(model::ImageClassifier, rng, shape) =
Flux.Chain(build(model.builder, rng, shape...),
model.finaliser)

fitresult(model::ImageClassifier, chain, y) =
fitresult(model::ImageClassifier, chain, y, ordinal_mappings=nothing, embedding_matrices=nothing) =
(chain, MLJModelInterface.classes(y[1]))

function MLJModelInterface.predict(model::ImageClassifier, fitresult, Xnew)
Expand Down
82 changes: 39 additions & 43 deletions src/mlj_model_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,7 @@ end

# # FIT AND UPDATE

const ERR_BUILDER =
"Builder does not appear to build an architecture compatible with supplied data. "
const ERR_BUILDER = "Builder does not appear to build an architecture compatible with supplied data. "

true_rng(model) = model.rng isa Integer ? Random.Xoshiro(model.rng) : model.rng

Expand All @@ -76,19 +75,22 @@ function MLJModelInterface.fit(model::MLJFluxModel,
pure_continuous_input = isempty(cat_inds)

# Decide whether to enable entity embeddings (e.g., ImageClassifier won't)
enable_entity_embs = is_embedding_enabled_type(model) && !pure_continuous_input
enable_entity_embs = hasproperty(model, :embedding_dims) && !pure_continuous_input

# Prepare entity embeddings inputs and encode X if entity embeddings enabled
featnames = []
if enable_entity_embs
X = convert_to_table(X)
featnames = Tables.schema(X).names
# entityprops is (index = cat_inds[i], levels = num_levels[i], newdim = newdims[i])
# for each categorical feature
entityprops, entityemb_output_dim =
prepare_entityembs(X, featnames, cat_inds, model.embedding_dims)
X, ordinal_mappings = ordinal_encoder_fit_transform(X; featinds = cat_inds)
end

# entityprops is (index = cat_inds[i], levels = num_levels[i], newdim = newdims[i])
# for each categorical feature
default_embedding_dims = enable_entity_embs ? model.embedding_dims : Dict{Symbol, Real}()
entityprops, entityemb_output_dim =
prepare_entityembs(X, featnames, cat_inds, default_embedding_dims)
X, ordinal_mappings = ordinal_encoder_fit_transform(X; featinds = cat_inds)

## Construct model chain
chain =
(!enable_entity_embs) ? construct_model_chain(model, rng, shape, move) :
Expand Down Expand Up @@ -122,6 +124,9 @@ function MLJModelInterface.fit(model::MLJFluxModel,
data[2],
)

# Extract embedding matrices
embedding_matrices = get_embedding_matrices(chain, cat_inds, featnames)

# Prepare cache for potential warm restarts
cache = (
deepcopy(model),
Expand All @@ -132,27 +137,19 @@ function MLJModelInterface.fit(model::MLJFluxModel,
optimiser_state,
deepcopy(rng),
move,
entityprops,
entityemb_output_dim,
ordinal_mappings,
featnames,
)

# Extract embedding matrices
enable_entity_embs &&
(embedding_matrices = get_embedding_matrices(chain, cat_inds, featnames))

# Prepare fitresult
fitresult_args = (model, Flux.cpu(chain), y)
fitresult =
MLJFlux.fitresult(model, Flux.cpu(chain), y, ordinal_mappings, embedding_matrices)

# Prepare report
report = (training_losses = history,)

# Modify cache and fitresult if entity embeddings enabled
if enable_entity_embs
cache = (cache..., entityprops, entityemb_output_dim, ordinal_mappings, featnames)
fitresult =
MLJFlux.fitresult(fitresult_args..., ordinal_mappings, embedding_matrices)
else
fitresult = MLJFlux.fitresult(fitresult_args...,)
end

return fitresult, cache, report
end

Expand All @@ -165,15 +162,22 @@ function MLJModelInterface.update(model::MLJFluxModel,
# Decide whether to enable entity embeddings (e.g., ImageClassifier won't)
cat_inds = get_cat_inds(X)
pure_continuous_input = (length(cat_inds) == 0)
enable_entity_embs = is_embedding_enabled_type(model) && !pure_continuous_input
enable_entity_embs = hasproperty(model, :embedding_dims) && !pure_continuous_input

# Unpack cache from previous fit
old_model, data, old_history, shape, regularized_optimiser, optimiser_state, rng, move =
old_cache[1:8]
if enable_entity_embs
entityprops, entityemb_output_dim, ordinal_mappings, featnames = old_cache[9:12]
cat_inds = [prop.index for prop in entityprops]
end
old_model,
data,
old_history,
shape,
regularized_optimiser,
optimiser_state,
rng,
move,
entityprops,
entityemb_output_dim,
ordinal_mappings,
featnames = old_cache
cat_inds = [prop.index for prop in entityprops]

# Extract chain
old_chain = old_fitresult[1]
Expand All @@ -196,6 +200,8 @@ function MLJModelInterface.update(model::MLJFluxModel,
else
move = Mover(model.acceleration)
rng = true_rng(model)
X = convert_to_table(X)
X = ordinal_encoder_transform(X, ordinal_mappings)
if enable_entity_embs
chain =
construct_model_chain_with_entityembs(
Expand All @@ -206,8 +212,6 @@ function MLJModelInterface.update(model::MLJFluxModel,
entityprops,
entityemb_output_dim,
)
X = convert_to_table(X)
X = ordinal_encoder_transform(X, ordinal_mappings)
else
chain = construct_model_chain(model, rng, shape, move)
end
Expand Down Expand Up @@ -237,8 +241,7 @@ function MLJModelInterface.update(model::MLJFluxModel,
end

# Extract embedding matrices
enable_entity_embs &&
(embedding_matrices = get_embedding_matrices(chain, cat_inds, featnames))
embedding_matrices = get_embedding_matrices(chain, cat_inds, featnames)

# Prepare cache, fitresult, and report
cache = (
Expand All @@ -250,17 +253,10 @@ function MLJModelInterface.update(model::MLJFluxModel,
optimiser_state,
deepcopy(rng),
move,
entityprops, entityemb_output_dim, ordinal_mappings, featnames,
)

fitresult_args = (model, Flux.cpu(chain), y)
if enable_entity_embs
cache = (cache..., entityprops, entityemb_output_dim, ordinal_mappings, featnames)
fitresult =
MLJFlux.fitresult(fitresult_args..., ordinal_mappings, embedding_matrices)
else
fitresult = MLJFlux.fitresult(fitresult_args...)
end

fitresult =
MLJFlux.fitresult(model, Flux.cpu(chain), y, ordinal_mappings, embedding_matrices)
report = (training_losses = history,)

return fitresult, cache, report
Expand Down
2 changes: 0 additions & 2 deletions src/regressor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ function shape(model::NeuralNetworkRegressor, X, y)
return (n_input, 1)
end

is_embedding_enabled_type(::NeuralNetworkRegressor) = true

build(model::NeuralNetworkRegressor, rng, shape) =
build(model.builder, rng, shape...)
Expand Down Expand Up @@ -50,7 +49,6 @@ A private method that returns the shape of the input and output of the model for
data `X` and `y`.
"""
shape(model::MultitargetNeuralNetworkRegressor, X, y) = (ncols(X), ncols(y))
is_embedding_enabled_type(::MultitargetNeuralNetworkRegressor) = true

build(model::MultitargetNeuralNetworkRegressor, rng, shape) =
build(model.builder, rng, shape...)
Expand Down
16 changes: 0 additions & 16 deletions test/entity_embedding_utils.jl
Original file line number Diff line number Diff line change
@@ -1,20 +1,4 @@

@testset "Embedding Enabled Types" begin
clf = MLJFlux.NeuralNetworkClassifier(
builder = MLJFlux.Short(n_hidden = 5, dropout = 0.2),
optimiser = Optimisers.Adam(0.01),
batch_size = 8,
epochs = 100,
)
@test MLJFlux.is_embedding_enabled_type(clf)

clf = MLJFlux.ImageClassifier(
batch_size = 50,
epochs = 10,
rng = 123,
)
@test !MLJFlux.is_embedding_enabled_type(clf)
end


@testset "set_default_new_embedding_dim" begin
Expand Down

0 comments on commit f60f789

Please sign in to comment.