From b3b41acc9788eea309a8014c78c1b6a33274317b Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Fri, 3 May 2024 15:34:54 +1200 Subject: [PATCH] simplify: use MLP builder instead of custom Short2 in tests --- test/classifier.jl | 9 +++++---- test/runtests.jl | 15 --------------- 2 files changed, 5 insertions(+), 19 deletions(-) diff --git a/test/classifier.jl b/test/classifier.jl index 972c6656..81ca2023 100644 --- a/test/classifier.jl +++ b/test/classifier.jl @@ -16,10 +16,11 @@ y = map(ycont) do η end end |> categorical; -# In the tests below we want to check GPU and CPU give similar results. We use Short2 -# instead of Short because Dropout in Short does not appear to behave the same on GPU as -# on a CPU, even when we use `default_rng()` for both. -builder = Short2() +# In the tests below we want to check GPU and CPU give similar results. We use the `MLP` +# builer instead of the default `Short()` because `Dropout()` in `Short()` does not appear +# to behave the same on GPU as on a CPU, even when we use `default_rng()` for both. + +builder = MLJFlux.MLP(hidden=(8,)) optimiser = Optimisers.Adam(0.03) losses = [] diff --git a/test/runtests.jl b/test/runtests.jl index 2e0d4fd1..b7b11d66 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -27,21 +27,6 @@ MLJFlux.gpu_isdead() && push!(EXCLUDED_RESOURCE_TYPES, CUDALibs) "these types, as unavailable:\n$EXCLUDED_RESOURCE_TYPES\n"* "Excluded tests marked as \"broken\"." -# Alternative version of `Short` builder with no dropout: -mutable struct Short2 <: MLJFlux.Builder - n_hidden::Int # if zero use geometric mean of input/output - σ -end -Short2(; n_hidden=0, σ=Flux.sigmoid) = Short2(n_hidden, σ) -function MLJFlux.build(builder::Short2, rng, n, m) - n_hidden = - builder.n_hidden == 0 ? round(Int, sqrt(n*m)) : builder.n_hidden - init = Flux.glorot_uniform(rng) - return Flux.Chain( - Flux.Dense(n, n_hidden, builder.σ, init=init), - Flux.Dense(n_hidden, m, init=init)) -end - seed!(123) include("test_utils.jl")