From d18fe13b4cc479ee166556eff6685bed5e78237f Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Fri, 18 Oct 2024 04:55:02 +0200 Subject: [PATCH] fix tests --- test/utils.jl | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/test/utils.jl b/test/utils.jl index affb54c085..9221297fd4 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -251,9 +251,9 @@ end end @testset "Params" begin - m = Dense(10, 5) + m = Dense(10 => 5) @test size.(params(m)) == [(5, 10), (5,)] - m = RNN(10, 5) + m = RNN(10 => 5) @test size.(params(m)) == [(5, 10), (5, 5), (5,), (5, 1)] # Layer duplicated in same chain, params just once pls. @@ -273,13 +273,11 @@ end @testset "params gradient" begin m = (x=[1,2.0], y=[3.0]); - @test begin - # Explicit -- was broken by #2054 / then fixed / now broken again on julia v1.11 - gnew = gradient(m -> (sum(norm, Flux.params(m))), m)[1] - @test gnew.x ≈ [0.4472135954999579, 0.8944271909999159] - @test gnew.y ≈ [1.0] - true - end broken = VERSION >= v"1.11" + # Explicit -- was broken by #2054 + gnew = gradient(m -> (sum(norm, Flux.params(m))), m)[1] + @test gnew.x ≈ [0.4472135954999579, 0.8944271909999159] + @test gnew.y ≈ [1.0] + # Implicit gold = gradient(() -> (sum(norm, Flux.params(m))), Flux.params(m)) @@ -288,7 +286,7 @@ end end @testset "Precision" begin - m = Chain(Dense(10, 5, relu; bias=false), Dense(5, 2)) + m = Chain(Dense(10 => 5, relu; bias=false), Dense(5 => 2)) x64 = rand(Float64, 10) x32 = rand(Float32, 10) i64 = rand(Int64, 10)