Skip to content

Commit

Permalink
simplify: use MLP builder instead of custom Short2 in tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ablaom committed May 3, 2024
1 parent ffb20bd commit b3b41ac
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 19 deletions.
9 changes: 5 additions & 4 deletions test/classifier.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@ y = map(ycont) do η
end
end |> categorical;

# In the tests below we want to check GPU and CPU give similar results. We use Short2
# instead of Short because Dropout in Short does not appear to behave the same on GPU as
# on a CPU, even when we use `default_rng()` for both.
builder = Short2()
# In the tests below we want to check GPU and CPU give similar results. We use the `MLP`
# builer instead of the default `Short()` because `Dropout()` in `Short()` does not appear
# to behave the same on GPU as on a CPU, even when we use `default_rng()` for both.

builder = MLJFlux.MLP(hidden=(8,))
optimiser = Optimisers.Adam(0.03)

losses = []
Expand Down
15 changes: 0 additions & 15 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,21 +27,6 @@ MLJFlux.gpu_isdead() && push!(EXCLUDED_RESOURCE_TYPES, CUDALibs)
"these types, as unavailable:\n$EXCLUDED_RESOURCE_TYPES\n"*
"Excluded tests marked as \"broken\"."

# Alternative version of `Short` builder with no dropout:
mutable struct Short2 <: MLJFlux.Builder
n_hidden::Int # if zero use geometric mean of input/output
σ
end
Short2(; n_hidden=0, σ=Flux.sigmoid) = Short2(n_hidden, σ)
function MLJFlux.build(builder::Short2, rng, n, m)
n_hidden =
builder.n_hidden == 0 ? round(Int, sqrt(n*m)) : builder.n_hidden
init = Flux.glorot_uniform(rng)
return Flux.Chain(
Flux.Dense(n, n_hidden, builder.σ, init=init),
Flux.Dense(n_hidden, m, init=init))
end

seed!(123)

include("test_utils.jl")
Expand Down

0 comments on commit b3b41ac

Please sign in to comment.