Skip to content

Commit

Permalink
remove params entirely
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello committed Oct 13, 2024
1 parent 17db916 commit 2c3e257
Show file tree
Hide file tree
Showing 8 changed files with 43 additions and 95 deletions.
10 changes: 5 additions & 5 deletions src/layers/show.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,21 +90,21 @@ function _layer_show(io::IO, layer, indent::Int=0, name=nothing)
_str = isnothing(name) ? "" : "$name = "
str = _str * sprint(show, layer, context=io)
print(io, " "^indent, str, indent==0 ? "" : ",")
if !isempty(params(layer))
if !isempty(trainables(layer))
print(io, " "^max(2, (indent==0 ? 20 : 39) - indent - length(str)))
printstyled(io, "# ", underscorise(sum(length, params(layer); init=0)), " parameters";
printstyled(io, "# ", underscorise(sum(length, trainables(layer); init=0)), " parameters";
color=:light_black)
nonparam = _childarray_sum(length, layer) - sum(length, params(layer), init=0)
nonparam = _childarray_sum(length, layer) - sum(length, trainables(layer), init=0)
if nonparam > 0
printstyled(io, ", plus ", underscorise(nonparam), indent==0 ? " non-trainable" : ""; color=:light_black)
end
_nan_show(io, params(layer))
_nan_show(io, trainables(layer))
end
indent==0 || println(io)
end

function _big_finale(io::IO, m)
ps = params(m)
ps = trainables(m)
if length(ps) > 2
pars = underscorise(sum(length, ps; init=0))
bytes = Base.format_bytes(Base.summarysize(m))
Expand Down
2 changes: 0 additions & 2 deletions src/outputsize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -302,8 +302,6 @@ function ChainRulesCore.rrule(::typeof(striplazy), m)
striplazy(m), _ -> error("striplazy should never be used within a gradient")
end

params!(p::Params, x::LazyLayer, seen = IdSet()) = error("LazyLayer should never be used within params(m). Call striplazy(m) first.")

Functors.functor(::Type{<:LazyLayer}, x) = error("LazyLayer should not be walked with Functors.jl, as the arrays which Flux.gpu wants to move may not exist yet.")

function Base.show(io::IO, l::LazyLayer)
Expand Down
6 changes: 4 additions & 2 deletions test/data.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@ using Random
X = zeros(2, 10)
loss(x) = sum((x .- θ).^2)
d = DataLoader(X)
Flux.train!(loss, Params([θ]), ncycle(d, 10), Descent(0.1))
opt_state = Flux.setup(Descent(0.1), θ)
Flux.train!(loss, θ, ncycle(d, 10), opt_state)
@test norm(θ) < 1e-4

# test interaction with `train!`
Expand All @@ -91,7 +92,8 @@ using Random
Y = fill(2, 10)
loss(x, y) = sum((y - x'*θ).^2)
d = DataLoader((X, Y))
Flux.train!(loss, Params([θ]), ncycle(d, 10), Descent(0.1))
opt_state = Flux.setup(Descent(0.1), θ)
Flux.train!(loss, θ, ncycle(d, 10), opt_state)
@test norm.- 1) < 1e-10

# specify the rng
Expand Down
14 changes: 1 addition & 13 deletions test/ext_cuda/curnn.jl
Original file line number Diff line number Diff line change
@@ -1,20 +1,8 @@
using Flux, CUDA, Test

@testset for R in [RNN, GRU, LSTM, GRUv3]
m = R(10, 5) |> gpu
x = gpu(rand(10))
(m̄,) = gradient(m -> sum(m(x)), m)
Flux.reset!(m)
θ = gradient(() -> sum(m(x)), params(m))
@test x isa CuArray
@test θ[m.cell.Wi] isa CuArray
@test collect(m̄.cell.Wi) == collect(θ[m.cell.Wi])
end

@testset "RNN" begin
@testset for R in [RNN, GRU, LSTM, GRUv3], batch_size in (1, 5)
rnn = R(10, 5)
curnn = fmap(gpu, rnn)
curnn = rnn |> gpu

Flux.reset!(rnn)
Flux.reset!(curnn)
Expand Down
62 changes: 20 additions & 42 deletions test/ext_cuda/layers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -110,17 +110,17 @@ end
l = cl((2,2), 1=>3, bias = false) |> gpu
ip = zeros(Float32, 28,28,1,1) |> gpu
@test sum(l(ip)) 0.f0
gs = gradient(() -> sum(l(ip)), Flux.params(l))
@test l.bias gs.params
gs = gradient(l -> sum(l(ip)), l)[1]
@test gs.bias === nothing
end

@testset "Dense without bias" begin
l = Dense(ones(Float32, 4, 3), false) |> gpu
ip = zeros(Float32, 3, 7) |> gpu

@test sum(l(ip)) 0.f0
gs = gradient(() -> sum(l(ip)), Flux.params(l))
@test l.bias gs.params
gs = gradient(l -> sum(l(ip)), l)[1]
@test gs.bias === nothing
end

@testset "Extended BatchNorm" begin
Expand All @@ -133,13 +133,13 @@ end
μ_cpu = copy(m_cpu.μ)
m_cpu(x_cpu)
@test m_cpu.μ μ_cpu
gradient(() -> sum(m_cpu(x_cpu)), Flux.params(m_cpu))
gradient(m_cpu -> sum(m_cpu(x_cpu)), m_cpu)
@test !(m_cpu.μ μ_cpu)

μ_gpu = copy(m_gpu.μ)
m_gpu(x_gpu)
@test m_gpu.μ μ_gpu
gradient(() -> sum(m_gpu(x_gpu)), Flux.params(m_gpu))
gradient(m_gpu -> sum(m_gpu(x_gpu)), m_gpu)
@test !(m_gpu.μ μ_gpu)

@test Array(m_gpu.μ) m_cpu.μ
Expand All @@ -149,14 +149,14 @@ end
μ_cpu = copy(m_cpu.μ)
m_cpu(x_cpu)
@test m_cpu.μ μ_cpu
gradient(() -> sum(m_cpu(x_cpu)), Flux.params(m_cpu))
gradient(m_cpu -> sum(m_cpu(x_cpu)), m_cpu)
@test m_cpu.μ μ_cpu

testmode!(m_gpu)
μ_gpu = copy(m_gpu.μ)
m_gpu(x_gpu)
@test m_gpu.μ μ_gpu
gradient(() -> sum(m_gpu(x_gpu)), Flux.params(m_gpu))
gradient(m_gpu -> sum(m_gpu(x_gpu)), m_gpu)
@test m_gpu.μ μ_gpu

## In trainmode, always track statistics
Expand All @@ -165,52 +165,36 @@ end
m_cpu(x_cpu)
@test !(m_cpu.μ μ_cpu)
μ_cpu = copy(m_cpu.μ)
gradient(() -> sum(m_cpu(x_cpu)), Flux.params(m_cpu))
gradient(m_cpu -> sum(m_cpu(x_cpu)), m_cpu)
@test !(m_cpu.μ μ_cpu)

trainmode!(m_gpu)
μ_gpu = copy(m_gpu.μ)
m_gpu(x_gpu)
@test !(m_gpu.μ μ_gpu)
μ_gpu = copy(m_gpu.μ)
gradient(() -> sum(m_gpu(x_gpu)), Flux.params(m_gpu))
gradient(m_gpu -> sum(m_gpu(x_gpu)), m_gpu)
@test !(m_gpu.μ μ_gpu)

## No errors if input type mistmatch
# x_cpu = rand(Float64, 3, 2, 2)
# x_gpu = x_cpu |> gpu
# m_cpu(x_cpu)
# gradient(() -> sum(m_cpu(x_cpu)), Flux.params(m_cpu))
# m_gpu(x_gpu)
# gradient(() -> sum(m_gpu(x_gpu)), Flux.params(m_gpu))
end

@testset "Two-streams Bilinear" begin
x = zeros(Float32,10,9) |> gpu
y = zeros(Float32,2,9) |> gpu
b = Flux.Bilinear(10, 2, 3) |> gpu
@test size(b(x,y)) == (3,9)
@test sum(abs2, b(x,y)) 0f0
gs_gpu = gradient(() -> sum(abs2.(b(x, y))), params(b))
b_cpu, x_cpu, y_cpu = b |> cpu, x |> cpu, y |> cpu
gs_cpu = gradient(() -> sum(abs2.(b_cpu(x_cpu, y_cpu))), params(b_cpu))
for (pgpu, pcpu) in zip(params(b), params(b_cpu))
@test gs_cpu[pcpu] Array(gs_gpu[pgpu])
end
@test size(b(x, y)) == (3,9)
@test sum(abs2, b(x, y)) 0f0
test_gradients(b |> cpu, x |> cpu, y |> cpu,
test_gpu=true, compare_finite_diff=false, loss=o -> mean(abs2, o))
end

@testset "Two-streams Bilinear" begin
x = zeros(Float32,10,9) |> gpu
y = zeros(Float32,2,9) |> gpu
b = Flux.Bilinear(10, 2, 3) |> gpu
@test size(b(x,y)) == (3,9)
@test sum(abs2, b(x,y)) 0f0
gs_gpu = gradient(() -> sum(abs2.(b(x, y))), params(b))
b_cpu, x_cpu, y_cpu = b |> cpu, x |> cpu, y |> cpu
gs_cpu = gradient(() -> sum(abs2.(b_cpu(x_cpu, y_cpu))), params(b_cpu))
for (pgpu, pcpu) in zip(params(b), params(b_cpu))
@test gs_cpu[pcpu] Array(gs_gpu[pgpu])
end
@test size(b(x, y)) == (3,9)
@test sum(abs2, b(x, y)) 0f0
test_gradients(b |> cpu, x |> cpu, y |> cpu,
test_gpu=true, compare_finite_diff=false, loss=o -> mean(abs2, o))
end

@testset "Parallel" begin
Expand All @@ -228,15 +212,9 @@ end
end

@testset "gradient" begin
input_cpu = randn(10, 10, 10, 10)
input_gpu = input_cpu |> gpu
layer_cpu = Parallel(+, x -> zero(x), identity)
layer_gpu = layer_cpu |> gpu
gs_cpu = gradient(() -> sum(abs2.(layer_cpu(input_cpu))), params(layer_cpu))
gs_gpu = gradient(() -> sum(abs2.(layer_gpu(input_gpu))), params(layer_gpu))
for (pgpu, pcpu) in zip(params(layer_cpu), params(layer_gpu))
@test gs_cpu[pcpu] gs_gpu[pgpu]
end
test_gradients(layer_cpu, randn(5, 5, 5, 5),
test_gpu=true, compare_finite_diff=false, loss=o -> mean(abs2, o))
end
end

Expand Down
11 changes: 5 additions & 6 deletions test/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ using Flux: activations
x = randn(Float32,11,7)
b = Flux.Bilinear(11, 11, 3)
@test size(b(x)) == (3,7)
@test_nowarn gs = gradient(() -> sum(abs2.(b(x))), params(b))
test_gradients(b, x)
end

@testset "constructors" begin
Expand Down Expand Up @@ -436,16 +436,15 @@ end
@testset "gradients of Chain{Vector}" begin
m1 = Chain(Dense(3,4,tanh; bias=false), Dense(4,2))
m1v = Chain([m1[1], m1[2]])
@test sum(length, params(m1)) == sum(length, params(m1v))
@test sum(length, trainables(m1)) == sum(length, trainables(m1v))

x1 = randn(Float32,3,5)
@test m1(x1) m1v(x1)

y1 = rand(Bool,2,5)
g1 = gradient(() -> Flux.Losses.logitcrossentropy(m1(x1), y1), params(m1))
g1v = gradient(() -> Flux.Losses.logitcrossentropy(m1v(x1), y1), params(m1v))
@test g1[m1[1].weight] g1v[m1v[1].weight]
@test g1[m1[2].bias] g1v[m1v[2].bias]
g1 = gradient(m1 -> Flux.logitcrossentropy(m1(x1), y1), m1)[1]
g1v = gradient(m1v -> Flux.logitcrossentropy(m1v(x1), y1), m1v)[1]
check_equal_leaves(g1, g1v)

@test Flux.destructure(m1)[1] Flux.destructure(m1v)[1]
z1 = rand(22);
Expand Down
5 changes: 2 additions & 3 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
using Flux
using Flux: OneHotArray, OneHotMatrix, OneHotVector
using Flux: params
using Test
using Random, Statistics, LinearAlgebra
using IterTools: ncycle
Expand All @@ -11,9 +10,9 @@ using Functors: fmapstructure_with_path

## Uncomment below to change the default test settings
# ENV["FLUX_TEST_AMDGPU"] = "true"
ENV["FLUX_TEST_CUDA"] = "true"
# ENV["FLUX_TEST_CUDA"] = "true"
# ENV["FLUX_TEST_METAL"] = "true"
ENV["FLUX_TEST_CPU"] = "false"
# ENV["FLUX_TEST_CPU"] = "false"
# ENV["FLUX_TEST_DISTRIBUTED_MPI"] = "true"
# ENV["FLUX_TEST_DISTRIBUTED_NCCL"] = "true"
ENV["FLUX_TEST_ENZYME"] = "false" # We temporarily disable Enzyme tests since they are failing
Expand Down
28 changes: 6 additions & 22 deletions test/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ using Flux
using Flux: throttle, nfan, glorot_uniform, glorot_normal,
kaiming_normal, kaiming_uniform, orthogonal, truncated_normal,
sparse_init, identity_init, unstack, batch, unbatch,
unsqueeze, params, loadmodel!
unsqueeze, loadmodel!
using MLUtils
using Statistics, LinearAlgebra
using Random
Expand Down Expand Up @@ -334,28 +334,12 @@ end
o = ones(s)
z = zeros(s)

@testset "Explicit" begin
gfun(args...) = gradient((x, y) -> sum(op.(x,y)), args...)
g = gfun(o, z)
@test gfun(o, false) == (g[1], nothing)
gfun(args...) = gradient((x, y) -> sum(op.(x,y)), args...)
g = gfun(o, z)
@test gfun(o, false) == (g[1], nothing)

g = gfun(z, o)
@test gfun(false, o) == (nothing, g[2])
end

@testset "Implicit" begin
gfun(args...) = gradient(() -> sum(op.(args...)), params(collect(args)))
g = gfun(o, z)

gres = gfun(o, false)
@test gres[o] == g[o]
@test false gres.params

g = gfun(z, o)
gres = gfun(false, o)
@test gres[o] == g[o]
@test false gres.params
end
g = gfun(z, o)
@test gfun(false, o) == (nothing, g[2])
end
end

Expand Down

0 comments on commit 2c3e257

Please sign in to comment.