diff --git a/perf/recurrent.jl b/perf/recurrent.jl index ef00a8d9a5..1550009bd3 100644 --- a/perf/recurrent.jl +++ b/perf/recurrent.jl @@ -7,12 +7,10 @@ Flux.@functor RNNWrapper # Need to specialize for RNNWrapper. fw(r::RNNWrapper, X::Vector{<:AbstractArray}) = begin - Flux.reset!(r.rnn) [r.rnn(x) for x in X] end fw(r::RNNWrapper, X) = begin - Flux.reset!(r.rnn) r.rnn(X) end diff --git a/src/deprecations.jl b/src/deprecations.jl index 6148894dbe..8a9b67501d 100644 --- a/src/deprecations.jl +++ b/src/deprecations.jl @@ -167,3 +167,8 @@ end # where `loss_mxy` accepts the model as its first argument. # """ # )) + +function reset!(x) + Base.depwarn("reset!(m) is deprecated. You can remove this call as it is no more needed.", :reset!) + return x +end diff --git a/test/ext_amdgpu/runtests.jl b/test/ext_amdgpu/runtests.jl index 9dfbb41577..42972f285e 100644 --- a/test/ext_amdgpu/runtests.jl +++ b/test/ext_amdgpu/runtests.jl @@ -11,6 +11,6 @@ end end @testset "Recurrent" begin - BROKEN_TESTS = [] + global BROKEN_TESTS = [] include("../ext_common/recurrent_gpu_ad.jl") end diff --git a/test/ext_cuda/runtests.jl b/test/ext_cuda/runtests.jl index b4eee41315..be02409077 100644 --- a/test/ext_cuda/runtests.jl +++ b/test/ext_cuda/runtests.jl @@ -23,7 +23,7 @@ end include("cudnn.jl") end @testset "Recurrent" begin - BROKEN_TESTS = [] + global BROKEN_TESTS = [] include("../ext_common/recurrent_gpu_ad.jl") end @testset "ctc" begin diff --git a/test/ext_enzyme/enzyme.jl b/test/ext_enzyme/enzyme.jl index bae14fd246..aa04150cf0 100644 --- a/test/ext_enzyme/enzyme.jl +++ b/test/ext_enzyme/enzyme.jl @@ -93,7 +93,6 @@ end @testset "Models" begin function loss(model, x) - Flux.reset!(model) sum(model(x)) end @@ -126,7 +125,6 @@ end @testset "Recurrence Tests" begin function loss(model, x) - Flux.reset!(model) for i in 1:3 x = model(x) end diff --git a/test/ext_metal/runtests.jl b/test/ext_metal/runtests.jl index 5bb34caeb7..86e1068cf3 100644 --- a/test/ext_metal/runtests.jl +++ b/test/ext_metal/runtests.jl @@ -33,7 +33,7 @@ end end @testset "Recurrent" begin - BROKEN_TESTS = [:lstm, :gru, :gruv3] + global BROKEN_TESTS = [:lstm, :gru, :gruv3] include("../ext_common/recurrent_gpu_ad.jl") end diff --git a/test/utils.jl b/test/utils.jl index 79eebded49..0236a3d636 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -251,19 +251,19 @@ end end @testset "Params" begin - m = Dense(10, 5) + m = Dense(10 => 5) @test size.(params(m)) == [(5, 10), (5,)] - m = RNN(10, 5) - @test size.(params(m)) == [(5, 10), (5, 5), (5,), (5, 1)] + m = RNN(10 => 5) + @test size.(params(m)) == [(5, 10), (5, 5), (5,)] # Layer duplicated in same chain, params just once pls. c = Chain(m, m) - @test size.(params(c)) == [(5, 10), (5, 5), (5,), (5, 1)] + @test size.(params(c)) == [(5, 10), (5, 5), (5,)] # Self-referential array. Just want params, no stack overflow pls. r = Any[nothing,m] r[1] = r - @test size.(params(r)) == [(5, 10), (5, 5), (5,), (5, 1)] + @test size.(params(r)) == [(5, 10), (5, 5), (5,)] # Ensure functor explores inside Transpose but not SubArray m = (x = view([1,2,3]pi, 1:2), y = transpose([4 5]pi)) @@ -273,7 +273,7 @@ end @testset "params gradient" begin m = (x=[1,2.0], y=[3.0]); - # Explicit -- was broken by #2054 / then fixed / now broken again on julia v1.11 + # Explicit -- was broken by #2054 gnew = gradient(m -> (sum(norm, Flux.params(m))), m)[1] @test gnew.x ≈ [0.4472135954999579, 0.8944271909999159] @test gnew.y ≈ [1.0] @@ -286,7 +286,7 @@ end end @testset "Precision" begin - m = Chain(Dense(10, 5, relu; bias=false), Dense(5, 2)) + m = Chain(Dense(10 => 5, relu; bias=false), Dense(5 => 2)) x64 = rand(Float64, 10) x32 = rand(Float32, 10) i64 = rand(Int64, 10) @@ -467,10 +467,10 @@ end @test modules[5] === m2 @test modules[6] === m3 - mod_par = Flux.modules(Parallel(Flux.Bilinear(2,2,2,cbrt), Dense(2,2,abs), Dense(2,2,abs2))) + mod_par = Flux.modules(Parallel(Flux.Bilinear(2,2,2,cbrt), Dense(2=>2,abs), Dense(2=>2,abs2))) @test length(mod_par) == 5 - mod_rnn = Flux.modules(Chain(Dense(2,3), BatchNorm(3), LSTM(3,4))) + mod_rnn = Flux.modules(Chain(Dense(2=>3), BatchNorm(3), LSTM(3=>4))) @test length(mod_rnn) == 6 @test mod_rnn[end] isa Flux.LSTMCell