Skip to content

Commit

Permalink
fix some tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ablaom committed Sep 16, 2024
1 parent 1cfcd1e commit 09b17da
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 18 deletions.
16 changes: 8 additions & 8 deletions src/mlj_model_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,35 +66,35 @@ function MLJModelInterface.fit(model::MLJFluxModel,
X,
y)
# GPU and rng related variables
move = Mover(model.acceleration)
move = MLJFlux.Mover(model.acceleration)
rng = true_rng(model)

# Get input properties
shape = MLJFlux.shape(model, X, y)
cat_inds = get_cat_inds(X)
cat_inds = MLJFlux.get_cat_inds(X)
pure_continuous_input = isempty(cat_inds)

# Decide whether to enable entity embeddings (e.g., ImageClassifier won't)
enable_entity_embs = is_embedding_enabled(model) && !pure_continuous_input
enable_entity_embs = MLJFlux.is_embedding_enabled(model) && !pure_continuous_input

# Prepare entity embeddings inputs and encode X if entity embeddings enabled
featnames = []
if enable_entity_embs
X = convert_to_table(X)
X = MLJFlux.convert_to_table(X)
featnames = Tables.schema(X).names
end

# entityprops is (index = cat_inds[i], levels = num_levels[i], newdim = newdims[i])
# for each categorical feature
default_embedding_dims = enable_entity_embs ? model.embedding_dims : Dict{Symbol, Real}()
entityprops, entityemb_output_dim =
prepare_entityembs(X, featnames, cat_inds, default_embedding_dims)
X, ordinal_mappings = ordinal_encoder_fit_transform(X; featinds = cat_inds)
MLJFlux.prepare_entityembs(X, featnames, cat_inds, default_embedding_dims)
X, ordinal_mappings = MLJFlux.ordinal_encoder_fit_transform(X; featinds = cat_inds)

## Construct model chain
chain =
(!enable_entity_embs) ? construct_model_chain(model, rng, shape, move) :
construct_model_chain_with_entityembs(
MLJFlux.construct_model_chain_with_entityembs(
model,
rng,
shape,
Expand All @@ -104,7 +104,7 @@ function MLJModelInterface.fit(model::MLJFluxModel,
)

# Format data as needed by Flux and move to GPU
data = move.(collate(model, X, y, verbosity))
data = move.(MLJFlux.collate(model, X, y, verbosity))

# Test chain works (as it may be custom)
x = data[1][1]
Expand Down
11 changes: 5 additions & 6 deletions test/classifier.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,17 @@ seed!(1234)
N = 300
Xm = MLJBase.table(randn(Float32, N, 5)); # purely numeric
X = (; Tables.columntable(Xm)...,
Column1 = repeat([1.0, 2.0, 3.0, 4.0, 5.0], Int(N / 5)),
Column1 = repeat(Float32[1.0, 2.0, 3.0, 4.0, 5.0], Int(N / 5)),
Column2 = categorical(repeat(['a', 'b', 'c', 'd', 'e'], Int(N / 5))),
Column3 = categorical(repeat(["b", "c", "d", "f", "f"], Int(N / 5)), ordered = true),
Column4 = repeat([1.0, 2.0, 3.0, 4.0, 5.0], Int(N / 5)),
Column5 = randn(N),
Column4 = repeat(Float32[1.0, 2.0, 3.0, 4.0, 5.0], Int(N / 5)),
Column5 = randn(Float32, N),
Column6 = categorical(
repeat(["group1", "group1", "group2", "group2", "group3"], Int(N / 5)),
),
)


ycont = 2 * X.x1 - X.x3 + 0.1 * rand(N)
ycont = 2 * X.x1 - X.x3 + 0.1 * rand(Float32, N)
m, M = minimum(ycont), maximum(ycont)
_, a, b, _ = range(m, stop = M, length = 4) |> collect
y = map(ycont) do η
Expand Down Expand Up @@ -126,7 +125,7 @@ end
seed!(1234)
N = 300
X = MLJBase.table(rand(Float32, N, 4));
ycont = 2 * X.x1 - X.x3 + 0.1 * rand(N)
ycont = Float32.(2 * X.x1 - X.x3 + 0.1 * rand(N))
m, M = minimum(ycont), maximum(ycont)
_, a, _ = range(m, stop = M, length = 3) |> collect
y = map(ycont) do η
Expand Down
8 changes: 4 additions & 4 deletions test/regressor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@ Random.seed!(123)
N = 200
Xm = MLJBase.table(randn(Float32, N, 5)); # purely numeric
X = (; Tables.columntable(Xm)...,
Column1 = repeat([1.0, 2.0, 3.0, 4.0, 5.0], Int(N / 5)),
Column1 = repeat(Float32[1.0, 2.0, 3.0, 4.0, 5.0], Int(N / 5)),
Column2 = categorical(repeat(['a', 'b', 'c', 'd', 'e'], Int(N / 5))),
Column3 = categorical(repeat(["b", "c", "d", "f", "f"], Int(N / 5)), ordered = true),
Column4 = repeat([1.0, 2.0, 3.0, 4.0, 5.0], Int(N / 5)),
Column5 = randn(N),
Column4 = repeat(Float32[1.0, 2.0, 3.0, 4.0, 5.0], Int(N / 5)),
Column5 = randn(Float32, N),
Column6 = categorical(
repeat(["group1", "group1", "group2", "group2", "group3"], Int(N / 5)),
),
Expand All @@ -17,7 +17,7 @@ builder = MLJFlux.Short(σ = identity)
optimiser = Optimisers.Adam()

Random.seed!(123)
y = 1 .+ X.x1 - X.x2 .- 2X.x4 + X.x5
y = Float32(1) .+ X.x1 - X.x2 .- 2X.x4 + X.x5
train, test = MLJBase.partition(1:N, 0.7)

@testset_accelerated "NeuralNetworkRegressor" accel begin
Expand Down

0 comments on commit 09b17da

Please sign in to comment.