From 9c2f285005909a65f470ccd9b2d9a7a8557c7fdf Mon Sep 17 00:00:00 2001 From: CarloLucibello Date: Sun, 20 Oct 2024 10:03:57 +0200 Subject: [PATCH] cleanup --- test/runtests.jl | 4 ++-- test/train.jl | 21 +++++++++------------ test/utils.jl | 12 +++++------- 3 files changed, 16 insertions(+), 21 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index f44c4b7758..ff6660be14 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -11,9 +11,9 @@ using Functors: fmapstructure_with_path ## Uncomment below to change the default test settings # ENV["FLUX_TEST_AMDGPU"] = "true" -ENV["FLUX_TEST_CUDA"] = "true" +# ENV["FLUX_TEST_CUDA"] = "true" # ENV["FLUX_TEST_METAL"] = "true" -ENV["FLUX_TEST_CPU"] = "false" +# ENV["FLUX_TEST_CPU"] = "false" # ENV["FLUX_TEST_DISTRIBUTED_MPI"] = "true" # ENV["FLUX_TEST_DISTRIBUTED_NCCL"] = "true" ENV["FLUX_TEST_ENZYME"] = "false" # We temporarily disable Enzyme tests since they are failing diff --git a/test/train.jl b/test/train.jl index 38338c19b9..5a1fd0592e 100644 --- a/test/train.jl +++ b/test/train.jl @@ -155,18 +155,15 @@ for (trainfn!, name) in ((Flux.train!, "Zygote"), (train_enzyme!, "Enzyme")) pen2(x::AbstractArray) = sum(abs2, x)/2 opt = Flux.setup(Adam(0.1), model) - @test begin - trainfn!(model, data, opt) do m, x, y - err = Flux.mse(m(x), y) - l2 = sum(pen2, Flux.params(m)) - err + 0.33 * l2 - end - - diff2 = model.weight .- init_weight - @test diff1 ≈ diff2 - - true - end broken = VERSION >= v"1.11" + trainfn!(model, data, opt) do m, x, y + err = Flux.mse(m(x), y) + l2 = sum(pen2, Flux.params(m)) + err + 0.33 * l2 + end + + diff2 = model.weight .- init_weight + @test diff1 ≈ diff2 + end # Take 3: using WeightDecay instead. Need the /2 above, to match exactly. diff --git a/test/utils.jl b/test/utils.jl index b526b63286..79eebded49 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -273,13 +273,11 @@ end @testset "params gradient" begin m = (x=[1,2.0], y=[3.0]); - @test begin - # Explicit -- was broken by #2054 / then fixed / now broken again on julia v1.11 - gnew = gradient(m -> (sum(norm, Flux.params(m))), m)[1] - @test gnew.x ≈ [0.4472135954999579, 0.8944271909999159] - @test gnew.y ≈ [1.0] - true - end broken = VERSION >= v"1.11" + # Explicit -- was broken by #2054 / then fixed / now broken again on julia v1.11 + gnew = gradient(m -> (sum(norm, Flux.params(m))), m)[1] + @test gnew.x ≈ [0.4472135954999579, 0.8944271909999159] + @test gnew.y ≈ [1.0] + # Implicit gold = gradient(() -> (sum(norm, Flux.params(m))), Flux.params(m))