From 2c3e25783bb8fcdfe9c277f7f70228383802cf02 Mon Sep 17 00:00:00 2001 From: CarloLucibello Date: Sun, 13 Oct 2024 23:51:48 +0200 Subject: [PATCH] remove params entirely --- src/layers/show.jl | 10 +++---- src/outputsize.jl | 2 -- test/data.jl | 6 ++-- test/ext_cuda/curnn.jl | 14 +--------- test/ext_cuda/layers.jl | 62 +++++++++++++---------------------------- test/layers/basic.jl | 11 ++++---- test/runtests.jl | 5 ++-- test/utils.jl | 28 ++++--------------- 8 files changed, 43 insertions(+), 95 deletions(-) diff --git a/src/layers/show.jl b/src/layers/show.jl index a03ddf3754..95e7d8746b 100644 --- a/src/layers/show.jl +++ b/src/layers/show.jl @@ -90,21 +90,21 @@ function _layer_show(io::IO, layer, indent::Int=0, name=nothing) _str = isnothing(name) ? "" : "$name = " str = _str * sprint(show, layer, context=io) print(io, " "^indent, str, indent==0 ? "" : ",") - if !isempty(params(layer)) + if !isempty(trainables(layer)) print(io, " "^max(2, (indent==0 ? 20 : 39) - indent - length(str))) - printstyled(io, "# ", underscorise(sum(length, params(layer); init=0)), " parameters"; + printstyled(io, "# ", underscorise(sum(length, trainables(layer); init=0)), " parameters"; color=:light_black) - nonparam = _childarray_sum(length, layer) - sum(length, params(layer), init=0) + nonparam = _childarray_sum(length, layer) - sum(length, trainables(layer), init=0) if nonparam > 0 printstyled(io, ", plus ", underscorise(nonparam), indent==0 ? " non-trainable" : ""; color=:light_black) end - _nan_show(io, params(layer)) + _nan_show(io, trainables(layer)) end indent==0 || println(io) end function _big_finale(io::IO, m) - ps = params(m) + ps = trainables(m) if length(ps) > 2 pars = underscorise(sum(length, ps; init=0)) bytes = Base.format_bytes(Base.summarysize(m)) diff --git a/src/outputsize.jl b/src/outputsize.jl index 5d6132d059..c413405048 100644 --- a/src/outputsize.jl +++ b/src/outputsize.jl @@ -302,8 +302,6 @@ function ChainRulesCore.rrule(::typeof(striplazy), m) striplazy(m), _ -> error("striplazy should never be used within a gradient") end -params!(p::Params, x::LazyLayer, seen = IdSet()) = error("LazyLayer should never be used within params(m). Call striplazy(m) first.") - Functors.functor(::Type{<:LazyLayer}, x) = error("LazyLayer should not be walked with Functors.jl, as the arrays which Flux.gpu wants to move may not exist yet.") function Base.show(io::IO, l::LazyLayer) diff --git a/test/data.jl b/test/data.jl index b97c4dae80..8fd2507f3e 100644 --- a/test/data.jl +++ b/test/data.jl @@ -82,7 +82,8 @@ using Random X = zeros(2, 10) loss(x) = sum((x .- θ).^2) d = DataLoader(X) - Flux.train!(loss, Params([θ]), ncycle(d, 10), Descent(0.1)) + opt_state = Flux.setup(Descent(0.1), θ) + Flux.train!(loss, θ, ncycle(d, 10), opt_state) @test norm(θ) < 1e-4 # test interaction with `train!` @@ -91,7 +92,8 @@ using Random Y = fill(2, 10) loss(x, y) = sum((y - x'*θ).^2) d = DataLoader((X, Y)) - Flux.train!(loss, Params([θ]), ncycle(d, 10), Descent(0.1)) + opt_state = Flux.setup(Descent(0.1), θ) + Flux.train!(loss, θ, ncycle(d, 10), opt_state) @test norm(θ .- 1) < 1e-10 # specify the rng diff --git a/test/ext_cuda/curnn.jl b/test/ext_cuda/curnn.jl index 5c460d2aa4..385514e5ad 100644 --- a/test/ext_cuda/curnn.jl +++ b/test/ext_cuda/curnn.jl @@ -1,20 +1,8 @@ -using Flux, CUDA, Test - -@testset for R in [RNN, GRU, LSTM, GRUv3] - m = R(10, 5) |> gpu - x = gpu(rand(10)) - (m̄,) = gradient(m -> sum(m(x)), m) - Flux.reset!(m) - θ = gradient(() -> sum(m(x)), params(m)) - @test x isa CuArray - @test θ[m.cell.Wi] isa CuArray - @test collect(m̄.cell.Wi) == collect(θ[m.cell.Wi]) -end @testset "RNN" begin @testset for R in [RNN, GRU, LSTM, GRUv3], batch_size in (1, 5) rnn = R(10, 5) - curnn = fmap(gpu, rnn) + curnn = rnn |> gpu Flux.reset!(rnn) Flux.reset!(curnn) diff --git a/test/ext_cuda/layers.jl b/test/ext_cuda/layers.jl index cba95cee75..d8f33b3806 100644 --- a/test/ext_cuda/layers.jl +++ b/test/ext_cuda/layers.jl @@ -110,8 +110,8 @@ end l = cl((2,2), 1=>3, bias = false) |> gpu ip = zeros(Float32, 28,28,1,1) |> gpu @test sum(l(ip)) ≈ 0.f0 - gs = gradient(() -> sum(l(ip)), Flux.params(l)) - @test l.bias ∉ gs.params + gs = gradient(l -> sum(l(ip)), l)[1] + @test gs.bias === nothing end @testset "Dense without bias" begin @@ -119,8 +119,8 @@ end ip = zeros(Float32, 3, 7) |> gpu @test sum(l(ip)) ≈ 0.f0 - gs = gradient(() -> sum(l(ip)), Flux.params(l)) - @test l.bias ∉ gs.params + gs = gradient(l -> sum(l(ip)), l)[1] + @test gs.bias === nothing end @testset "Extended BatchNorm" begin @@ -133,13 +133,13 @@ end μ_cpu = copy(m_cpu.μ) m_cpu(x_cpu) @test m_cpu.μ ≈ μ_cpu - gradient(() -> sum(m_cpu(x_cpu)), Flux.params(m_cpu)) + gradient(m_cpu -> sum(m_cpu(x_cpu)), m_cpu) @test !(m_cpu.μ ≈ μ_cpu) μ_gpu = copy(m_gpu.μ) m_gpu(x_gpu) @test m_gpu.μ ≈ μ_gpu - gradient(() -> sum(m_gpu(x_gpu)), Flux.params(m_gpu)) + gradient(m_gpu -> sum(m_gpu(x_gpu)), m_gpu) @test !(m_gpu.μ ≈ μ_gpu) @test Array(m_gpu.μ) ≈ m_cpu.μ @@ -149,14 +149,14 @@ end μ_cpu = copy(m_cpu.μ) m_cpu(x_cpu) @test m_cpu.μ ≈ μ_cpu - gradient(() -> sum(m_cpu(x_cpu)), Flux.params(m_cpu)) + gradient(m_cpu -> sum(m_cpu(x_cpu)), m_cpu) @test m_cpu.μ ≈ μ_cpu testmode!(m_gpu) μ_gpu = copy(m_gpu.μ) m_gpu(x_gpu) @test m_gpu.μ ≈ μ_gpu - gradient(() -> sum(m_gpu(x_gpu)), Flux.params(m_gpu)) + gradient(m_gpu -> sum(m_gpu(x_gpu)), m_gpu) @test m_gpu.μ ≈ μ_gpu ## In trainmode, always track statistics @@ -165,7 +165,7 @@ end m_cpu(x_cpu) @test !(m_cpu.μ ≈ μ_cpu) μ_cpu = copy(m_cpu.μ) - gradient(() -> sum(m_cpu(x_cpu)), Flux.params(m_cpu)) + gradient(m_cpu -> sum(m_cpu(x_cpu)), m_cpu) @test !(m_cpu.μ ≈ μ_cpu) trainmode!(m_gpu) @@ -173,44 +173,28 @@ end m_gpu(x_gpu) @test !(m_gpu.μ ≈ μ_gpu) μ_gpu = copy(m_gpu.μ) - gradient(() -> sum(m_gpu(x_gpu)), Flux.params(m_gpu)) + gradient(m_gpu -> sum(m_gpu(x_gpu)), m_gpu) @test !(m_gpu.μ ≈ μ_gpu) - - ## No errors if input type mistmatch - # x_cpu = rand(Float64, 3, 2, 2) - # x_gpu = x_cpu |> gpu - # m_cpu(x_cpu) - # gradient(() -> sum(m_cpu(x_cpu)), Flux.params(m_cpu)) - # m_gpu(x_gpu) - # gradient(() -> sum(m_gpu(x_gpu)), Flux.params(m_gpu)) end @testset "Two-streams Bilinear" begin x = zeros(Float32,10,9) |> gpu y = zeros(Float32,2,9) |> gpu b = Flux.Bilinear(10, 2, 3) |> gpu - @test size(b(x,y)) == (3,9) - @test sum(abs2, b(x,y)) ≈ 0f0 - gs_gpu = gradient(() -> sum(abs2.(b(x, y))), params(b)) - b_cpu, x_cpu, y_cpu = b |> cpu, x |> cpu, y |> cpu - gs_cpu = gradient(() -> sum(abs2.(b_cpu(x_cpu, y_cpu))), params(b_cpu)) - for (pgpu, pcpu) in zip(params(b), params(b_cpu)) - @test gs_cpu[pcpu] ≈ Array(gs_gpu[pgpu]) - end + @test size(b(x, y)) == (3,9) + @test sum(abs2, b(x, y)) ≈ 0f0 + test_gradients(b |> cpu, x |> cpu, y |> cpu, + test_gpu=true, compare_finite_diff=false, loss=o -> mean(abs2, o)) end @testset "Two-streams Bilinear" begin x = zeros(Float32,10,9) |> gpu y = zeros(Float32,2,9) |> gpu b = Flux.Bilinear(10, 2, 3) |> gpu - @test size(b(x,y)) == (3,9) - @test sum(abs2, b(x,y)) ≈ 0f0 - gs_gpu = gradient(() -> sum(abs2.(b(x, y))), params(b)) - b_cpu, x_cpu, y_cpu = b |> cpu, x |> cpu, y |> cpu - gs_cpu = gradient(() -> sum(abs2.(b_cpu(x_cpu, y_cpu))), params(b_cpu)) - for (pgpu, pcpu) in zip(params(b), params(b_cpu)) - @test gs_cpu[pcpu] ≈ Array(gs_gpu[pgpu]) - end + @test size(b(x, y)) == (3,9) + @test sum(abs2, b(x, y)) ≈ 0f0 + test_gradients(b |> cpu, x |> cpu, y |> cpu, + test_gpu=true, compare_finite_diff=false, loss=o -> mean(abs2, o)) end @testset "Parallel" begin @@ -228,15 +212,9 @@ end end @testset "gradient" begin - input_cpu = randn(10, 10, 10, 10) - input_gpu = input_cpu |> gpu layer_cpu = Parallel(+, x -> zero(x), identity) - layer_gpu = layer_cpu |> gpu - gs_cpu = gradient(() -> sum(abs2.(layer_cpu(input_cpu))), params(layer_cpu)) - gs_gpu = gradient(() -> sum(abs2.(layer_gpu(input_gpu))), params(layer_gpu)) - for (pgpu, pcpu) in zip(params(layer_cpu), params(layer_gpu)) - @test gs_cpu[pcpu] ≈ gs_gpu[pgpu] - end + test_gradients(layer_cpu, randn(5, 5, 5, 5), + test_gpu=true, compare_finite_diff=false, loss=o -> mean(abs2, o)) end end diff --git a/test/layers/basic.jl b/test/layers/basic.jl index e4f8b23ea9..1ef9e71f51 100644 --- a/test/layers/basic.jl +++ b/test/layers/basic.jl @@ -196,7 +196,7 @@ using Flux: activations x = randn(Float32,11,7) b = Flux.Bilinear(11, 11, 3) @test size(b(x)) == (3,7) - @test_nowarn gs = gradient(() -> sum(abs2.(b(x))), params(b)) + test_gradients(b, x) end @testset "constructors" begin @@ -436,16 +436,15 @@ end @testset "gradients of Chain{Vector}" begin m1 = Chain(Dense(3,4,tanh; bias=false), Dense(4,2)) m1v = Chain([m1[1], m1[2]]) - @test sum(length, params(m1)) == sum(length, params(m1v)) + @test sum(length, trainables(m1)) == sum(length, trainables(m1v)) x1 = randn(Float32,3,5) @test m1(x1) ≈ m1v(x1) y1 = rand(Bool,2,5) - g1 = gradient(() -> Flux.Losses.logitcrossentropy(m1(x1), y1), params(m1)) - g1v = gradient(() -> Flux.Losses.logitcrossentropy(m1v(x1), y1), params(m1v)) - @test g1[m1[1].weight] ≈ g1v[m1v[1].weight] - @test g1[m1[2].bias] ≈ g1v[m1v[2].bias] + g1 = gradient(m1 -> Flux.logitcrossentropy(m1(x1), y1), m1)[1] + g1v = gradient(m1v -> Flux.logitcrossentropy(m1v(x1), y1), m1v)[1] + check_equal_leaves(g1, g1v) @test Flux.destructure(m1)[1] ≈ Flux.destructure(m1v)[1] z1 = rand(22); diff --git a/test/runtests.jl b/test/runtests.jl index 01e56ee1e4..6f5a2e7d84 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,6 +1,5 @@ using Flux using Flux: OneHotArray, OneHotMatrix, OneHotVector -using Flux: params using Test using Random, Statistics, LinearAlgebra using IterTools: ncycle @@ -11,9 +10,9 @@ using Functors: fmapstructure_with_path ## Uncomment below to change the default test settings # ENV["FLUX_TEST_AMDGPU"] = "true" -ENV["FLUX_TEST_CUDA"] = "true" +# ENV["FLUX_TEST_CUDA"] = "true" # ENV["FLUX_TEST_METAL"] = "true" -ENV["FLUX_TEST_CPU"] = "false" +# ENV["FLUX_TEST_CPU"] = "false" # ENV["FLUX_TEST_DISTRIBUTED_MPI"] = "true" # ENV["FLUX_TEST_DISTRIBUTED_NCCL"] = "true" ENV["FLUX_TEST_ENZYME"] = "false" # We temporarily disable Enzyme tests since they are failing diff --git a/test/utils.jl b/test/utils.jl index 8cdbac2daf..8c92728865 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -2,7 +2,7 @@ using Flux using Flux: throttle, nfan, glorot_uniform, glorot_normal, kaiming_normal, kaiming_uniform, orthogonal, truncated_normal, sparse_init, identity_init, unstack, batch, unbatch, - unsqueeze, params, loadmodel! + unsqueeze, loadmodel! using MLUtils using Statistics, LinearAlgebra using Random @@ -334,28 +334,12 @@ end o = ones(s) z = zeros(s) - @testset "Explicit" begin - gfun(args...) = gradient((x, y) -> sum(op.(x,y)), args...) - g = gfun(o, z) - @test gfun(o, false) == (g[1], nothing) + gfun(args...) = gradient((x, y) -> sum(op.(x,y)), args...) + g = gfun(o, z) + @test gfun(o, false) == (g[1], nothing) - g = gfun(z, o) - @test gfun(false, o) == (nothing, g[2]) - end - - @testset "Implicit" begin - gfun(args...) = gradient(() -> sum(op.(args...)), params(collect(args))) - g = gfun(o, z) - - gres = gfun(o, false) - @test gres[o] == g[o] - @test false ∉ gres.params - - g = gfun(z, o) - gres = gfun(false, o) - @test gres[o] == g[o] - @test false ∉ gres.params - end + g = gfun(z, o) + @test gfun(false, o) == (nothing, g[2]) end end