Skip to content

Commit

Permalink
Merge remote-tracking branch 'tiemvanderdeure/convert_to_f32' into do…
Browse files Browse the repository at this point in the history
…uble-trouble
  • Loading branch information
ablaom committed Sep 16, 2024
2 parents eac2974 + 5ae2fe5 commit 1cfcd1e
Show file tree
Hide file tree
Showing 6 changed files with 32 additions and 18 deletions.
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ julia = "1.9"
[extras]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Expand All @@ -42,4 +43,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"

[targets]
test = ["CUDA", "cuDNN", "LinearAlgebra", "MLJBase", "Random", "StableRNGs", "StatisticalMeasures", "StatsBase", "Test"]
test = ["CUDA", "cuDNN", "LinearAlgebra", "Logging", "MLJBase", "Random", "StableRNGs", "StatisticalMeasures", "StatsBase", "Test"]
15 changes: 11 additions & 4 deletions src/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -276,15 +276,22 @@ input `X` and target `y` in the form required by
by `model.batch_size`.)
"""
function collate(model, X, y)
function collate(model, X, y, verbosity)
row_batches = Base.Iterators.partition(1:nrows(y), model.batch_size)
Xmatrix = reformat(X)
Xmatrix = _f32(reformat(X), verbosity)
ymatrix = reformat(y)
return [_get(Xmatrix, b) for b in row_batches], [_get(ymatrix, b) for b in row_batches]
end
function collate(model::NeuralNetworkBinaryClassifier, X, y)
function collate(model::NeuralNetworkBinaryClassifier, X, y, verbosity)
row_batches = Base.Iterators.partition(1:nrows(y), model.batch_size)
Xmatrix = reformat(X)
Xmatrix = _f32(reformat(X), verbosity)
yvec = (y .== classes(y)[2])' # convert to boolean
return [_get(Xmatrix, b) for b in row_batches], [_get(yvec, b) for b in row_batches]
end

_f32(x::AbstractArray{Float32}, verbosity) = x
function _f32(x::AbstractArray, verbosity)
verbosity > 0 && @info "MLJFlux: converting input data to Float32"
return Float32.(x)
end

10 changes: 5 additions & 5 deletions src/mlj_model_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ function MLJModelInterface.fit(model::MLJFluxModel,
featnames = Tables.schema(X).names
end

# entityprops is (index = cat_inds[i], levels = num_levels[i], newdim = newdims[i])
# entityprops is (index = cat_inds[i], levels = num_levels[i], newdim = newdims[i])
# for each categorical feature
default_embedding_dims = enable_entity_embs ? model.embedding_dims : Dict{Symbol, Real}()
entityprops, entityemb_output_dim =
Expand All @@ -103,8 +103,8 @@ function MLJModelInterface.fit(model::MLJFluxModel,
entityemb_output_dim,
)

# Format data as needed by Flux and move to GPU
data = move.(collate(model, X, y))
# Format data as needed by Flux and move to GPU
data = move.(collate(model, X, y, verbosity))

# Test chain works (as it may be custom)
x = data[1][1]
Expand Down Expand Up @@ -143,7 +143,7 @@ function MLJModelInterface.fit(model::MLJFluxModel,
featnames,
)

# Prepare fitresult
# Prepare fitresult
fitresult =
MLJFlux.fitresult(model, Flux.cpu(chain), y, ordinal_mappings, embedding_matrices)

Expand Down Expand Up @@ -216,7 +216,7 @@ function MLJModelInterface.update(model::MLJFluxModel,
chain = construct_model_chain(model, rng, shape, move)
end
# reset `optimiser_state`:
data = move.(collate(model, X, y))
data = move.(collate(model, X, y, verbosity))
regularized_optimiser, optimiser_state =
prepare_optimiser(data, model, chain)
epochs = model.epochs
Expand Down
19 changes: 12 additions & 7 deletions test/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,23 +14,28 @@ rowvec(y::Vector) = reshape(y, 1, length(y))
end

@testset "collate" begin
# NeuralNetworRegressor:
Xmatrix = broadcast(x->round(x, sigdigits=2), rand(stable_rng, 10, 3))
Xmatrix = broadcast(x->round(x, sigdigits=2), rand(stable_rng, Float32, 10, 3))
Xmat_f64 = Float64.(Xmatrix)
# convert to a column table:
X = MLJBase.table(Xmatrix)
X_64 = MLJBase.table(Xmat_f64)

# NeuralNetworRegressor:
y = rand(stable_rng, Float32, 10)
model = MLJFlux.NeuralNetworkRegressor()
model.batch_size= 3
@test MLJFlux.collate(model, X, y) ==
@test MLJFlux.collate(model, X, y, 1) == MLJFlux.collate(model, X_64, y, 1) ==
([Xmatrix'[:,1:3], Xmatrix'[:,4:6], Xmatrix'[:,7:9], Xmatrix'[:,10:10]],
rowvec.([y[1:3], y[4:6], y[7:9], y[10:10]]))
@test_logs (:info,) MLJFlux.collate(model, X_64, y, 1)
@test_logs min_level=Logging.Info MLJFlux.collate(model, X, y, 1)
@test_logs min_level=Logging.Info MLJFlux.collate(model, X, y, 0)

# NeuralNetworClassifier:
y = categorical(['a', 'b', 'a', 'a', 'b', 'a', 'a', 'a', 'b', 'a'])
model = MLJFlux.NeuralNetworkClassifier()
model.batch_size = 3
data = MLJFlux.collate(model, X, y)
data = MLJFlux.collate(model, X, y, 1)

@test data == ([Xmatrix'[:,1:3], Xmatrix'[:,4:6],
Xmatrix'[:,7:9], Xmatrix'[:,10:10]],
Expand All @@ -42,13 +47,13 @@ end
y = MLJBase.table(ymatrix) # a rowaccess table
model = MLJFlux.NeuralNetworkRegressor()
model.batch_size= 3
@test MLJFlux.collate(model, X, y) ==
@test MLJFlux.collate(model, X, y, 1) ==
([Xmatrix'[:,1:3], Xmatrix'[:,4:6], Xmatrix'[:,7:9], Xmatrix'[:,10:10]],
rowvec.([ymatrix'[:,1:3], ymatrix'[:,4:6], ymatrix'[:,7:9],
ymatrix'[:,10:10]]))

y = Tables.columntable(y) # try a columnaccess table
@test MLJFlux.collate(model, X, y) ==
@test MLJFlux.collate(model, X, y, 1) ==
([Xmatrix'[:,1:3], Xmatrix'[:,4:6], Xmatrix'[:,7:9], Xmatrix'[:,10:10]],
rowvec.([ymatrix'[:,1:3], ymatrix'[:,4:6],
ymatrix'[:,7:9], ymatrix'[:,10:10]]))
Expand All @@ -58,7 +63,7 @@ end
y = categorical(['a', 'b', 'a', 'a', 'b', 'a', 'a', 'a', 'b', 'a'])
model = MLJFlux.ImageClassifier(batch_size=2)

data = MLJFlux.collate(model, Xmatrix, y)
data = MLJFlux.collate(model, Xmatrix, y, 1)
@test first.(data) == (Float32.(cat(Xmatrix[1], Xmatrix[2], dims=4)),
rowvec.([1 0;0 1]))

Expand Down
2 changes: 1 addition & 1 deletion test/mlj_model_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ end
# (2) manually train for one epoch explicitly adding a loss penalty:
chain = MLJFlux.build(builder, StableRNG(123), 3, 1);
penalty = Penalizer(lambda, alpha); # defined in test_utils.jl
X, y = MLJFlux.collate(model, Xuser, yuser);
X, y = MLJFlux.collate(model, Xuser, yuser, 0);
loss = model.loss;
n_batches = div(nobservations, batch_size)
optimiser_state = Optimisers.setup(optimiser, chain);
Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ using StableRNGs
using CUDA, cuDNN
import StatisticalMeasures
import Optimisers
import Logging

using ComputationalResources
using ComputationalResources: CPU1, CUDALibs
Expand Down

0 comments on commit 1cfcd1e

Please sign in to comment.