Skip to content

Commit

Permalink
add tests and fix existing ones
Browse files Browse the repository at this point in the history
  • Loading branch information
tiemvanderdeure committed Sep 9, 2024
1 parent 56838bf commit 5ae2fe5
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 9 deletions.
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ julia = "1.9"
[extras]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Expand All @@ -42,4 +43,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"

[targets]
test = ["CUDA", "cuDNN", "LinearAlgebra", "MLJBase", "Random", "StableRNGs", "StatisticalMeasures", "StatsBase", "Test"]
test = ["CUDA", "cuDNN", "LinearAlgebra", "Logging", "MLJBase", "Random", "StableRNGs", "StatisticalMeasures", "StatsBase", "Test"]
19 changes: 12 additions & 7 deletions test/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,23 +14,28 @@ rowvec(y::Vector) = reshape(y, 1, length(y))
end

@testset "collate" begin
# NeuralNetworRegressor:
Xmatrix = broadcast(x->round(x, sigdigits=2), rand(stable_rng, 10, 3))
Xmatrix = broadcast(x->round(x, sigdigits=2), rand(stable_rng, Float32, 10, 3))
Xmat_f64 = Float64.(Xmatrix)
# convert to a column table:
X = MLJBase.table(Xmatrix)
X_64 = MLJBase.table(Xmat_f64)

# NeuralNetworRegressor:
y = rand(stable_rng, Float32, 10)
model = MLJFlux.NeuralNetworkRegressor()
model.batch_size= 3
@test MLJFlux.collate(model, X, y) ==
@test MLJFlux.collate(model, X, y, 1) == MLJFlux.collate(model, X_64, y, 1) ==
([Xmatrix'[:,1:3], Xmatrix'[:,4:6], Xmatrix'[:,7:9], Xmatrix'[:,10:10]],
rowvec.([y[1:3], y[4:6], y[7:9], y[10:10]]))
@test_logs (:info,) MLJFlux.collate(model, X_64, y, 1)
@test_logs min_level=Logging.Info MLJFlux.collate(model, X, y, 1)
@test_logs min_level=Logging.Info MLJFlux.collate(model, X, y, 0)

# NeuralNetworClassifier:
y = categorical(['a', 'b', 'a', 'a', 'b', 'a', 'a', 'a', 'b', 'a'])
model = MLJFlux.NeuralNetworkClassifier()
model.batch_size = 3
data = MLJFlux.collate(model, X, y)
data = MLJFlux.collate(model, X, y, 1)

@test data == ([Xmatrix'[:,1:3], Xmatrix'[:,4:6],
Xmatrix'[:,7:9], Xmatrix'[:,10:10]],
Expand All @@ -42,13 +47,13 @@ end
y = MLJBase.table(ymatrix) # a rowaccess table
model = MLJFlux.NeuralNetworkRegressor()
model.batch_size= 3
@test MLJFlux.collate(model, X, y) ==
@test MLJFlux.collate(model, X, y, 1) ==
([Xmatrix'[:,1:3], Xmatrix'[:,4:6], Xmatrix'[:,7:9], Xmatrix'[:,10:10]],
rowvec.([ymatrix'[:,1:3], ymatrix'[:,4:6], ymatrix'[:,7:9],
ymatrix'[:,10:10]]))

y = Tables.columntable(y) # try a columnaccess table
@test MLJFlux.collate(model, X, y) ==
@test MLJFlux.collate(model, X, y, 1) ==
([Xmatrix'[:,1:3], Xmatrix'[:,4:6], Xmatrix'[:,7:9], Xmatrix'[:,10:10]],
rowvec.([ymatrix'[:,1:3], ymatrix'[:,4:6],
ymatrix'[:,7:9], ymatrix'[:,10:10]]))
Expand All @@ -58,7 +63,7 @@ end
y = categorical(['a', 'b', 'a', 'a', 'b', 'a', 'a', 'a', 'b', 'a'])
model = MLJFlux.ImageClassifier(batch_size=2)

data = MLJFlux.collate(model, Xmatrix, y)
data = MLJFlux.collate(model, Xmatrix, y, 1)
@test first.(data) == (Float32.(cat(Xmatrix[1], Xmatrix[2], dims=4)),
rowvec.([1 0;0 1]))

Expand Down
2 changes: 1 addition & 1 deletion test/mlj_model_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ end
# (2) manually train for one epoch explicitly adding a loss penalty:
chain = MLJFlux.build(builder, StableRNG(123), 3, 1);
penalty = Penalizer(lambda, alpha); # defined in test_utils.jl
X, y = MLJFlux.collate(model, Xuser, yuser);
X, y = MLJFlux.collate(model, Xuser, yuser, 0);
loss = model.loss;
n_batches = div(nobservations, batch_size)
optimiser_state = Optimisers.setup(optimiser, chain);
Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ using StableRNGs
using CUDA, cuDNN
import StatisticalMeasures
import Optimisers
import Logging

using ComputationalResources
using ComputationalResources: CPU1, CUDALibs
Expand Down

0 comments on commit 5ae2fe5

Please sign in to comment.