diff --git a/Project.toml b/Project.toml index 048736f66d..dcdcdc0682 100644 --- a/Project.toml +++ b/Project.toml @@ -77,15 +77,14 @@ Optimisers = "0.3" Pkg = "1.9" PrecompileTools = "1.2" Random = "1.9" +ReTestItems = "1.23.1" Reexport = "1" ReverseDiff = "1.15" -SafeTestsets = "0.1" Setfield = "1" SparseArrays = "1.9" StableRNGs = "1" Statistics = "1.9" Test = "1.9" -TestSetExtensions = "3" Tracker = "0.2.31" TruncatedStacktraces = "1.1" WeightInitializers = "0.1.4" @@ -108,13 +107,13 @@ LuxTestUtils = "ac9de150-d08f-4546-94fb-7472b5760531" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" +Reexport = "189a3867-3050-52da-a836-e630ba90ab69" +ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823" Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" -TestSetExtensions = "98d24dd4-01ad-11ea-1b02-c9a08f80db04" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["ADTypes", "Aqua", "ChainRulesCore", "ComponentArrays", "Flux", "Functors", "LuxAMDGPU", "LuxCUDA", "LuxCore", "LuxLib", "LuxTestUtils", "Optimisers", "Random", "SafeTestsets", "Setfield", "StableRNGs", "Statistics", "Test", "TestSetExtensions", "Zygote"] +test = ["ADTypes", "Aqua", "ChainRulesCore", "ComponentArrays", "Flux", "Functors", "LuxAMDGPU", "LuxCUDA", "LuxCore", "LuxLib", "LuxTestUtils", "Optimisers", "Random", "Setfield", "StableRNGs", "Statistics", "Test", "Zygote", "ReTestItems", "Reexport", "Logging"] diff --git a/docs/tutorials.jl b/docs/tutorials.jl index fcf531cd53..3bcfb4d825 100644 --- a/docs/tutorials.jl +++ b/docs/tutorials.jl @@ -2,8 +2,8 @@ using Distributed addprocs(parse(Int, get(ENV, "LUX_DOCUMENTATION_NWORKERS", "1"))) -@everywhere const LUX_DOCUMENTATION_NWORKERS = parse( - Int, get(ENV, "LUX_DOCUMENTATION_NWORKERS", "1")) +@everywhere const LUX_DOCUMENTATION_NWORKERS = parse(Int, + get(ENV, "LUX_DOCUMENTATION_NWORKERS", "1")) @info "Lux Tutorial Build Running tutorials with $(LUX_DOCUMENTATION_NWORKERS) workers." @everywhere const CUDA_MEMORY_LIMIT = 100 ÷ LUX_DOCUMENTATION_NWORKERS @@ -30,8 +30,7 @@ pmap(enumerate(TUTORIALS)) do (i, (d, p)) name = "$(i)_$(first(rsplit(p, "/")))" p_ = get_example_path(p) OUTPUT = joinpath(@__DIR__, "src", "tutorials") - return Literate.markdown( - p_, joinpath(OUTPUT, d); execute=true, name, documenter=true, - preprocess=Base.Fix1(preprocess, p_)) + return Literate.markdown(p_, joinpath(OUTPUT, d); execute=true, name, + documenter=true, preprocess=Base.Fix1(preprocess, p_)) end end diff --git a/src/Lux.jl b/src/Lux.jl index 0898055147..e200c853f9 100644 --- a/src/Lux.jl +++ b/src/Lux.jl @@ -12,8 +12,7 @@ PrecompileTools.@recompile_invalidations begin import LuxCore: AbstractExplicitLayer, AbstractExplicitContainerLayer, initialparameters, initialstates, parameterlength, statelength, - update_state, - trainmode, testmode, setup, apply, display_name + update_state, trainmode, testmode, setup, apply, display_name import LuxDeviceUtils: AbstractLuxDevice, AbstractLuxGPUDevice, AbstractLuxDeviceAdaptor end diff --git a/src/chainrules.jl b/src/chainrules.jl index c75a4b84c3..021032201b 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -25,9 +25,7 @@ function CRC.rrule(::typeof(merge), nt1::NamedTuple{F1}, nt2::NamedTuple{F2}) wh dnt2 = NamedTuple((f2 => getproperty(dy, f2) for f2 in F2)) return (NoTangent(), dnt1, dnt2) end - function ∇merge(dy::Union{NoTangent, ZeroTangent}) - return (NoTangent(), NoTangent(), NoTangent()) - end + ∇merge(::Union{NoTangent, ZeroTangent}) = (NoTangent(), NoTangent(), NoTangent()) return y, ∇merge end diff --git a/src/deprecated.jl b/src/deprecated.jl index c894476613..d9745cba7a 100644 --- a/src/deprecated.jl +++ b/src/deprecated.jl @@ -11,9 +11,8 @@ This function has been deprecated. Use [`cpu_device`](@ref) instead. ::: """ function cpu(x) - Base.depwarn( - "`cpu` has been deprecated and will be removed in v0.6. Use `cpu_device` instead.", - :cpu) + Base.depwarn("`cpu` has been deprecated and will be removed in v0.6. Use \ + `cpu_device` instead.", :cpu) return (cpu_device())(x) end @@ -30,10 +29,11 @@ inside performance critical code will cause massive slowdowns due to type infere ::: """ function gpu(x) - @warn "Using `gpu` inside performance critical code will cause massive slowdowns due to type inference failure. Please update your code to use `gpu_device` API." maxlog=1 + @warn "Using `gpu` inside performance critical code will cause massive slowdowns due \ + to type inference failure. Please update your code to use `gpu_device` \ + API." maxlog=1 - Base.depwarn( - "`gpu` has been deprecated and will be removed in v0.6. Use `gpu_device` instead.", - :gpu) + Base.depwarn("`gpu` has been deprecated and will be removed in v0.6. Use \ + `gpu_device` instead.", :gpu) return (gpu_device())(x) end diff --git a/src/stacktraces.jl b/src/stacktraces.jl index 3a21cfbdc1..ce54894bf5 100644 --- a/src/stacktraces.jl +++ b/src/stacktraces.jl @@ -13,46 +13,6 @@ function disable_stacktrace_truncation!(; disable::Bool=true) end end -# NamedTuple / Tuples -- Lux uses them quite frequenty (states) making the error messages -# too verbose -@static if VERSION ≤ v"1.9" && !TruncatedStacktraces.DISABLE - function Base.show(io::IO, t::Type{<:Tuple}) - if (TruncatedStacktraces.VERBOSE[] || - !hasfield(t, :parameters) || - length(t.parameters) == 0) - invoke(show, Tuple{IO, Type}, io, t) - else - try - fields = t.parameters - fields_truncated = length(fields) > 2 ? "$(fields[1]),$(fields[2]),…" : - (length(fields) == 2 ? "$(fields[1]),$(fields[2])" : - (length(fields) == 1 ? "$(fields[1])" : "")) - print(io, "Tuple{$fields_truncated}") - catch - invoke(show, Tuple{IO, Type}, io, t) - end - end - end - - function Base.show(io::IO, t::Type{<:NamedTuple}) - if (TruncatedStacktraces.VERBOSE[] || - !hasfield(t, :parameters) || - length(t.parameters) == 0) - invoke(show, Tuple{IO, Type}, io, t) - else - try - fields = first(t.parameters) - fields_truncated = length(fields) > 2 ? "$(fields[1]),$(fields[2]),…" : - (length(fields) == 2 ? "$(fields[1]),$(fields[2])" : - (length(fields) == 1 ? "$(fields[1])" : "")) - print(io, "NamedTuple{($fields_truncated),…}") - catch - invoke(show, Tuple{IO, Type}, io, t) - end - end - end -end - # Lux Layers ## layers/basic.jl @truncate_stacktrace WrappedFunction diff --git a/test/aqua.jl b/test/aqua.jl deleted file mode 100644 index c41191af89..0000000000 --- a/test/aqua.jl +++ /dev/null @@ -1,16 +0,0 @@ -using Aqua, ChainRulesCore, Lux, Test - -@testset "All Tests (except Ambiguity & Piracy)" begin - Aqua.test_all(Lux; ambiguities=false, piracies=false) -end - -@testset "Ambiguity Tests" begin - # The exclusions are due to CRC.@non_differentiable - Aqua.test_ambiguities(Lux; exclude=[ChainRulesCore.frule, Core.kwcall]) -end - -@testset "Piracy Tests" begin - # The exclusions are due to CRC.@non_differentiable - Aqua.test_piracies(Lux; - treat_as_own=[ChainRulesCore.frule, ChainRulesCore.rrule, Core.kwcall]) -end diff --git a/test/aqua_tests.jl b/test/aqua_tests.jl new file mode 100644 index 0000000000..4d3059285b --- /dev/null +++ b/test/aqua_tests.jl @@ -0,0 +1,7 @@ +@testitem "Aqua: Quality Assurance" begin + using Aqua, ChainRulesCore + + Aqua.test_all(Lux; piracies=false) + Aqua.test_piracies(Lux; + treat_as_own=[ChainRulesCore.frule, ChainRulesCore.rrule, Core.kwcall]) +end diff --git a/test/contrib/compact.jl b/test/contrib/compact.jl deleted file mode 100644 index 0edc2dbb1a..0000000000 --- a/test/contrib/compact.jl +++ /dev/null @@ -1,268 +0,0 @@ -using ComponentArrays, Lux, Test -import Lux.Experimental: @compact - -include("../test_utils.jl") - -rng = get_stable_rng(12345) - -function similar_strings(s₁::String, s₂::String) - if s₁ != s₂ - println(stderr, "s₁: ", s₁) - println(stderr, "s₂: ", s₂) - end - return s₁ == s₂ -end - -function get_model_string(model) - io = IOBuffer() - show(io, MIME"text/plain"(), model) - return String(take!(io)) -end - -@testset "$mode: @compact" for (mode, aType, device, ongpu) in MODES - @testset "Linear Layer" begin - r = @compact(w=[1, 5, 10]) do x - return sum(w .* x) - end - ps, st = Lux.setup(rng, r) |> device - - @test ps.w == ([1, 5, 10] |> device) - @test st == NamedTuple() - - x = [1, 1, 1] |> aType - @test first(r(x, ps, st)) == 1 + 5 + 10 - - x = [1, 2, 3] |> aType - @test first(r(x, ps, st)) == 1 + 2 * 5 + 3 * 10 - - x = ones(3, 3) |> aType - @test first(r(x, ps, st)) == 3 * (1 + 5 + 10) - - @jet r(x, ps, st) - - # Test gradients: - x = [1, 1, 1] |> aType - @test Zygote.gradient(x -> sum(first(r(x, ps, st))), x)[1] == ps.w - end - - @testset "Linear Layer with Activation" begin - d_in = 5 - d_out = 7 - d = @compact(W=randn(d_out, d_in), b=zeros(d_out),act=relu) do x - y = W * x - return act.(y .+ b) - end - - ps, st = Lux.setup(rng, d) |> device - @test size(ps.W) == (7, 5) - @test size(ps.b) == (7,) - @test st.act == relu - - x = ones(5, 10) |> aType - @test size(first(d(x, ps, st))) == (7, 10) - - x = randn(rng, 5, 10) |> aType - @test all(≥(0), first(d(x, ps, st))) - - @jet d(x, ps, st) - - # Test gradients: - y, ∇ = Zygote.withgradient(ps) do ps - input = randn(5, 32) |> aType - desired_output = randn(7, 32) |> aType - prediction = first(d(input, ps, st)) - return sum(abs2, prediction .- desired_output) - end - - @test y isa AbstractFloat - grads = ∇[1] - @test length(grads) == 2 - @test Set(size.(values(grads))) == Set([(7, 5), (7,)]) - - # Test equivalence to Dense layer: - ps_dense = (; weight=ps.W, bias=ps.b) - st_dense = NamedTuple() - dense = Dense(d_in => d_out, relu) - - x = [1, 2, 3, 4, 5] |> aType - @test first(d(x, ps, st)) ≈ first(dense(x, ps_dense, st_dense)) - end - - @testset "MLP" begin - n_in = 1 - n_out = 1 - nlayers = 3 - - model = @compact(w1=Dense(n_in, 128), - w2=[Dense(128, 128) for i in 1:nlayers], - w3=Dense(128, n_out), - act=relu) do x - embed = act.(w1(x)) - for w in w2 - embed = act.(w(embed)) - end - out = w3(embed) - return out - end - - ps, st = Lux.setup(rng, model) |> device - - @test size(ps.w1.weight) == (128, 1) - @test size(ps.w1.bias) == (128, 1) - @test length(ps.w2) == nlayers - for i in 1:nlayers - @test size(ps.w2[i].weight) == (128, 128) - @test size(ps.w2[i].bias) == (128, 1) - end - @test size(ps.w3.weight) == (1, 128) - @test size(ps.w3.bias) == (1, 1) - - x = randn(n_in, 32) |> aType - - @test size(first(model(x, ps, st))) == (1, 32) - - ps2 = ps |> cpu_device() |> ComponentArray |> device - - @test size(first(model(x, ps2, st))) == (1, 32) - - @jet model(x, ps, st) - - __f = (x, ps) -> sum(first(model(x, ps, st))) - - @eval @test_gradients $__f $x $ps gpu_testing=$ongpu atol=1.0f-3 rtol=1.0f-3 - end - - @testset "String Representations" begin - model = @compact(w=Dense(32 => 32)) do (x, y) - tmp = sum(w(x)) - return tmp + y - end - expected_string = """@compact( - w = Dense(32 => 32), # 1_056 parameters - ) do (x, y) - tmp = sum(w(x)) - return tmp + y - end # Total: 1_056 parameters, - # plus 0 states.""" - - @test similar_strings(get_model_string(model), expected_string) - end - - @testset "Custom Naming" begin - model = @compact(w=Dense(32, 32), name="Linear(...)") do (x, y) - tmp = sum(w(x)) - return tmp + y - end - expected_string = "Linear(...)() # 1_056 parameters" - @test similar_strings(get_model_string(model), expected_string) - end - - @testset "Hierarchical Models" begin - model1 = @compact(w1=Dense(32 => 32, relu), w2=Dense(32 => 32, relu)) do x - return w2(w1(x)) - end - model2 = @compact(w1=model1, w2=Dense(32 => 32, relu)) do x - return w2(w1(x)) - end - expected_string = """@compact( - w1 = @compact( - w1 = Dense(32 => 32, relu), # 1_056 parameters - w2 = Dense(32 => 32, relu), # 1_056 parameters - ) do x - return w2(w1(x)) - end, - w2 = Dense(32 => 32, relu), # 1_056 parameters - ) do x - return w2(w1(x)) - end # Total: 3_168 parameters, - # plus 0 states.""" - @test similar_strings(get_model_string(model2), expected_string) - end - - @testset "Array Parameters" begin - model = @compact(x=randn(32), w=Dense(32 => 32)) do s - return w(x .* s) - end - expected_string = """@compact( - x = randn(32), - w = Dense(32 => 32), # 1_056 parameters - ) do s - return w(x .* s) - end # Total: 1_088 parameters, - # plus 0 states.""" - @test similar_strings(get_model_string(model), expected_string) - end - - @testset "Hierarchy with Inner Model Named" begin - model = @compact(w1=@compact(w1=randn(32, 32), name="Model(32)") do x - return w1 * x - end, - w2=randn(32, 32), - w3=randn(32),) do x - return w2 * w1(x) - end - expected_string = """@compact( - w1 = Model(32)(), # 1_024 parameters - w2 = randn(32, 32), - w3 = randn(32), - ) do x - return w2 * w1(x) - end # Total: 2_080 parameters, - # plus 0 states.""" - @test similar_strings(get_model_string(model), expected_string) - end - - @testset "Hierarchy with Outer Model Named" begin - model = @compact(w1=@compact(w1=randn(32, 32)) do x - return w1 * x - end, - w2=randn(32, 32), - w3=randn(32), - name="Model(32)") do x - return w2 * w1(x) - end - expected_string = """Model(32)() # 2_080 parameters""" - @test similar_strings(get_model_string(model), expected_string) - end - - @testset "Dependent Initializations" begin - # Test that initialization lines cannot depend on each other - @test_throws UndefVarError @compact(y₁=3, z=y₁^2) do x - return y₁ + z + x - end - end - - @testset "Keyword Argument Syntax" begin - _a = 3 - _b = 4 - c = 5 - model = @compact(a=_a; b=_b,c) do x - return a + b * x + c * x^2 - end - ps, st = Lux.setup(rng, model) |> device - @test first(model(2, ps, st)) == _a + _b * 2 + c * 2^2 - end - - @testset "Keyword Arguments with Anonymous Function" begin - model = @test_nowarn @compact(x->x + a + b; a=1,b=2) - ps, st = Lux.setup(rng, model) |> device - @test first(model(3, ps, st)) == 1 + 2 + 3 - expected_string = """@compact( - a = 1, - b = 2, - ) do x - x + a + b - end # Total: 0 parameters, - # plus 2 states.""" - @test similar_strings(get_model_string(model), expected_string) - end - - @testset "Scoping of Parameter Arguments" begin - model = @compact(w1=3, w2=5) do a - g(w1, w2) = 2 * w1 * w2 - return (w1 + w2) * g(a, a) - end - ps, st = Lux.setup(rng, model) |> device - @test first(model(2, ps, st)) == (3 + 5) * 2 * 2 * 2 - end -end diff --git a/test/contrib/compact_tests.jl b/test/contrib/compact_tests.jl new file mode 100644 index 0000000000..e3f141f4ef --- /dev/null +++ b/test/contrib/compact_tests.jl @@ -0,0 +1,268 @@ +@testitem "@compact" setup=[SharedTestSetup] begin + using ComponentArrays + import Lux.Experimental: @compact + + rng = get_stable_rng(12345) + + function similar_strings(s₁::String, s₂::String) + if s₁ != s₂ + println(stderr, "s₁: ", s₁) + println(stderr, "s₂: ", s₂) + end + return s₁ == s₂ + end + + function get_model_string(model) + io = IOBuffer() + show(io, MIME"text/plain"(), model) + return String(take!(io)) + end + + @testset "$mode: @compact" for (mode, aType, device, ongpu) in MODES + @testset "Linear Layer" begin + r = @compact(w=[1, 5, 10]) do x + return sum(w .* x) + end + ps, st = Lux.setup(rng, r) |> device + + @test ps.w == ([1, 5, 10] |> device) + @test st == NamedTuple() + + x = [1, 1, 1] |> aType + @test first(r(x, ps, st)) == 1 + 5 + 10 + + x = [1, 2, 3] |> aType + @test first(r(x, ps, st)) == 1 + 2 * 5 + 3 * 10 + + x = ones(3, 3) |> aType + @test first(r(x, ps, st)) == 3 * (1 + 5 + 10) + + @jet r(x, ps, st) + + # Test gradients: + x = [1, 1, 1] |> aType + @test Zygote.gradient(x -> sum(first(r(x, ps, st))), x)[1] == ps.w + end + + @testset "Linear Layer with Activation" begin + d_in = 5 + d_out = 7 + d = @compact(W=randn(d_out, d_in), b=zeros(d_out),act=relu) do x + y = W * x + return act.(y .+ b) + end + + ps, st = Lux.setup(rng, d) |> device + @test size(ps.W) == (7, 5) + @test size(ps.b) == (7,) + @test st.act == relu + + x = ones(5, 10) |> aType + @test size(first(d(x, ps, st))) == (7, 10) + + x = randn(rng, 5, 10) |> aType + @test all(≥(0), first(d(x, ps, st))) + + @jet d(x, ps, st) + + # Test gradients: + y, ∇ = Zygote.withgradient(ps) do ps + input = randn(5, 32) |> aType + desired_output = randn(7, 32) |> aType + prediction = first(d(input, ps, st)) + return sum(abs2, prediction .- desired_output) + end + + @test y isa AbstractFloat + grads = ∇[1] + @test length(grads) == 2 + @test Set(size.(values(grads))) == Set([(7, 5), (7,)]) + + # Test equivalence to Dense layer: + ps_dense = (; weight=ps.W, bias=ps.b) + st_dense = NamedTuple() + dense = Dense(d_in => d_out, relu) + + x = [1, 2, 3, 4, 5] |> aType + @test first(d(x, ps, st)) ≈ first(dense(x, ps_dense, st_dense)) + end + + @testset "MLP" begin + n_in = 1 + n_out = 1 + nlayers = 3 + + model = @compact(w1=Dense(n_in, 128), + w2=[Dense(128, 128) for i in 1:nlayers], + w3=Dense(128, n_out), + act=relu) do x + embed = act.(w1(x)) + for w in w2 + embed = act.(w(embed)) + end + out = w3(embed) + return out + end + + ps, st = Lux.setup(rng, model) |> device + + @test size(ps.w1.weight) == (128, 1) + @test size(ps.w1.bias) == (128, 1) + @test length(ps.w2) == nlayers + for i in 1:nlayers + @test size(ps.w2[i].weight) == (128, 128) + @test size(ps.w2[i].bias) == (128, 1) + end + @test size(ps.w3.weight) == (1, 128) + @test size(ps.w3.bias) == (1, 1) + + x = randn(n_in, 32) |> aType + + @test size(first(model(x, ps, st))) == (1, 32) + + ps2 = ps |> cpu_device() |> ComponentArray |> device + + @test size(first(model(x, ps2, st))) == (1, 32) + + @jet model(x, ps, st) + + __f = (x, ps) -> sum(first(model(x, ps, st))) + + @eval @test_gradients $__f $x $ps gpu_testing=$ongpu atol=1.0f-3 rtol=1.0f-3 + end + + @testset "String Representations" begin + model = @compact(w=Dense(32 => 32)) do (x, y) + tmp = sum(w(x)) + return tmp + y + end + expected_string = """@compact( + w = Dense(32 => 32), # 1_056 parameters + ) do (x, y) + tmp = sum(w(x)) + return tmp + y + end # Total: 1_056 parameters, + # plus 0 states.""" + + @test similar_strings(get_model_string(model), expected_string) + end + + @testset "Custom Naming" begin + model = @compact(w=Dense(32, 32), name="Linear(...)") do (x, y) + tmp = sum(w(x)) + return tmp + y + end + expected_string = "Linear(...)() # 1_056 parameters" + @test similar_strings(get_model_string(model), expected_string) + end + + @testset "Hierarchical Models" begin + model1 = @compact(w1=Dense(32 => 32, relu), w2=Dense(32 => 32, relu)) do x + return w2(w1(x)) + end + model2 = @compact(w1=model1, w2=Dense(32 => 32, relu)) do x + return w2(w1(x)) + end + expected_string = """@compact( + w1 = @compact( + w1 = Dense(32 => 32, relu), # 1_056 parameters + w2 = Dense(32 => 32, relu), # 1_056 parameters + ) do x + return w2(w1(x)) + end, + w2 = Dense(32 => 32, relu), # 1_056 parameters + ) do x + return w2(w1(x)) + end # Total: 3_168 parameters, + # plus 0 states.""" + @test similar_strings(get_model_string(model2), expected_string) + end + + @testset "Array Parameters" begin + model = @compact(x=randn(32), w=Dense(32 => 32)) do s + return w(x .* s) + end + expected_string = """@compact( + x = randn(32), + w = Dense(32 => 32), # 1_056 parameters + ) do s + return w(x .* s) + end # Total: 1_088 parameters, + # plus 0 states.""" + @test similar_strings(get_model_string(model), expected_string) + end + + @testset "Hierarchy with Inner Model Named" begin + model = @compact(w1=@compact(w1=randn(32, 32), name="Model(32)") do x + return w1 * x + end, + w2=randn(32, 32), + w3=randn(32),) do x + return w2 * w1(x) + end + expected_string = """@compact( + w1 = Model(32)(), # 1_024 parameters + w2 = randn(32, 32), + w3 = randn(32), + ) do x + return w2 * w1(x) + end # Total: 2_080 parameters, + # plus 0 states.""" + @test similar_strings(get_model_string(model), expected_string) + end + + @testset "Hierarchy with Outer Model Named" begin + model = @compact(w1=@compact(w1=randn(32, 32)) do x + return w1 * x + end, + w2=randn(32, 32), + w3=randn(32), + name="Model(32)") do x + return w2 * w1(x) + end + expected_string = """Model(32)() # 2_080 parameters""" + @test similar_strings(get_model_string(model), expected_string) + end + + @testset "Dependent Initializations" begin + # Test that initialization lines cannot depend on each other + @test_throws UndefVarError @compact(y₁=3, z=y₁^2) do x + return y₁ + z + x + end + end + + @testset "Keyword Argument Syntax" begin + _a = 3 + _b = 4 + c = 5 + model = @compact(a=_a; b=_b,c) do x + return a + b * x + c * x^2 + end + ps, st = Lux.setup(rng, model) |> device + @test first(model(2, ps, st)) == _a + _b * 2 + c * 2^2 + end + + @testset "Keyword Arguments with Anonymous Function" begin + model = @test_nowarn @compact(x->x + a + b; a=1,b=2) + ps, st = Lux.setup(rng, model) |> device + @test first(model(3, ps, st)) == 1 + 2 + 3 + expected_string = """@compact( + a = 1, + b = 2, + ) do x + x + a + b + end # Total: 0 parameters, + # plus 2 states.""" + @test similar_strings(get_model_string(model), expected_string) + end + + @testset "Scoping of Parameter Arguments" begin + model = @compact(w1=3, w2=5) do a + g(w1, w2) = 2 * w1 * w2 + return (w1 + w2) * g(a, a) + end + ps, st = Lux.setup(rng, model) |> device + @test first(model(2, ps, st)) == (3 + 5) * 2 * 2 * 2 + end + end +end diff --git a/test/contrib/debug.jl b/test/contrib/debug.jl deleted file mode 100644 index 0f409516b5..0000000000 --- a/test/contrib/debug.jl +++ /dev/null @@ -1,139 +0,0 @@ -using Lux, Test, Logging, Zygote, ChainRulesCore -const CRC = ChainRulesCore - -include("../test_utils.jl") - -rng = get_stable_rng(12345) - -@testset "$mode: Dimension Mismatch" for (mode, aType, device, ongpu) in MODES - model = Chain(Dense(1 => 16, relu), Chain(Dense(16 => 3), Dense(1 => 1)), - BatchNorm(1); disable_optimizations=true) - - ps, st = Lux.setup(rng, model) |> device - x = randn(rng, Float32, 1, 5) |> aType - - @test_throws DimensionMismatch model(x, ps, st) - - model_debug = Lux.Experimental.@debug_mode model - - @test_throws DimensionMismatch model_debug(x, ps, st) - @test_logs (:info,) (:error, - "Layer Dense(1 => 1) failed!! This layer is present at location model.layers.layer_2.layers.layer_2") match_mode=:any try - model_debug(x, ps, st) - catch - end - - model_debug = Lux.Experimental.@debug_mode model error_check=false - - @test_throws DimensionMismatch model_debug(x, ps, st) - @test_logs min_level=Logging.Error try - model_debug(x, ps, st) - catch - end - - model_fixed = Chain(Dense(1 => 16, relu), Chain(Dense(16 => 1), Dense(1 => 1)), - BatchNorm(1); disable_optimizations=true) - - ps, st = Lux.setup(rng, model_fixed) |> device - - @test_nowarn model_fixed(x, ps, st) - - model_fixed_debug = Lux.Experimental.@debug_mode model_fixed - - @test_logs min_level=Logging.Error Zygote.gradient(sum ∘ first ∘ model_fixed_debug, - x, ps, st) -end - -offending_layer(x) = 2 .* x - -function CRC.rrule(::typeof(offending_layer), x) - y = offending_layer(x) - function ∇offending_layer(Δ) - Δ[1:1] .= NaN - return NoTangent(), Δ - end - return y, ∇offending_layer -end - -@testset "$mode: NaN Debugging" for (mode, aType, device, ongpu) in MODES - model = Chain(Dense(1 => 16, relu), Chain(Dense(16 => 1), Dense(1 => 1)), - BatchNorm(1); disable_optimizations=true) - - x = randn(rng, Float32, 1, 5) |> aType - ps, st = Lux.setup(rng, model) |> device - - model_debug = Lux.Experimental.@debug_mode model nan_check=:both - - ps.layer_2.layer_2.weight .*= NaN32 - - @test any(isnan, first(model(x, ps, st)) |> Array) - - @test_throws DomainError model_debug(x, ps, st) - - model_debug2 = Lux.Experimental.@debug_mode model nan_check=:forward - - @test_throws DomainError model_debug2(x, ps, st) - - model_debug3 = Lux.Experimental.@debug_mode model nan_check=:backward - @test any(isnan, first(model_debug3(x, ps, st)) |> Array) - - model_debug4 = Lux.Experimental.@debug_mode model nan_check=:none - @test any(isnan, first(model_debug4(x, ps, st)) |> Array) - - model = Chain(Dense(1 => 16, relu), Chain(Dense(16 => 1), offending_layer), - BatchNorm(1); disable_optimizations=true) - - ps, st = Lux.setup(rng, model) |> device - - @test !any(isnan, first(model(x, ps, st)) |> Array) - - gs = only(Zygote.gradient(ps -> sum(first(model(x, ps, st))), ps)) - @test any(isnan, gs.layer_1.weight) - @test any(isnan, gs.layer_1.bias) - @test any(isnan, gs.layer_2.layer_1.weight) - @test any(isnan, gs.layer_2.layer_1.bias) - @test gs.layer_2.layer_2 === nothing - @test !any(isnan, gs.layer_3.scale) - @test !any(isnan, gs.layer_3.bias) - - model_debug = Lux.Experimental.@debug_mode model nan_check=:both - - @test_logs min_level=Logging.Error model_debug(x, ps, st) - - @test_throws DomainError only(Zygote.gradient(ps -> sum(first(model_debug(x, ps, st))), - ps)) - - model_debug2 = Lux.Experimental.@debug_mode model nan_check=:forward - - @test_logs min_level=Logging.Error model_debug2(x, ps, st) - - gs = only(Zygote.gradient(ps -> sum(first(model_debug2(x, ps, st))), ps)) - @test any(isnan, gs.layer_1.weight) - @test any(isnan, gs.layer_1.bias) - @test any(isnan, gs.layer_2.layer_1.weight) - @test any(isnan, gs.layer_2.layer_1.bias) - @test gs.layer_2.layer_2 === nothing - @test !any(isnan, gs.layer_3.scale) - @test !any(isnan, gs.layer_3.bias) - - model_debug3 = Lux.Experimental.@debug_mode model nan_check=:backward - - @test_logs min_level=Logging.Error model_debug3(x, ps, st) - - @test_throws DomainError only(Zygote.gradient( - ps -> sum(first(model_debug3(x, ps, st))), - ps)) - - model_debug4 = Lux.Experimental.@debug_mode model nan_check=:none - - @test_logs min_level=Logging.Error model_debug4(x, ps, st) - - gs = only(Zygote.gradient(ps -> sum(first(model_debug4(x, ps, st))), ps)) - @test any(isnan, gs.layer_1.weight) - @test any(isnan, gs.layer_1.bias) - @test any(isnan, gs.layer_2.layer_1.weight) - @test any(isnan, gs.layer_2.layer_1.bias) - @test gs.layer_2.layer_2 === nothing - @test !any(isnan, gs.layer_3.scale) - @test !any(isnan, gs.layer_3.bias) -end diff --git a/test/contrib/debug_tests.jl b/test/contrib/debug_tests.jl new file mode 100644 index 0000000000..2ded610b65 --- /dev/null +++ b/test/contrib/debug_tests.jl @@ -0,0 +1,150 @@ +@testitem "Debugging Tools: DimensionMismatch" setup=[SharedTestSetup] begin + using Logging + + rng = get_stable_rng(12345) + + @testset "$mode" for (mode, aType, device, ongpu) in MODES + model = Chain(Dense(1 => 16, relu), Chain(Dense(16 => 3), Dense(1 => 1)), + BatchNorm(1); disable_optimizations=true) + + ps, st = Lux.setup(rng, model) |> device + x = randn(rng, Float32, 1, 5) |> aType + + @test_throws DimensionMismatch model(x, ps, st) + + model_debug = Lux.Experimental.@debug_mode model + + @test_throws DimensionMismatch model_debug(x, ps, st) + @test_logs (:info,) (:error, + "Layer Dense(1 => 1) failed!! This layer is present at location model.layers.layer_2.layers.layer_2") match_mode=:any try + model_debug(x, ps, st) + catch + end + + model_debug = Lux.Experimental.@debug_mode model error_check=false + + @test_throws DimensionMismatch model_debug(x, ps, st) + @test_logs min_level=Logging.Error try + model_debug(x, ps, st) + catch + end + + model_fixed = Chain(Dense(1 => 16, relu), Chain(Dense(16 => 1), Dense(1 => 1)), + BatchNorm(1); disable_optimizations=true) + + ps, st = Lux.setup(rng, model_fixed) |> device + + @test_nowarn model_fixed(x, ps, st) + + model_fixed_debug = Lux.Experimental.@debug_mode model_fixed + + @test_logs min_level=Logging.Error Zygote.gradient(sum ∘ first ∘ model_fixed_debug, + x, ps, st) + end +end + +@testitem "Debugging Tools: NaN" setup=[SharedTestSetup] begin + using Logging, ChainRulesCore + import ChainRulesCore as CRC + + rng = get_stable_rng(12345) + + offending_layer(x) = 2 .* x + + function CRC.rrule(::typeof(offending_layer), x) + y = offending_layer(x) + function ∇offending_layer(Δ) + Δ[1:1] .= NaN + return NoTangent(), Δ + end + return y, ∇offending_layer + end + + @testset "$mode: NaN Debugging" for (mode, aType, device, ongpu) in MODES + model = Chain(Dense(1 => 16, relu), Chain(Dense(16 => 1), Dense(1 => 1)), + BatchNorm(1); disable_optimizations=true) + + x = randn(rng, Float32, 1, 5) |> aType + ps, st = Lux.setup(rng, model) |> device + + model_debug = Lux.Experimental.@debug_mode model nan_check=:both + + ps.layer_2.layer_2.weight .*= NaN32 + + @test any(isnan, first(model(x, ps, st)) |> Array) + + @test_throws DomainError model_debug(x, ps, st) + + model_debug2 = Lux.Experimental.@debug_mode model nan_check=:forward + + @test_throws DomainError model_debug2(x, ps, st) + + model_debug3 = Lux.Experimental.@debug_mode model nan_check=:backward + @test any(isnan, first(model_debug3(x, ps, st)) |> Array) + + model_debug4 = Lux.Experimental.@debug_mode model nan_check=:none + @test any(isnan, first(model_debug4(x, ps, st)) |> Array) + + model = Chain(Dense(1 => 16, relu), Chain(Dense(16 => 1), offending_layer), + BatchNorm(1); disable_optimizations=true) + + ps, st = Lux.setup(rng, model) |> device + + @test !any(isnan, first(model(x, ps, st)) |> Array) + + gs = only(Zygote.gradient(ps -> sum(first(model(x, ps, st))), ps)) + @test any(isnan, gs.layer_1.weight) + @test any(isnan, gs.layer_1.bias) + @test any(isnan, gs.layer_2.layer_1.weight) + @test any(isnan, gs.layer_2.layer_1.bias) + @test gs.layer_2.layer_2 === nothing + @test !any(isnan, gs.layer_3.scale) + @test !any(isnan, gs.layer_3.bias) + + model_debug = Lux.Experimental.@debug_mode model nan_check=:both + + @test_logs min_level=Logging.Error model_debug(x, ps, st) + + @test_throws DomainError only(Zygote.gradient( + ps -> sum(first(model_debug(x, + ps, + st))), + ps)) + + model_debug2 = Lux.Experimental.@debug_mode model nan_check=:forward + + @test_logs min_level=Logging.Error model_debug2(x, ps, st) + + gs = only(Zygote.gradient(ps -> sum(first(model_debug2(x, ps, st))), ps)) + @test any(isnan, gs.layer_1.weight) + @test any(isnan, gs.layer_1.bias) + @test any(isnan, gs.layer_2.layer_1.weight) + @test any(isnan, gs.layer_2.layer_1.bias) + @test gs.layer_2.layer_2 === nothing + @test !any(isnan, gs.layer_3.scale) + @test !any(isnan, gs.layer_3.bias) + + model_debug3 = Lux.Experimental.@debug_mode model nan_check=:backward + + @test_logs min_level=Logging.Error model_debug3(x, ps, st) + + @test_throws DomainError only(Zygote.gradient( + ps -> sum(first(model_debug3(x, + ps, + st))), + ps)) + + model_debug4 = Lux.Experimental.@debug_mode model nan_check=:none + + @test_logs min_level=Logging.Error model_debug4(x, ps, st) + + gs = only(Zygote.gradient(ps -> sum(first(model_debug4(x, ps, st))), ps)) + @test any(isnan, gs.layer_1.weight) + @test any(isnan, gs.layer_1.bias) + @test any(isnan, gs.layer_2.layer_1.weight) + @test any(isnan, gs.layer_2.layer_1.bias) + @test gs.layer_2.layer_2 === nothing + @test !any(isnan, gs.layer_3.scale) + @test !any(isnan, gs.layer_3.bias) + end +end diff --git a/test/contrib/freeze.jl b/test/contrib/freeze.jl deleted file mode 100644 index da42385b7f..0000000000 --- a/test/contrib/freeze.jl +++ /dev/null @@ -1,84 +0,0 @@ -using ComponentArrays, Lux, Test - -include("../test_utils.jl") - -rng = get_stable_rng(12345) - -@testset "$mode: All Parameters Freezing" for (mode, aType, device, ongpu) in MODES - @testset "NamedTuple" begin - d = Dense(5 => 5) - psd, std = Lux.setup(rng, d) .|> device - - fd, ps, st = Lux.Experimental.freeze(d, psd, std, nothing) - @test length(keys(ps)) == 0 - @test length(keys(st)) == 2 - @test sort([keys(st)...]) == [:frozen_params, :states] - @test sort([keys(st.frozen_params)...]) == [:bias, :weight] - - x = randn(rng, Float32, 5, 1) |> aType - - @test d(x, psd, std)[1] == fd(x, ps, st)[1] - - @jet fd(x, ps, st) - __f = (x, ps) -> sum(first(fd(x, ps, st))) - - @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu - end - - @testset "ComponentArray" begin - m = Chain(Lux.Experimental.freeze(Dense(1 => 3, tanh)), Dense(3 => 1)) - ps, st = Lux.setup(rng, m) - st = st |> device - ps_c = ComponentVector(ps) |> device - ps = ps |> device - x = randn(rng, Float32, 1, 2) |> aType - - @test m(x, ps, st)[1] == m(x, ps_c, st)[1] - - @jet m(x, ps_c, st) - __f = (x, ps) -> sum(first(m(x, ps, st))) - @eval @test_gradients $__f $x $ps_c atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu - end - - @testset "LuxDL/Lux.jl#427" begin - m = Dense(1 => 1) - ps, st = Lux.setup(rng, m) - st = st |> device - ps_c = ComponentVector(ps) |> device - ps = ps |> device - - fd, psf, stf = Lux.Experimental.freeze(m, ps, st) - - @test fd isa Lux.Experimental.FrozenLayer - @test psf isa NamedTuple{} - @test sort([keys(stf)...]) == [:frozen_params, :states] - @test sort([keys(stf.frozen_params)...]) == [:bias, :weight] - - fd, psf, stf = Lux.Experimental.freeze(m, ps_c, st) - - @test fd isa Lux.Experimental.FrozenLayer - @test psf isa NamedTuple{} - @test sort([keys(stf)...]) == [:frozen_params, :states] - @test sort([keys(stf.frozen_params)...]) == [:bias, :weight] - end -end - -@testset "$mode: Partial Freezing" for (mode, aType, device, ongpu) in MODES - d = Dense(5 => 5) - psd, std = Lux.setup(rng, d) .|> device - - fd, ps, st = Lux.Experimental.freeze(d, psd, std, (:weight,)) - @test length(keys(ps)) == 1 - @test length(keys(st)) == 2 - @test sort([keys(st)...]) == [:frozen_params, :states] - @test sort([keys(st.frozen_params)...]) == [:weight] - @test sort([keys(ps)...]) == [:bias] - - x = randn(rng, Float32, 5, 1) |> aType - - @test d(x, psd, std)[1] == fd(x, ps, st)[1] - - @jet fd(x, ps, st) - __f = (x, ps) -> sum(first(fd(x, ps, st))) - @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu -end diff --git a/test/contrib/freeze_tests.jl b/test/contrib/freeze_tests.jl new file mode 100644 index 0000000000..828331d614 --- /dev/null +++ b/test/contrib/freeze_tests.jl @@ -0,0 +1,86 @@ +@testitem "All Parameter Freezing" setup=[SharedTestSetup] begin + rng = get_stable_rng(12345) + + @testset "$mode" for (mode, aType, device, ongpu) in MODES + @testset "NamedTuple" begin + d = Dense(5 => 5) + psd, std = Lux.setup(rng, d) .|> device + + fd, ps, st = Lux.Experimental.freeze(d, psd, std, nothing) + @test length(keys(ps)) == 0 + @test length(keys(st)) == 2 + @test sort([keys(st)...]) == [:frozen_params, :states] + @test sort([keys(st.frozen_params)...]) == [:bias, :weight] + + x = randn(rng, Float32, 5, 1) |> aType + + @test d(x, psd, std)[1] == fd(x, ps, st)[1] + + @jet fd(x, ps, st) + __f = (x, ps) -> sum(first(fd(x, ps, st))) + + @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu + end + + @testset "ComponentArray" begin + m = Chain(Lux.Experimental.freeze(Dense(1 => 3, tanh)), Dense(3 => 1)) + ps, st = Lux.setup(rng, m) + st = st |> device + ps_c = ComponentVector(ps) |> device + ps = ps |> device + x = randn(rng, Float32, 1, 2) |> aType + + @test m(x, ps, st)[1] == m(x, ps_c, st)[1] + + @jet m(x, ps_c, st) + __f = (x, ps) -> sum(first(m(x, ps, st))) + @eval @test_gradients $__f $x $ps_c atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu + end + + @testset "LuxDL/Lux.jl#427" begin + m = Dense(1 => 1) + ps, st = Lux.setup(rng, m) + st = st |> device + ps_c = ComponentVector(ps) |> device + ps = ps |> device + + fd, psf, stf = Lux.Experimental.freeze(m, ps, st) + + @test fd isa Lux.Experimental.FrozenLayer + @test psf isa NamedTuple{} + @test sort([keys(stf)...]) == [:frozen_params, :states] + @test sort([keys(stf.frozen_params)...]) == [:bias, :weight] + + fd, psf, stf = Lux.Experimental.freeze(m, ps_c, st) + + @test fd isa Lux.Experimental.FrozenLayer + @test psf isa NamedTuple{} + @test sort([keys(stf)...]) == [:frozen_params, :states] + @test sort([keys(stf.frozen_params)...]) == [:bias, :weight] + end + end +end + +@testitem "Partial Freezing" setup=[SharedTestSetup] begin + rng = get_stable_rng(12345) + + @testset "$mode" for (mode, aType, device, ongpu) in MODES + d = Dense(5 => 5) + psd, std = Lux.setup(rng, d) .|> device + + fd, ps, st = Lux.Experimental.freeze(d, psd, std, (:weight,)) + @test length(keys(ps)) == 1 + @test length(keys(st)) == 2 + @test sort([keys(st)...]) == [:frozen_params, :states] + @test sort([keys(st.frozen_params)...]) == [:weight] + @test sort([keys(ps)...]) == [:bias] + + x = randn(rng, Float32, 5, 1) |> aType + + @test d(x, psd, std)[1] == fd(x, ps, st)[1] + + @jet fd(x, ps, st) + __f = (x, ps) -> sum(first(fd(x, ps, st))) + @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu + end +end diff --git a/test/contrib/map.jl b/test/contrib/map.jl deleted file mode 100644 index 8a02441d7f..0000000000 --- a/test/contrib/map.jl +++ /dev/null @@ -1,71 +0,0 @@ -using Lux, Setfield, Test - -include("../test_utils.jl") - -function zero_dense_params_1(l, ps, st, name) - if l isa Dense && occursin("model.layers.chain", name) - @set! ps.weight = zero.(ps.weight) - @set! ps.bias = zero.(ps.bias) - end - return l, ps, st -end - -function zero_dense_params_2(l, ps, st, name) - if l isa Dense && occursin("c.layers.chain", name) - @set! ps.weight = zero.(ps.weight) - @set! ps.bias = zero.(ps.bias) - end - return l, ps, st -end - -function zero_dense_params_3(l, ps, st, name) - if l isa Dense - @set! ps.weight = zero.(ps.weight) - @set! ps.bias = zero.(ps.bias) - end - return l, ps, st -end - -@testset "$mode" for (mode, aType, device, ongpu) in MODES - c = Parallel(+; - chain=Chain(; dense_1=Dense(2 => 3), bn=BatchNorm(3), dense_2=Dense(3 => 5)), - dense_3=Dense(5 => 1)) - - rng = get_stable_rng(12345) - ps, st = Lux.setup(rng, c) .|> device - - c_, ps_, st_ = Lux.Experimental.layer_map(zero_dense_params_1, c, ps, st) - - @test all(iszero, ps_.chain.dense_1.weight) - @test all(iszero, ps_.chain.dense_1.bias) - @test all(iszero, ps_.chain.dense_2.weight) - @test all(iszero, ps_.chain.dense_2.bias) - @test !all(iszero, ps_.dense_3.weight) - @test all(iszero, ps_.dense_3.bias) - - c_, ps_, st_ = Lux.Experimental.@layer_map zero_dense_params_2 c ps st - - @test all(iszero, ps_.chain.dense_1.weight) - @test all(iszero, ps_.chain.dense_1.bias) - @test all(iszero, ps_.chain.dense_2.weight) - @test all(iszero, ps_.chain.dense_2.bias) - @test !all(iszero, ps_.dense_3.weight) - @test all(iszero, ps_.dense_3.bias) - - # Custom Layers -- See https://github.com/avik-pal/Lux.jl/issues/187 - struct SimpleCustom{L1, L2} <: Lux.AbstractExplicitContainerLayer{(:dense, :conv)} - dense::L1 - conv::L2 - end - - l = SimpleCustom(Dense(3 => 2), Conv((3,), 3 => 2)) - - ps, st = Lux.setup(rng, l) .|> device - - l_, ps_, st_ = Lux.Experimental.@layer_map zero_dense_params_3 l ps st - - @test all(iszero, ps_.dense.weight) - @test all(iszero, ps_.dense.bias) - @test !all(iszero, ps_.conv.weight) - @test all(iszero, ps_.conv.bias) -end diff --git a/test/contrib/map_tests.jl b/test/contrib/map_tests.jl new file mode 100644 index 0000000000..86268d811f --- /dev/null +++ b/test/contrib/map_tests.jl @@ -0,0 +1,71 @@ +@testitem "Layer Map" setup=[SharedTestSetup] begin + using Setfield + + function zero_dense_params_1(l, ps, st, name) + if l isa Dense && occursin("model.layers.chain", name) + @set! ps.weight = zero.(ps.weight) + @set! ps.bias = zero.(ps.bias) + end + return l, ps, st + end + + function zero_dense_params_2(l, ps, st, name) + if l isa Dense && occursin("c.layers.chain", name) + @set! ps.weight = zero.(ps.weight) + @set! ps.bias = zero.(ps.bias) + end + return l, ps, st + end + + function zero_dense_params_3(l, ps, st, name) + if l isa Dense + @set! ps.weight = zero.(ps.weight) + @set! ps.bias = zero.(ps.bias) + end + return l, ps, st + end + + @testset "$mode" for (mode, aType, device, ongpu) in MODES + c = Parallel(+; + chain=Chain(; dense_1=Dense(2 => 3), bn=BatchNorm(3), dense_2=Dense(3 => 5)), + dense_3=Dense(5 => 1)) + + rng = get_stable_rng(12345) + ps, st = Lux.setup(rng, c) .|> device + + c_, ps_, st_ = Lux.Experimental.layer_map(zero_dense_params_1, c, ps, st) + + @test all(iszero, ps_.chain.dense_1.weight) + @test all(iszero, ps_.chain.dense_1.bias) + @test all(iszero, ps_.chain.dense_2.weight) + @test all(iszero, ps_.chain.dense_2.bias) + @test !all(iszero, ps_.dense_3.weight) + @test all(iszero, ps_.dense_3.bias) + + c_, ps_, st_ = Lux.Experimental.@layer_map zero_dense_params_2 c ps st + + @test all(iszero, ps_.chain.dense_1.weight) + @test all(iszero, ps_.chain.dense_1.bias) + @test all(iszero, ps_.chain.dense_2.weight) + @test all(iszero, ps_.chain.dense_2.bias) + @test !all(iszero, ps_.dense_3.weight) + @test all(iszero, ps_.dense_3.bias) + + # Custom Layers -- See https://github.com/LuxDL/Lux.jl/issues/187 + struct SimpleCustom{L1, L2} <: Lux.AbstractExplicitContainerLayer{(:dense, :conv)} + dense::L1 + conv::L2 + end + + l = SimpleCustom(Dense(3 => 2), Conv((3,), 3 => 2)) + + ps, st = Lux.setup(rng, l) .|> device + + l_, ps_, st_ = Lux.Experimental.@layer_map zero_dense_params_3 l ps st + + @test all(iszero, ps_.dense.weight) + @test all(iszero, ps_.dense.bias) + @test !all(iszero, ps_.conv.weight) + @test all(iszero, ps_.conv.bias) + end +end diff --git a/test/contrib/share_parameters.jl b/test/contrib/share_parameters.jl deleted file mode 100644 index 733bd24ed7..0000000000 --- a/test/contrib/share_parameters.jl +++ /dev/null @@ -1,64 +0,0 @@ -using ComponentArrays, Lux, Test - -include("../test_utils.jl") - -rng = get_stable_rng(12345) - -@testset "$mode" for (mode, aType, device, ongpu) in MODES - model = Chain(; - d1=Dense(2 => 4, tanh), - d2=Chain(; l1=Dense(4 => 2), l2=Dense(2 => 4)), - d3=Dense(4 => 2)) - - ps, st = Lux.setup(rng, model) .|> device - - sharing = (("d2.l2", "d1"), ("d3", "d2.l1")) - - ps_1 = Lux.Experimental.share_parameters(ps, sharing) - - @test ps_1.d2.l2.weight == ps_1.d1.weight - @test ps_1.d2.l2.bias == ps_1.d1.bias - @test ps_1.d3.weight == ps_1.d2.l1.weight - @test ps_1.d3.bias == ps_1.d2.l1.bias - - ps_new_1 = (; weight=randn(rng, Float32, 4, 2), bias=randn(rng, Float32, 4, 1)) |> - device - ps_new_2 = (; weight=randn(rng, Float32, 2, 4), bias=randn(rng, Float32, 2, 1)) |> - device - - ps_2 = Lux.Experimental.share_parameters(ps, sharing, (ps_new_1, ps_new_2)) - - @test ps_2.d2.l2.weight == ps_new_1.weight == ps_2.d1.weight - @test ps_2.d2.l2.bias == ps_new_1.bias == ps_2.d1.bias - @test ps_2.d3.weight == ps_new_2.weight == ps_2.d2.l1.weight - @test ps_2.d3.bias == ps_new_2.bias == ps_2.d2.l1.bias - - # Mix in ComponentArray - ps_new_ca_1 = ComponentArray(ps_new_1 |> LuxCPUDevice()) |> device - - ps_3 = Lux.Experimental.share_parameters(ps, sharing, (ps_new_ca_1, ps_new_2)) - - @test ps_3.d2.l2.weight == ps_new_ca_1.weight == ps_3.d1.weight - @test ps_3.d2.l2.bias == ps_new_ca_1.bias == ps_3.d1.bias - @test ps_3.d3.weight == ps_new_2.weight == ps_3.d2.l1.weight - @test ps_3.d3.bias == ps_new_2.bias == ps_3.d2.l1.bias - - # Input Checks - non_disjoint_sharing = (("d2.l2", "d1"), ("d1", "d2.l1")) - @test_throws ArgumentError Lux.Experimental.share_parameters(ps, non_disjoint_sharing) - @test_throws ArgumentError Lux.Experimental.share_parameters(ps, sharing, (ps_new_1,)) - - # Parameter Structure Mismatch - ps_new_1 = (; weight=randn(rng, Float32, 2, 4), bias=randn(rng, Float32, 4, 1)) |> - device - ps_new_2 = (; weight=randn(rng, Float32, 2, 4), bias=randn(rng, Float32, 2, 1)) |> - device - - @test_throws ArgumentError Lux.Experimental.share_parameters(ps, sharing, - (ps_new_1, ps_new_2)) - - ps_new_ca_1 = ComponentArray(ps_new_1 |> LuxCPUDevice()) |> device - - @test_throws ArgumentError Lux.Experimental.share_parameters(ps, sharing, - (ps_new_ca_1, ps_new_2)) -end diff --git a/test/contrib/share_parameters_tests.jl b/test/contrib/share_parameters_tests.jl new file mode 100644 index 0000000000..0a171cecea --- /dev/null +++ b/test/contrib/share_parameters_tests.jl @@ -0,0 +1,64 @@ +@testitem "Parameter Sharing" setup=[SharedTestSetup] begin + rng = get_stable_rng(12345) + + @testset "$mode" for (mode, aType, device, ongpu) in MODES + model = Chain(; + d1=Dense(2 => 4, tanh), + d2=Chain(; l1=Dense(4 => 2), l2=Dense(2 => 4)), + d3=Dense(4 => 2)) + + ps, st = Lux.setup(rng, model) .|> device + + sharing = (("d2.l2", "d1"), ("d3", "d2.l1")) + + ps_1 = Lux.Experimental.share_parameters(ps, sharing) + + @test ps_1.d2.l2.weight == ps_1.d1.weight + @test ps_1.d2.l2.bias == ps_1.d1.bias + @test ps_1.d3.weight == ps_1.d2.l1.weight + @test ps_1.d3.bias == ps_1.d2.l1.bias + + ps_new_1 = (; weight=randn(rng, Float32, 4, 2), bias=randn(rng, Float32, 4, 1)) |> + device + ps_new_2 = (; weight=randn(rng, Float32, 2, 4), bias=randn(rng, Float32, 2, 1)) |> + device + + ps_2 = Lux.Experimental.share_parameters(ps, sharing, (ps_new_1, ps_new_2)) + + @test ps_2.d2.l2.weight == ps_new_1.weight == ps_2.d1.weight + @test ps_2.d2.l2.bias == ps_new_1.bias == ps_2.d1.bias + @test ps_2.d3.weight == ps_new_2.weight == ps_2.d2.l1.weight + @test ps_2.d3.bias == ps_new_2.bias == ps_2.d2.l1.bias + + # Mix in ComponentArray + ps_new_ca_1 = ComponentArray(ps_new_1 |> LuxCPUDevice()) |> device + + ps_3 = Lux.Experimental.share_parameters(ps, sharing, (ps_new_ca_1, ps_new_2)) + + @test ps_3.d2.l2.weight == ps_new_ca_1.weight == ps_3.d1.weight + @test ps_3.d2.l2.bias == ps_new_ca_1.bias == ps_3.d1.bias + @test ps_3.d3.weight == ps_new_2.weight == ps_3.d2.l1.weight + @test ps_3.d3.bias == ps_new_2.bias == ps_3.d2.l1.bias + + # Input Checks + non_disjoint_sharing = (("d2.l2", "d1"), ("d1", "d2.l1")) + @test_throws ArgumentError Lux.Experimental.share_parameters(ps, + non_disjoint_sharing) + @test_throws ArgumentError Lux.Experimental.share_parameters(ps, sharing, + (ps_new_1,)) + + # Parameter Structure Mismatch + ps_new_1 = (; weight=randn(rng, Float32, 2, 4), bias=randn(rng, Float32, 4, 1)) |> + device + ps_new_2 = (; weight=randn(rng, Float32, 2, 4), bias=randn(rng, Float32, 2, 1)) |> + device + + @test_throws ArgumentError Lux.Experimental.share_parameters(ps, sharing, + (ps_new_1, ps_new_2)) + + ps_new_ca_1 = ComponentArray(ps_new_1 |> LuxCPUDevice()) |> device + + @test_throws ArgumentError Lux.Experimental.share_parameters(ps, sharing, + (ps_new_ca_1, ps_new_2)) + end +end diff --git a/test/contrib/training.jl b/test/contrib/training.jl deleted file mode 100644 index ed8389da7d..0000000000 --- a/test/contrib/training.jl +++ /dev/null @@ -1,56 +0,0 @@ -using ADTypes, Lux, Optimisers, Random, Test - -include("../test_utils.jl") - -function _loss_function(model, ps, st, data) - y, st = model(data, ps, st) - return sum(y), st, () -end - -@testset "$mode: TrainState" for (mode, aType, device, ongpu) in MODES - rng = get_stable_rng(12345) - - model = Dense(3, 2) - opt = Adam(0.01f0) - - tstate = Lux.Experimental.TrainState(Lux.replicate(rng), model, opt; - transform_variables=device) - - x = randn(Lux.replicate(rng), Float32, (3, 1)) |> aType - - ps, st = Lux.setup(Lux.replicate(rng), model) .|> device - opt_st = Optimisers.setup(opt, tstate.parameters) - - @test check_approx(tstate.model, model) - @test check_approx(tstate.parameters, ps) - @test check_approx(tstate.states, st) - @test check_approx(tstate.optimizer_state, opt_st) - @test tstate.step == 0 -end - -@testset "$mode: AbstractADTypes" for (mode, aType, device, ongpu) in MODES - rng = get_stable_rng(12345) - - model = Dense(3, 2) - opt = Adam(0.01f0) - - tstate = Lux.Experimental.TrainState(Lux.replicate(rng), model, opt; - transform_variables=device) - - x = randn(Lux.replicate(rng), Float32, (3, 1)) |> aType - - @testset "NotImplemented $(string(ad))" for ad in (AutoEnzyme(), AutoReverseDiff()) - @test_throws ArgumentError Lux.Experimental.compute_gradients( - ad, _loss_function, x, - tstate) - end - - for ad in (AutoZygote(), AutoTracker()) - grads, _, _, _ = @test_nowarn Lux.Experimental.compute_gradients( - ad, _loss_function, - x, tstate) - tstate_ = @test_nowarn Lux.Experimental.apply_gradients(tstate, grads) - @test tstate_.step == 1 - @test tstate != tstate_ - end -end diff --git a/test/contrib/training_tests.jl b/test/contrib/training_tests.jl new file mode 100644 index 0000000000..aaaab1c487 --- /dev/null +++ b/test/contrib/training_tests.jl @@ -0,0 +1,52 @@ +@testitem "TrainState" setup=[SharedTestSetup] begin + using Optimisers + + rng = get_stable_rng(12345) + + @testset "$mode" for (mode, aType, device, ongpu) in MODES + model = Dense(3, 2) + opt = Adam(0.01f0) + + tstate = Lux.Experimental.TrainState(Lux.replicate(rng), model, opt) + + x = randn(Lux.replicate(rng), Float32, (3, 1)) + + ps, st = Lux.setup(Lux.replicate(rng), model) + opt_st = Optimisers.setup(opt, tstate.parameters) + + @test check_approx(tstate.model, model) + @test check_approx(tstate.parameters, ps) + @test check_approx(tstate.states, st) + @test check_approx(tstate.optimizer_state, opt_st) + @test tstate.step == 0 + end +end + +@testitem "AbstractADTypes" setup=[SharedTestSetup] begin + using ADTypes, Optimisers + + function _loss_function(model, ps, st, data) + y, st = model(data, ps, st) + return sum(y), st, () + end + + rng = get_stable_rng(12345) + + @testset "$mode" for (mode, aType, device, ongpu) in MODES + model = Dense(3, 2) + opt = Adam(0.01f0) + + tstate = Lux.Experimental.TrainState(Lux.replicate(rng), model, opt; + transform_variables=device) + + x = randn(Lux.replicate(rng), Float32, (3, 1)) |> aType + + for ad in (AutoZygote(), AutoTracker()) + grads, _, _, _ = Lux.Experimental.compute_gradients(ad, _loss_function, + x, tstate) + tstate_ = Lux.Experimental.apply_gradients(tstate, grads) + @test tstate_.step == 1 + @test tstate != tstate_ + end + end +end diff --git a/test/core.jl b/test/core_tests.jl similarity index 75% rename from test/core.jl rename to test/core_tests.jl index 0243379589..79950763a5 100644 --- a/test/core.jl +++ b/test/core_tests.jl @@ -1,10 +1,8 @@ -using Functors, Lux, Test +@testitem "Functors Compatibility" setup=[SharedTestSetup] begin + using Functors -include("test_utils.jl") + rng = get_stable_rng(12345) -rng = get_stable_rng(12345) - -@testset "Functors Compatibility" begin c = Parallel(+; chain=Chain(; dense_1=Dense(2 => 3), dense_2=Dense(3 => 5)), dense_3=Dense(5 => 1)) diff --git a/test/ext/LuxComponentArraysExt.jl b/test/ext/LuxComponentArraysExt.jl deleted file mode 100644 index a482a716d0..0000000000 --- a/test/ext/LuxComponentArraysExt.jl +++ /dev/null @@ -1,14 +0,0 @@ -using ComponentArrays, Lux, Test, Zygote - -include("../test_utils.jl") - -rng = get_stable_rng(12345) - -@testset "LuxComponentArraysExt" begin - # Ref: https://github.com/avik-pal/Lux.jl/issues/243 - nn = Chain(Dense(4, 3), Dense(3, 2)) - ps, st = Lux.setup(rng, nn) - - l2reg(p) = sum(abs2, ComponentArray(p)) - @test_nowarn gradient(l2reg, ps) -end diff --git a/test/ext/LuxFluxTransformExt.jl b/test/ext/LuxFluxTransformExt.jl deleted file mode 100644 index 94bb8bd15b..0000000000 --- a/test/ext/LuxFluxTransformExt.jl +++ /dev/null @@ -1,451 +0,0 @@ -import Flux -using Lux, Test - -fdevice(::Lux.LuxCPUDevice) = Flux.cpu -fdevice(::Lux.LuxCUDADevice) = Base.Fix1(Flux.gpu, Flux.FluxCUDAAdaptor()) -fdevice(::Lux.LuxAMDGPUDevice) = Base.Fix1(Flux.gpu, Flux.FluxAMDAdaptor()) - -include("../test_utils.jl") - -@testset "$mode: LuxFluxTransformExt" for (mode, aType, device, ongpu) in MODES - @testset "Containers" begin - @testset "Chain" begin - model = Flux.Chain(Flux.Dense(2 => 5), Flux.Dense(5 => 1)) |> fdevice(device) - x = rand(Float32, 2, 1) |> aType - - model_lux = transform(model; preserve_ps_st=true) - ps, st = Lux.setup(get_stable_rng(12345), model_lux) .|> device - - @test model(x) ≈ model_lux(x, ps, st)[1] - - model_lux = transform(model) - ps, st = Lux.setup(get_stable_rng(12345), model_lux) .|> device - - @test size(model_lux(x, ps, st)[1]) == (1, 1) - end - - @testset "Maxout" begin - model = Flux.Maxout(() -> Flux.Dense(2 => 5), 4) |> fdevice(device) - x = rand(Float32, 2, 1) |> aType - - model_lux = transform(model; preserve_ps_st=true) - ps, st = Lux.setup(get_stable_rng(12345), model_lux) .|> device - - @test model(x) ≈ model_lux(x, ps, st)[1] - - model_lux = transform(model) - ps, st = Lux.setup(get_stable_rng(12345), model_lux) .|> device - - @test size(model_lux(x, ps, st)[1]) == (5, 1) - end - - @testset "Skip Connection" begin - model = Flux.SkipConnection(Flux.Dense(2 => 2), +) |> fdevice(device) - x = rand(Float32, 2, 1) |> aType - - model_lux = transform(model; preserve_ps_st=true) - ps, st = Lux.setup(get_stable_rng(12345), model_lux) .|> device - - @test model(x) ≈ model_lux(x, ps, st)[1] - - model_lux = transform(model) - ps, st = Lux.setup(get_stable_rng(12345), model_lux) .|> device - - @test size(model_lux(x, ps, st)[1]) == (2, 1) - end - - @testset "Parallel" begin - model = Flux.Parallel(+, Flux.Dense(2 => 2), Flux.Dense(2 => 2)) |> - fdevice(device) - x = rand(Float32, 2, 1) |> aType - - model_lux = transform(model; preserve_ps_st=true) - ps, st = Lux.setup(get_stable_rng(12345), model_lux) .|> device - - @test model(x) ≈ model_lux(x, ps, st)[1] - - model_lux = transform(model) - ps, st = Lux.setup(get_stable_rng(12345), model_lux) .|> device - - @test size(model_lux(x, ps, st)[1]) == (2, 1) - end - - @testset "Pairwise Fusion" begin - model = Flux.PairwiseFusion(+, Flux.Dense(2 => 2), Flux.Dense(2 => 2)) |> - fdevice(device) - x = (rand(Float32, 2, 1), rand(Float32, 2, 1)) .|> aType - - model_lux = transform(model; preserve_ps_st=true) - ps, st = Lux.setup(get_stable_rng(12345), model_lux) .|> device - - @test all(model(x) .≈ model_lux(x, ps, st)[1]) - - model_lux = transform(model) - ps, st = Lux.setup(get_stable_rng(12345), model_lux) .|> device - - @test all(size.(model_lux(x, ps, st)[1]) .== ((2, 1),)) - end - end - - @testset "Linear" begin - @testset "Dense" begin - for model in [ - Flux.Dense(2 => 4) |> fdevice(device), - Flux.Dense(2 => 4; bias=false) |> fdevice(device) - ] - x = randn(Float32, 2, 4) |> aType - - model_lux = transform(model; preserve_ps_st=true) - ps, st = Lux.setup(get_stable_rng(12345), model_lux) .|> device - - @test model(x) ≈ model_lux(x, ps, st)[1] - - model_lux = transform(model) - ps, st = Lux.setup(get_stable_rng(12345), model_lux) .|> device - - @test size(model_lux(x, ps, st)[1]) == size(model(x)) - end - end - - @testset "Scale" begin - for model in [ - Flux.Scale(2) |> fdevice(device), - Flux.Scale(2; bias=false) |> fdevice(device) - ] - x = randn(Float32, 2, 4) |> aType - - model_lux = transform(model; preserve_ps_st=true) - ps, st = Lux.setup(get_stable_rng(12345), model_lux) .|> device - - @test model(x) ≈ model_lux(x, ps, st)[1] - - model_lux = transform(model) - ps, st = Lux.setup(get_stable_rng(12345), model_lux) .|> device - - @test size(model_lux(x, ps, st)[1]) == size(model(x)) - end - end - - @testset "Bilinear" begin - for model in [ - Flux.Bilinear((2, 3) => 5) |> fdevice(device), - Flux.Bilinear((2, 3) => 5; bias=false) |> fdevice(device) - ] - x = randn(Float32, 2, 4) |> aType - y = randn(Float32, 3, 4) |> aType - - model_lux = transform(model; preserve_ps_st=true) - ps, st = Lux.setup(get_stable_rng(12345), model_lux) .|> device - - @test model(x, y) ≈ model_lux((x, y), ps, st)[1] - - model_lux = transform(model) - ps, st = Lux.setup(get_stable_rng(12345), model_lux) .|> device - - @test size(model_lux((x, y), ps, st)[1]) == size(model(x, y)) - end - end - - @testset "Embedding" begin - model = Flux.Embedding(16 => 4) |> fdevice(device) - x = rand(1:16, 2, 4) |> aType - - model_lux = transform(model; preserve_ps_st=true) - ps, st = Lux.setup(get_stable_rng(12345), model_lux) .|> device - - @test model(x) ≈ model_lux(x, ps, st)[1] - - model_lux = transform(model) - ps, st = Lux.setup(get_stable_rng(12345), model_lux) .|> device - - @test size(model_lux(x, ps, st)[1]) == (4, 2, 4) - end - end - - @testset "Convolutions" begin - @testset "Conv" begin - model = Flux.Conv((3, 3), 1 => 2) |> fdevice(device) - x = rand(Float32, 6, 6, 1, 4) |> aType - - model_lux = transform(model; preserve_ps_st=true) - ps, st = Lux.setup(get_stable_rng(12345), model_lux) .|> device - - @test model(x) ≈ model_lux(x, ps, st)[1] - - model = Flux.Conv((3, 3), 1 => 2; pad=Flux.SamePad()) |> fdevice(device) - x = rand(Float32, 6, 6, 1, 4) |> aType - - model_lux = transform(model; preserve_ps_st=true) - ps, st = Lux.setup(get_stable_rng(12345), model_lux) .|> device - - @test model(x) ≈ model_lux(x, ps, st)[1] - end - - @testset "CrossCor" begin - model = Flux.CrossCor((3, 3), 1 => 2) |> fdevice(device) - x = rand(Float32, 6, 6, 1, 4) |> aType - - model_lux = transform(model; preserve_ps_st=true) - ps, st = Lux.setup(get_stable_rng(12345), model_lux) .|> device - - @test model(x) ≈ model_lux(x, ps, st)[1] - - model = Flux.CrossCor((3, 3), 1 => 2; pad=Flux.SamePad()) |> fdevice(device) - x = rand(Float32, 6, 6, 1, 4) |> aType - - model_lux = transform(model; preserve_ps_st=true) - ps, st = Lux.setup(get_stable_rng(12345), model_lux) .|> device - - @test model(x) ≈ model_lux(x, ps, st)[1] - end - - @testset "ConvTranspose" begin - model = Flux.ConvTranspose((3, 3), 1 => 2) |> fdevice(device) - x = rand(Float32, 6, 6, 1, 4) |> aType - - model_lux = transform(model; preserve_ps_st=true) - ps, st = Lux.setup(get_stable_rng(12345), model_lux) .|> device - - @test model(x) ≈ model_lux(x, ps, st)[1] - - model = Flux.ConvTranspose((3, 3), 1 => 2; pad=Flux.SamePad()) |> - fdevice(device) - x = rand(Float32, 6, 6, 1, 4) |> aType - - model_lux = transform(model; preserve_ps_st=true) - ps, st = Lux.setup(get_stable_rng(12345), model_lux) .|> device - - @test model(x) ≈ model_lux(x, ps, st)[1] - end - end - - @testset "Pooling" begin - @testset "AdaptiveMaxPooling" begin - model = Flux.AdaptiveMaxPool((2, 2)) |> fdevice(device) - x = rand(Float32, 6, 6, 1, 4) |> aType - - model_lux = transform(model; preserve_ps_st=true) - ps, st = Lux.setup(get_stable_rng(12345), model_lux) .|> device - - @test model(x) ≈ model_lux(x, ps, st)[1] - end - - @testset "AdaptiveMeanPooling" begin - model = Flux.AdaptiveMeanPool((2, 2)) |> fdevice(device) - x = rand(Float32, 6, 6, 1, 4) |> aType - - model_lux = transform(model; preserve_ps_st=true) - ps, st = Lux.setup(get_stable_rng(12345), model_lux) .|> device - - @test model(x) ≈ model_lux(x, ps, st)[1] - end - - @testset "MaxPooling" begin - model = Flux.MaxPool((2, 2)) |> fdevice(device) - x = rand(Float32, 6, 6, 1, 4) |> aType - - model_lux = transform(model; preserve_ps_st=true) - ps, st = Lux.setup(get_stable_rng(12345), model_lux) .|> device - - @test model(x) ≈ model_lux(x, ps, st)[1] - end - - @testset "MeanPooling" begin - model = Flux.MeanPool((2, 2)) |> fdevice(device) - x = rand(Float32, 6, 6, 1, 4) |> aType - - model_lux = transform(model; preserve_ps_st=true) - ps, st = Lux.setup(get_stable_rng(12345), model_lux) .|> device - - @test model(x) ≈ model_lux(x, ps, st)[1] - end - - @testset "GlobalMaxPooling" begin - model = Flux.GlobalMaxPool() |> fdevice(device) - x = rand(Float32, 6, 6, 1, 4) |> aType - - model_lux = transform(model; preserve_ps_st=true) - ps, st = Lux.setup(get_stable_rng(12345), model_lux) .|> device - - @test model(x) ≈ model_lux(x, ps, st)[1] - end - - @testset "GlobalMeanPooling" begin - model = Flux.GlobalMeanPool() |> fdevice(device) - x = rand(Float32, 6, 6, 1, 4) |> aType - - model_lux = transform(model; preserve_ps_st=true) - ps, st = Lux.setup(get_stable_rng(12345), model_lux) .|> device - - @test model(x) ≈ model_lux(x, ps, st)[1] - end - end - - @testset "Upsampling" begin - @testset "Upsample" begin - model = Flux.Upsample(5) |> fdevice(device) - x = rand(Float32, 2, 2, 2, 1) |> aType - - model_lux = transform(model) - ps, st = Lux.setup(get_stable_rng(12345), model_lux) .|> device - - @test size(model_lux(x, ps, st)[1]) == (10, 10, 2, 1) - @test model(x) ≈ model_lux(x, ps, st)[1] - end - - @testset "PixelShuffle" begin - model = Flux.PixelShuffle(2) |> fdevice(device) - x = randn(Float32, 2, 2, 4, 1) |> aType - - model_lux = transform(model) - ps, st = Lux.setup(get_stable_rng(12345), model_lux) .|> device - - @test size(model_lux(x, ps, st)[1]) == (4, 4, 1, 1) - @test model(x) ≈ model_lux(x, ps, st)[1] - end - end - - @testset "Recurrent" begin - # @test_throws Lux.FluxModelConversionError transform(Flux.RNN(2 => 2)) - - @testset "RNNCell" begin - model = Flux.RNNCell(2 => 3) |> fdevice(device) - x = rand(Float32, 2, 4) |> aType - - model_lux = transform(model) - ps, st = Lux.setup(get_stable_rng(12345), model_lux) .|> device - - @test size(model_lux(x, ps, st)[1][1]) == (3, 4) - end - - @testset "LSTMCell" begin - model = Flux.LSTMCell(2 => 3) |> fdevice(device) - x = rand(Float32, 2, 4) |> aType - - model_lux = transform(model) - ps, st = Lux.setup(get_stable_rng(12345), model_lux) .|> device - - @test size(model_lux(x, ps, st)[1][1]) == (3, 4) - end - - @testset "GRUCell" begin - model = Flux.GRUCell(2 => 3) |> fdevice(device) - x = rand(Float32, 2, 4) |> aType - - model_lux = transform(model) - ps, st = Lux.setup(get_stable_rng(12345), model_lux) .|> device - - @test size(model_lux(x, ps, st)[1][1]) == (3, 4) - end - end - - @testset "Normalize" begin - @testset "BatchNorm" begin - model = Flux.BatchNorm(2) |> fdevice(device) - x = randn(Float32, 2, 4) |> aType - - model_lux = transform(model; preserve_ps_st=true) - ps, st = Lux.setup(get_stable_rng(12345), model_lux) .|> device - st = Lux.testmode(st) - - @test model(x) ≈ model_lux(x, ps, st)[1] - - x = randn(Float32, 2, 2, 2, 1) |> aType - - @test model(x) ≈ model_lux(x, ps, st)[1] - - model_lux = transform(model; preserve_ps_st=true, force_preserve=true) - ps, st = Lux.setup(get_stable_rng(12345), model_lux) .|> device - st = Lux.testmode(st) - - @test model(x) ≈ model_lux(x, ps, st)[1] - end - - @testset "GroupNorm" begin - model = Flux.GroupNorm(4, 2) |> fdevice(device) - x = randn(Float32, 2, 2, 4, 1) |> aType - - model_lux = transform(model; preserve_ps_st=true) - ps, st = Lux.setup(get_stable_rng(12345), model_lux) .|> device - st = Lux.testmode(st) - - @test model(x) ≈ model_lux(x, ps, st)[1] - - model_lux = transform(model; preserve_ps_st=true, force_preserve=true) - ps, st = Lux.setup(get_stable_rng(12345), model_lux) .|> device - st = Lux.testmode(st) - - @test model(x) ≈ model_lux(x, ps, st)[1] - end - - @testset "LayerNorm" begin - model = Flux.LayerNorm(4) |> fdevice(device) - x = randn(Float32, 4, 4, 4, 1) |> aType - - model_lux = transform(model; preserve_ps_st=true) - ps, st = Lux.setup(get_stable_rng(12345), model_lux) .|> device - st = Lux.testmode(st) - - @test model(x) ≈ model_lux(x, ps, st)[1] - end - - @testset "InstanceNorm" begin - model = Flux.InstanceNorm(4) |> fdevice(device) - x = randn(Float32, 4, 4, 4, 1) |> aType - - model_lux = transform(model; preserve_ps_st=true) - ps, st = Lux.setup(get_stable_rng(12345), model_lux) .|> device - - @test model(x) ≈ model_lux(x, ps, st)[1] - end - end - - @testset "Dropout" begin - @testset "Dropout" begin - model = transform(Flux.Dropout(0.5f0)) - - x = randn(Float32, 2, 4) |> aType - ps, st = Lux.setup(get_stable_rng(12345), model) .|> device - - @test size(model(x, ps, st)[1]) == size(x) - - x = randn(Float32, 2, 3, 4) |> aType - ps, st = Lux.setup(get_stable_rng(12345), model) .|> device - - @test size(model(x, ps, st)[1]) == size(x) - end - - @testset "AlphaDropout" begin - model = transform(Flux.AlphaDropout(0.5)) - - x = randn(Float32, 2, 4) |> aType - ps, st = Lux.setup(get_stable_rng(12345), model) .|> device - - @test size(model(x, ps, st)[1]) == size(x) - - x = randn(Float32, 2, 4, 3) |> aType - ps, st = Lux.setup(get_stable_rng(12345), model) .|> device - - @test size(model(x, ps, st)[1]) == size(x) - end - end - - @testset "Custom Layer" begin - struct CustomFluxLayer - weight - bias - end - - Flux.@functor CustomFluxLayer - - (c::CustomFluxLayer)(x) = c.weight .* x .+ c.bias - - c = CustomFluxLayer(randn(10), randn(10)) |> fdevice(device) - x = randn(10) |> aType - - c_lux = transform(c) - ps, st = Lux.setup(get_stable_rng(12345), c_lux) .|> device - - @test c(x) ≈ c_lux(x, ps, st)[1] - end -end diff --git a/test/ext/flux_tests.jl b/test/ext/flux_tests.jl new file mode 100644 index 0000000000..df2167b8f9 --- /dev/null +++ b/test/ext/flux_tests.jl @@ -0,0 +1,451 @@ +@testitem "Lux.transform tests" setup=[SharedTestSetup] begin + import Flux + + fdevice(::Lux.LuxCPUDevice) = Flux.cpu + fdevice(::Lux.LuxCUDADevice) = Base.Fix1(Flux.gpu, Flux.FluxCUDAAdaptor()) + fdevice(::Lux.LuxAMDGPUDevice) = Base.Fix1(Flux.gpu, Flux.FluxAMDAdaptor()) + + @testset "$mode" for (mode, aType, device, ongpu) in MODES + @testset "Containers" begin + @testset "Chain" begin + model = Flux.Chain(Flux.Dense(2 => 5), Flux.Dense(5 => 1)) |> + fdevice(device) + x = rand(Float32, 2, 1) |> aType + + model_lux = transform(model; preserve_ps_st=true) + ps, st = Lux.setup(get_stable_rng(12345), model_lux) .|> device + + @test model(x) ≈ model_lux(x, ps, st)[1] + + model_lux = transform(model) + ps, st = Lux.setup(get_stable_rng(12345), model_lux) .|> device + + @test size(model_lux(x, ps, st)[1]) == (1, 1) + end + + @testset "Maxout" begin + model = Flux.Maxout(() -> Flux.Dense(2 => 5), 4) |> fdevice(device) + x = rand(Float32, 2, 1) |> aType + + model_lux = transform(model; preserve_ps_st=true) + ps, st = Lux.setup(get_stable_rng(12345), model_lux) .|> device + + @test model(x) ≈ model_lux(x, ps, st)[1] + + model_lux = transform(model) + ps, st = Lux.setup(get_stable_rng(12345), model_lux) .|> device + + @test size(model_lux(x, ps, st)[1]) == (5, 1) + end + + @testset "Skip Connection" begin + model = Flux.SkipConnection(Flux.Dense(2 => 2), +) |> fdevice(device) + x = rand(Float32, 2, 1) |> aType + + model_lux = transform(model; preserve_ps_st=true) + ps, st = Lux.setup(get_stable_rng(12345), model_lux) .|> device + + @test model(x) ≈ model_lux(x, ps, st)[1] + + model_lux = transform(model) + ps, st = Lux.setup(get_stable_rng(12345), model_lux) .|> device + + @test size(model_lux(x, ps, st)[1]) == (2, 1) + end + + @testset "Parallel" begin + model = Flux.Parallel(+, Flux.Dense(2 => 2), Flux.Dense(2 => 2)) |> + fdevice(device) + x = rand(Float32, 2, 1) |> aType + + model_lux = transform(model; preserve_ps_st=true) + ps, st = Lux.setup(get_stable_rng(12345), model_lux) .|> device + + @test model(x) ≈ model_lux(x, ps, st)[1] + + model_lux = transform(model) + ps, st = Lux.setup(get_stable_rng(12345), model_lux) .|> device + + @test size(model_lux(x, ps, st)[1]) == (2, 1) + end + + @testset "Pairwise Fusion" begin + model = Flux.PairwiseFusion(+, Flux.Dense(2 => 2), Flux.Dense(2 => 2)) |> + fdevice(device) + x = (rand(Float32, 2, 1), rand(Float32, 2, 1)) .|> aType + + model_lux = transform(model; preserve_ps_st=true) + ps, st = Lux.setup(get_stable_rng(12345), model_lux) .|> device + + @test all(model(x) .≈ model_lux(x, ps, st)[1]) + + model_lux = transform(model) + ps, st = Lux.setup(get_stable_rng(12345), model_lux) .|> device + + @test all(size.(model_lux(x, ps, st)[1]) .== ((2, 1),)) + end + end + + @testset "Linear" begin + @testset "Dense" begin + for model in [ + Flux.Dense(2 => 4) |> fdevice(device), + Flux.Dense(2 => 4; bias=false) |> fdevice(device) + ] + x = randn(Float32, 2, 4) |> aType + + model_lux = transform(model; preserve_ps_st=true) + ps, st = Lux.setup(get_stable_rng(12345), model_lux) .|> device + + @test model(x) ≈ model_lux(x, ps, st)[1] + + model_lux = transform(model) + ps, st = Lux.setup(get_stable_rng(12345), model_lux) .|> device + + @test size(model_lux(x, ps, st)[1]) == size(model(x)) + end + end + + @testset "Scale" begin + for model in [ + Flux.Scale(2) |> fdevice(device), + Flux.Scale(2; bias=false) |> fdevice(device) + ] + x = randn(Float32, 2, 4) |> aType + + model_lux = transform(model; preserve_ps_st=true) + ps, st = Lux.setup(get_stable_rng(12345), model_lux) .|> device + + @test model(x) ≈ model_lux(x, ps, st)[1] + + model_lux = transform(model) + ps, st = Lux.setup(get_stable_rng(12345), model_lux) .|> device + + @test size(model_lux(x, ps, st)[1]) == size(model(x)) + end + end + + @testset "Bilinear" begin + for model in [ + Flux.Bilinear((2, 3) => 5) |> fdevice(device), + Flux.Bilinear((2, 3) => 5; bias=false) |> fdevice(device) + ] + x = randn(Float32, 2, 4) |> aType + y = randn(Float32, 3, 4) |> aType + + model_lux = transform(model; preserve_ps_st=true) + ps, st = Lux.setup(get_stable_rng(12345), model_lux) .|> device + + @test model(x, y) ≈ model_lux((x, y), ps, st)[1] + + model_lux = transform(model) + ps, st = Lux.setup(get_stable_rng(12345), model_lux) .|> device + + @test size(model_lux((x, y), ps, st)[1]) == size(model(x, y)) + end + end + + @testset "Embedding" begin + model = Flux.Embedding(16 => 4) |> fdevice(device) + x = rand(1:16, 2, 4) |> aType + + model_lux = transform(model; preserve_ps_st=true) + ps, st = Lux.setup(get_stable_rng(12345), model_lux) .|> device + + @test model(x) ≈ model_lux(x, ps, st)[1] + + model_lux = transform(model) + ps, st = Lux.setup(get_stable_rng(12345), model_lux) .|> device + + @test size(model_lux(x, ps, st)[1]) == (4, 2, 4) + end + end + + @testset "Convolutions" begin + @testset "Conv" begin + model = Flux.Conv((3, 3), 1 => 2) |> fdevice(device) + x = rand(Float32, 6, 6, 1, 4) |> aType + + model_lux = transform(model; preserve_ps_st=true) + ps, st = Lux.setup(get_stable_rng(12345), model_lux) .|> device + + @test model(x) ≈ model_lux(x, ps, st)[1] + + model = Flux.Conv((3, 3), 1 => 2; pad=Flux.SamePad()) |> fdevice(device) + x = rand(Float32, 6, 6, 1, 4) |> aType + + model_lux = transform(model; preserve_ps_st=true) + ps, st = Lux.setup(get_stable_rng(12345), model_lux) .|> device + + @test model(x) ≈ model_lux(x, ps, st)[1] + end + + @testset "CrossCor" begin + model = Flux.CrossCor((3, 3), 1 => 2) |> fdevice(device) + x = rand(Float32, 6, 6, 1, 4) |> aType + + model_lux = transform(model; preserve_ps_st=true) + ps, st = Lux.setup(get_stable_rng(12345), model_lux) .|> device + + @test model(x) ≈ model_lux(x, ps, st)[1] + + model = Flux.CrossCor((3, 3), 1 => 2; pad=Flux.SamePad()) |> fdevice(device) + x = rand(Float32, 6, 6, 1, 4) |> aType + + model_lux = transform(model; preserve_ps_st=true) + ps, st = Lux.setup(get_stable_rng(12345), model_lux) .|> device + + @test model(x) ≈ model_lux(x, ps, st)[1] + end + + @testset "ConvTranspose" begin + model = Flux.ConvTranspose((3, 3), 1 => 2) |> fdevice(device) + x = rand(Float32, 6, 6, 1, 4) |> aType + + model_lux = transform(model; preserve_ps_st=true) + ps, st = Lux.setup(get_stable_rng(12345), model_lux) .|> device + + @test model(x) ≈ model_lux(x, ps, st)[1] + + model = Flux.ConvTranspose((3, 3), 1 => 2; pad=Flux.SamePad()) |> + fdevice(device) + x = rand(Float32, 6, 6, 1, 4) |> aType + + model_lux = transform(model; preserve_ps_st=true) + ps, st = Lux.setup(get_stable_rng(12345), model_lux) .|> device + + @test model(x) ≈ model_lux(x, ps, st)[1] + end + end + + @testset "Pooling" begin + @testset "AdaptiveMaxPooling" begin + model = Flux.AdaptiveMaxPool((2, 2)) |> fdevice(device) + x = rand(Float32, 6, 6, 1, 4) |> aType + + model_lux = transform(model; preserve_ps_st=true) + ps, st = Lux.setup(get_stable_rng(12345), model_lux) .|> device + + @test model(x) ≈ model_lux(x, ps, st)[1] + end + + @testset "AdaptiveMeanPooling" begin + model = Flux.AdaptiveMeanPool((2, 2)) |> fdevice(device) + x = rand(Float32, 6, 6, 1, 4) |> aType + + model_lux = transform(model; preserve_ps_st=true) + ps, st = Lux.setup(get_stable_rng(12345), model_lux) .|> device + + @test model(x) ≈ model_lux(x, ps, st)[1] + end + + @testset "MaxPooling" begin + model = Flux.MaxPool((2, 2)) |> fdevice(device) + x = rand(Float32, 6, 6, 1, 4) |> aType + + model_lux = transform(model; preserve_ps_st=true) + ps, st = Lux.setup(get_stable_rng(12345), model_lux) .|> device + + @test model(x) ≈ model_lux(x, ps, st)[1] + end + + @testset "MeanPooling" begin + model = Flux.MeanPool((2, 2)) |> fdevice(device) + x = rand(Float32, 6, 6, 1, 4) |> aType + + model_lux = transform(model; preserve_ps_st=true) + ps, st = Lux.setup(get_stable_rng(12345), model_lux) .|> device + + @test model(x) ≈ model_lux(x, ps, st)[1] + end + + @testset "GlobalMaxPooling" begin + model = Flux.GlobalMaxPool() |> fdevice(device) + x = rand(Float32, 6, 6, 1, 4) |> aType + + model_lux = transform(model; preserve_ps_st=true) + ps, st = Lux.setup(get_stable_rng(12345), model_lux) .|> device + + @test model(x) ≈ model_lux(x, ps, st)[1] + end + + @testset "GlobalMeanPooling" begin + model = Flux.GlobalMeanPool() |> fdevice(device) + x = rand(Float32, 6, 6, 1, 4) |> aType + + model_lux = transform(model; preserve_ps_st=true) + ps, st = Lux.setup(get_stable_rng(12345), model_lux) .|> device + + @test model(x) ≈ model_lux(x, ps, st)[1] + end + end + + @testset "Upsampling" begin + @testset "Upsample" begin + model = Flux.Upsample(5) |> fdevice(device) + x = rand(Float32, 2, 2, 2, 1) |> aType + + model_lux = transform(model) + ps, st = Lux.setup(get_stable_rng(12345), model_lux) .|> device + + @test size(model_lux(x, ps, st)[1]) == (10, 10, 2, 1) + @test model(x) ≈ model_lux(x, ps, st)[1] + end + + @testset "PixelShuffle" begin + model = Flux.PixelShuffle(2) |> fdevice(device) + x = randn(Float32, 2, 2, 4, 1) |> aType + + model_lux = transform(model) + ps, st = Lux.setup(get_stable_rng(12345), model_lux) .|> device + + @test size(model_lux(x, ps, st)[1]) == (4, 4, 1, 1) + @test model(x) ≈ model_lux(x, ps, st)[1] + end + end + + @testset "Recurrent" begin + # @test_throws Lux.FluxModelConversionError transform(Flux.RNN(2 => 2)) + + @testset "RNNCell" begin + model = Flux.RNNCell(2 => 3) |> fdevice(device) + x = rand(Float32, 2, 4) |> aType + + model_lux = transform(model) + ps, st = Lux.setup(get_stable_rng(12345), model_lux) .|> device + + @test size(model_lux(x, ps, st)[1][1]) == (3, 4) + end + + @testset "LSTMCell" begin + model = Flux.LSTMCell(2 => 3) |> fdevice(device) + x = rand(Float32, 2, 4) |> aType + + model_lux = transform(model) + ps, st = Lux.setup(get_stable_rng(12345), model_lux) .|> device + + @test size(model_lux(x, ps, st)[1][1]) == (3, 4) + end + + @testset "GRUCell" begin + model = Flux.GRUCell(2 => 3) |> fdevice(device) + x = rand(Float32, 2, 4) |> aType + + model_lux = transform(model) + ps, st = Lux.setup(get_stable_rng(12345), model_lux) .|> device + + @test size(model_lux(x, ps, st)[1][1]) == (3, 4) + end + end + + @testset "Normalize" begin + @testset "BatchNorm" begin + model = Flux.BatchNorm(2) |> fdevice(device) + x = randn(Float32, 2, 4) |> aType + + model_lux = transform(model; preserve_ps_st=true) + ps, st = Lux.setup(get_stable_rng(12345), model_lux) .|> device + st = Lux.testmode(st) + + @test model(x) ≈ model_lux(x, ps, st)[1] + + x = randn(Float32, 2, 2, 2, 1) |> aType + + @test model(x) ≈ model_lux(x, ps, st)[1] + + model_lux = transform(model; preserve_ps_st=true, force_preserve=true) + ps, st = Lux.setup(get_stable_rng(12345), model_lux) .|> device + st = Lux.testmode(st) + + @test model(x) ≈ model_lux(x, ps, st)[1] + end + + @testset "GroupNorm" begin + model = Flux.GroupNorm(4, 2) |> fdevice(device) + x = randn(Float32, 2, 2, 4, 1) |> aType + + model_lux = transform(model; preserve_ps_st=true) + ps, st = Lux.setup(get_stable_rng(12345), model_lux) .|> device + st = Lux.testmode(st) + + @test model(x) ≈ model_lux(x, ps, st)[1] + + model_lux = transform(model; preserve_ps_st=true, force_preserve=true) + ps, st = Lux.setup(get_stable_rng(12345), model_lux) .|> device + st = Lux.testmode(st) + + @test model(x) ≈ model_lux(x, ps, st)[1] + end + + @testset "LayerNorm" begin + model = Flux.LayerNorm(4) |> fdevice(device) + x = randn(Float32, 4, 4, 4, 1) |> aType + + model_lux = transform(model; preserve_ps_st=true) + ps, st = Lux.setup(get_stable_rng(12345), model_lux) .|> device + st = Lux.testmode(st) + + @test model(x) ≈ model_lux(x, ps, st)[1] + end + + @testset "InstanceNorm" begin + model = Flux.InstanceNorm(4) |> fdevice(device) + x = randn(Float32, 4, 4, 4, 1) |> aType + + model_lux = transform(model; preserve_ps_st=true) + ps, st = Lux.setup(get_stable_rng(12345), model_lux) .|> device + + @test model(x) ≈ model_lux(x, ps, st)[1] + end + end + + @testset "Dropout" begin + @testset "Dropout" begin + model = transform(Flux.Dropout(0.5f0)) + + x = randn(Float32, 2, 4) |> aType + ps, st = Lux.setup(get_stable_rng(12345), model) .|> device + + @test size(model(x, ps, st)[1]) == size(x) + + x = randn(Float32, 2, 3, 4) |> aType + ps, st = Lux.setup(get_stable_rng(12345), model) .|> device + + @test size(model(x, ps, st)[1]) == size(x) + end + + @testset "AlphaDropout" begin + model = transform(Flux.AlphaDropout(0.5)) + + x = randn(Float32, 2, 4) |> aType + ps, st = Lux.setup(get_stable_rng(12345), model) .|> device + + @test size(model(x, ps, st)[1]) == size(x) + + x = randn(Float32, 2, 4, 3) |> aType + ps, st = Lux.setup(get_stable_rng(12345), model) .|> device + + @test size(model(x, ps, st)[1]) == size(x) + end + end + + @testset "Custom Layer" begin + struct CustomFluxLayer + weight + bias + end + + Flux.@functor CustomFluxLayer + + (c::CustomFluxLayer)(x) = c.weight .* x .+ c.bias + + c = CustomFluxLayer(randn(10), randn(10)) |> fdevice(device) + x = randn(10) |> aType + + c_lux = transform(c) + ps, st = Lux.setup(get_stable_rng(12345), c_lux) .|> device + + @test c(x) ≈ c_lux(x, ps, st)[1] + end + end +end diff --git a/test/layers/basic.jl b/test/layers/basic.jl deleted file mode 100644 index 12b5bab824..0000000000 --- a/test/layers/basic.jl +++ /dev/null @@ -1,351 +0,0 @@ -using Lux, Random, Test - -include("../test_utils.jl") - -rng = get_stable_rng(12345) - -@testset "$mode: Miscellaneous Layers" for (mode, aType, device, ongpu) in MODES - @testset "Reshape Layer" begin - layer = ReshapeLayer((2, 3)) - __display(layer) - ps, st = Lux.setup(rng, layer) .|> device - x = randn(rng, 6, 3) |> aType - - @test size(layer(x, ps, st)[1]) == (2, 3, 3) - - @jet layer(x, ps, st) - __f = x -> sum(first(layer(x, ps, st))) - @eval @test_gradients $__f $x gpu_testing=$ongpu atol=1.0f-3 rtol=1.0f-3 - end - - @testset "Flatten Layer" begin - layer = FlattenLayer() - __display(layer) - ps, st = Lux.setup(rng, layer) .|> device - x = randn(rng, 6, 3, 2) |> aType - - @test size(layer(x, ps, st)[1]) == (18, 2) - - @jet layer(x, ps, st) - __f = x -> sum(first(layer(x, ps, st))) - @eval @test_gradients $__f $x gpu_testing=$ongpu atol=1.0f-3 rtol=1.0f-3 - end - - @testset "NoOpLayer" begin - layer = NoOpLayer() - __display(layer) - ps, st = Lux.setup(rng, layer) .|> device - x = (x=2, b=5) # Something totally arbitrary - - @test layer(x, ps, st)[1] == x - - @jet layer(x, ps, st) - - x = randn(rng, 6, 3) |> aType - __f = x -> sum(first(layer(x, ps, st))) - @eval @test_gradients $__f $x gpu_testing=$ongpu atol=1.0f-3 rtol=1.0f-3 - end - - @testset "SelectDim Layer" begin - layer = SelectDim(3, 1) - __display(layer) - ps, st = Lux.setup(rng, layer) .|> device - x = randn(rng, 6, 4, 3, 2) |> aType - - @test size(layer(x, ps, st)[1]) == (6, 4, 2) - - @jet layer(x, ps, st) - __f = x -> sum(first(layer(x, ps, st))) - @eval @test_gradients $__f $x gpu_testing=$ongpu atol=1.0f-3 rtol=1.0f-3 - end - - @testset "WrappedFunction" begin - layer = WrappedFunction(x -> x .* x) - __display(layer) - ps, st = Lux.setup(rng, layer) .|> device - x = randn(rng, 6, 4, 3, 2) |> aType - - @test layer(x, ps, st)[1] == x .* x - - @jet layer(x, ps, st) - __f = x -> sum(first(layer(x, ps, st))) - @eval @test_gradients $__f $x gpu_testing=$ongpu atol=1.0f-3 rtol=1.0f-3 - end -end - -@testset "$mode: Dense" for (mode, aType, device, ongpu) in MODES - @testset "constructors" begin - layer = Dense(10, 100) - ps, st = Lux.setup(rng, layer) .|> device - - @test size(ps.weight) == (100, 10) - @test size(ps.bias) == (100, 1) - @test layer.activation == identity - - layer = Dense(10, 100, relu; use_bias=false) - ps, st = Lux.setup(rng, layer) .|> device - - @test !haskey(ps, :bias) - @test layer.activation == relu - end - - @testset "allow fast activation" begin - layer = Dense(10, 10, tanh) - @test layer.activation == tanh_fast - layer = Dense(10, 10, tanh; allow_fast_activation=false) - @test layer.activation == tanh - end - - @testset "dimensions" begin - layer = Dense(10, 5) - ps, st = Lux.setup(rng, layer) - - @test size(first(Lux.apply(layer, randn(10), ps, st))) == (5,) - @test size(first(Lux.apply(layer, randn(10, 2), ps, st))) == (5, 2) - end - - @testset "zeros" begin - @test begin - layer = Dense(10, - 1, - identity; - init_weight=(rng, args...; kwargs...) -> ones(args...; kwargs...)) - first(Lux.apply(layer, - ones(10, 1) |> aType, - device.(Lux.setup(rng, layer))...)) - end == 10 * aType(ones(1, 1)) - - @test begin - layer = Dense(10, - 1, - identity; - init_weight=(rng, args...; kwargs...) -> ones(args...; kwargs...)) - first(Lux.apply(layer, - ones(10, 2) |> aType, - device.(Lux.setup(rng, layer))...)) - end == 10 * aType(ones(1, 2)) - - @test begin - layer = Dense(10, - 2, - identity; - init_weight=(rng, args...; kwargs...) -> ones(args...; kwargs...)) - first(Lux.apply(layer, - ones(10, 1) |> aType, - device.(Lux.setup(rng, layer))...)) - end == 10 * aType(ones(2, 1)) - - @test begin - layer = Dense(10, - 2, - identity; - init_weight=(rng, args...; kwargs...) -> ones(args...; kwargs...)) - first(Lux.apply(layer, - aType([ones(10, 1) 2 * ones(10, 1)]), - device.(Lux.setup(rng, layer))...)) - end == aType([10 20; 10 20]) - - @test begin - layer = Dense(10, - 2, - identity; - init_weight=(rng, args...; kwargs...) -> ones(args...; kwargs...), - use_bias=false) - first(Lux.apply(layer, - aType([ones(10, 1) 2 * ones(10, 1)]), - device.(Lux.setup(rng, layer))...)) - end == aType([10 20; 10 20]) - end -end - -@testset "$mode: Scale" for (mode, aType, device, ongpu) in MODES - @testset "constructors" begin - layer = Scale(10, 100) - ps, st = Lux.setup(rng, layer) .|> device - - @test size(ps.weight) == (10, 100) - @test size(ps.bias) == (10, 100) - @test layer.activation == identity - - layer = Scale(10, 100, relu; use_bias=false) - ps, st = Lux.setup(rng, layer) .|> device - - @test !haskey(ps, :bias) - @test layer.activation == relu - end - - @testset "allow fast activation" begin - layer = Scale(10, 5, tanh) - @test layer.activation == tanh_fast - layer = Scale(10, 5, tanh; allow_fast_activation=false) - @test layer.activation == tanh - end - - @testset "dimensions" begin - layer = Scale(10, 5) - ps, st = Lux.setup(rng, layer) .|> device - - @test size(first(Lux.apply(layer, randn(10) |> aType, ps, st))) == (10, 5) - @test size(first(Lux.apply(layer, randn(10, 5, 2) |> aType, ps, st))) == (10, 5, 2) - end - - @testset "zeros" begin - @test begin - layer = Scale(10, - 1, - identity; - init_weight=(rng, args...; kwargs...) -> ones(args...; kwargs...)) - first(Lux.apply(layer, - ones(10, 1) |> aType, - device.(Lux.setup(rng, layer))...)) - end == aType(ones(10, 1)) - - @test begin - layer = Scale(10, - 1, - identity; - init_weight=(rng, args...; kwargs...) -> ones(args...; kwargs...)) - first(Lux.apply(layer, - ones(10, 2) |> aType, - device.(Lux.setup(rng, layer))...)) - end == aType(ones(10, 2)) - - @test begin - layer = Scale(2, - identity; - init_weight=(rng, args...; kwargs...) -> ones(args...; kwargs...), - init_bias=(rng, args...; kwargs...) -> ones(args...; kwargs...)) - first(Lux.apply(layer, - [1 2; 3 4] |> aType, - device.(Lux.setup(rng, layer))...)) - end == aType([2.0 3.0; 4.0 5.0]) - - @test begin - layer = Scale(2, - tanh; - use_bias=false, - init_weight=(rng, args...; kwargs...) -> zeros(args...; kwargs...)) - first(Lux.apply(layer, - [1 2; 3 4] |> aType, - device.(Lux.setup(rng, layer))...)) - end == aType(zeros(2, 2)) - end -end - -@testset "$mode: Bilinear" for (mode, aType, device, ongpu) in MODES - @testset "SkipConnection recombinator" begin - d = Dense(2 => 2) - __display(d) - b = Bilinear((2, 2) => 3) - __display(b) - layer = SkipConnection(d, b) - __display(layer) - ps, st = Lux.setup(rng, layer) .|> device - x = randn(rng, Float32, 2, 1) |> aType - - @test size(layer(x, ps, st)[1]) == (3, 1) - - @jet layer(x, ps, st) - __f = (x, ps) -> sum(first(layer(x, ps, st))) - @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu - - d = Dense(2 => 2) - __display(d) - b = Bilinear((2, 2) => 3; use_bias=false) - __display(b) - layer = SkipConnection(d, b) - __display(layer) - ps, st = Lux.setup(rng, layer) .|> device - x = randn(rng, Float32, 2, 1) |> aType - - @test size(layer(x, ps, st)[1]) == (3, 1) - - @jet layer(x, ps, st) - __f = (x, ps) -> sum(first(layer(x, ps, st))) - @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu - - d = Dense(2 => 3) - __display(d) - b = Bilinear((3, 2) => 5) - __display(b) - layer = SkipConnection(d, b) - __display(layer) - ps, st = Lux.setup(rng, layer) .|> device - x = randn(rng, Float32, 2, 7, 11) |> aType - - @test size(layer(x, ps, st)[1]) == (5, 7, 11) - - @jet layer(x, ps, st) - __f = (x, ps) -> sum(first(layer(x, ps, st))) - @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu - end - - @testset "Two-streams zero sum" begin - x = zeros(Float32, 2, 1) |> aType - y = zeros(Float32, 1, 1) |> aType - layer = Bilinear((2, 1) => 3) - __display(layer) - ps, st = Lux.setup(rng, layer) .|> device - - @test size(layer((x, y), ps, st)[1]) == (3, 1) - @test sum(abs2, layer((x, y), ps, st)[1]) == 0.0f0 - - @jet layer((x, y), ps, st) - __f = (x, y, ps) -> sum(first(layer((x, y), ps, st))) - @eval @test_gradients $__f $x $y $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu - end - - @testset "Inner interactions" begin - x = randn(Float32, 2, 1) |> aType - layer = Bilinear((2, 2) => 3) - __display(layer) - ps, st = Lux.setup(rng, layer) .|> device - - @test size(layer(x, ps, st)[1]) == (3, 1) - - @jet layer(x, ps, st) - __f = (x, ps) -> sum(first(layer(x, ps, st))) - @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu - - x = randn(Float32, 2, 1) |> aType - layer = Bilinear(2 => 3) - __display(layer) - ps, st = Lux.setup(rng, layer) .|> device - - @test size(layer(x, ps, st)[1]) == (3, 1) - - @jet layer(x, ps, st) - __f = (x, ps) -> sum(first(layer(x, ps, st))) - @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu - end -end - -@testset "$mode: Embedding" for (mode, aType, device, ongpu) in MODES - vocab_size, embed_size = 10, 4 - layer = Embedding(vocab_size => embed_size) - __display(layer) - ps, st = Lux.setup(rng, layer) .|> device - - @test size(ps.weight) == (embed_size, vocab_size) - - x = rand(1:vocab_size, 1)[1] - y, st_ = layer(x, ps, st) - @test size(layer(x, ps, st)[1]) == (embed_size,) - @test y == ps.weight[:, x] - - @jet layer(x, ps, st) - - x = rand(1:vocab_size, 3) |> aType - y, st_ = layer(x, ps, st) - @test y isa aType{Float32} - @test y == ps.weight[:, x] - - @jet layer(x, ps, st) - - x = rand(1:vocab_size, 3, 4) |> aType - y, st_ = layer(x, ps, st) - @test y isa aType{Float32, 3} - @test size(y) == (embed_size, 3, 4) - - @jet layer(x, ps, st) -end diff --git a/test/layers/basic_tests.jl b/test/layers/basic_tests.jl new file mode 100644 index 0000000000..7b936dd48d --- /dev/null +++ b/test/layers/basic_tests.jl @@ -0,0 +1,347 @@ +@testitem "Miscellaneous Layers" setup=[SharedTestSetup] begin + rng = get_stable_rng(12345) + + @testset "$mode" for (mode, aType, device, ongpu) in MODES + @testset "Reshape Layer" begin + layer = ReshapeLayer((2, 3)) + __display(layer) + ps, st = Lux.setup(rng, layer) .|> device + x = randn(rng, 6, 3) |> aType + + @test size(layer(x, ps, st)[1]) == (2, 3, 3) + + @jet layer(x, ps, st) + __f = x -> sum(first(layer(x, ps, st))) + @eval @test_gradients $__f $x gpu_testing=$ongpu atol=1.0f-3 rtol=1.0f-3 + end + + @testset "Flatten Layer" begin + layer = FlattenLayer() + __display(layer) + ps, st = Lux.setup(rng, layer) .|> device + x = randn(rng, 6, 3, 2) |> aType + + @test size(layer(x, ps, st)[1]) == (18, 2) + + @jet layer(x, ps, st) + __f = x -> sum(first(layer(x, ps, st))) + @eval @test_gradients $__f $x gpu_testing=$ongpu atol=1.0f-3 rtol=1.0f-3 + end + + @testset "NoOpLayer" begin + layer = NoOpLayer() + __display(layer) + ps, st = Lux.setup(rng, layer) .|> device + x = (x=2, b=5) # Something totally arbitrary + + @test layer(x, ps, st)[1] == x + + @jet layer(x, ps, st) + + x = randn(rng, 6, 3) |> aType + __f = x -> sum(first(layer(x, ps, st))) + @eval @test_gradients $__f $x gpu_testing=$ongpu atol=1.0f-3 rtol=1.0f-3 + end + + @testset "SelectDim Layer" begin + layer = SelectDim(3, 1) + __display(layer) + ps, st = Lux.setup(rng, layer) .|> device + x = randn(rng, 6, 4, 3, 2) |> aType + + @test size(layer(x, ps, st)[1]) == (6, 4, 2) + + @jet layer(x, ps, st) + __f = x -> sum(first(layer(x, ps, st))) + @eval @test_gradients $__f $x gpu_testing=$ongpu atol=1.0f-3 rtol=1.0f-3 + end + + @testset "WrappedFunction" begin + layer = WrappedFunction(x -> x .* x) + __display(layer) + ps, st = Lux.setup(rng, layer) .|> device + x = randn(rng, 6, 4, 3, 2) |> aType + + @test layer(x, ps, st)[1] == x .* x + + @jet layer(x, ps, st) + __f = x -> sum(first(layer(x, ps, st))) + @eval @test_gradients $__f $x gpu_testing=$ongpu atol=1.0f-3 rtol=1.0f-3 + end + end +end + +@testitem "Dense" setup=[SharedTestSetup] begin + rng = get_stable_rng(12345) + + @testset "$mode" for (mode, aType, device, ongpu) in MODES + @testset "constructors" begin + layer = Dense(10, 100) + ps, st = Lux.setup(rng, layer) .|> device + + @test size(ps.weight) == (100, 10) + @test size(ps.bias) == (100, 1) + @test layer.activation == identity + + layer = Dense(10, 100, relu; use_bias=false) + ps, st = Lux.setup(rng, layer) .|> device + + @test !haskey(ps, :bias) + @test layer.activation == relu + end + + @testset "allow fast activation" begin + layer = Dense(10, 10, tanh) + @test layer.activation == tanh_fast + layer = Dense(10, 10, tanh; allow_fast_activation=false) + @test layer.activation == tanh + end + + @testset "dimensions" begin + layer = Dense(10, 5) + ps, st = Lux.setup(rng, layer) + + @test size(first(Lux.apply(layer, randn(10), ps, st))) == (5,) + @test size(first(Lux.apply(layer, randn(10, 2), ps, st))) == (5, 2) + end + + @testset "zeros" begin + @test begin + layer = Dense(10, 1, identity; + init_weight=(rng, args...; kwargs...) -> ones(args...; kwargs...)) + first(Lux.apply(layer, + ones(10, 1) |> aType, + device.(Lux.setup(rng, layer))...)) + end == 10 * aType(ones(1, 1)) + + @test begin + layer = Dense(10, 1, identity; + init_weight=(rng, args...; kwargs...) -> ones(args...; kwargs...)) + first(Lux.apply(layer, + ones(10, 2) |> aType, + device.(Lux.setup(rng, layer))...)) + end == 10 * aType(ones(1, 2)) + + @test begin + layer = Dense(10, 2, identity; + init_weight=(rng, args...; kwargs...) -> ones(args...; kwargs...)) + first(Lux.apply(layer, + ones(10, 1) |> aType, + device.(Lux.setup(rng, layer))...)) + end == 10 * aType(ones(2, 1)) + + @test begin + layer = Dense(10, 2, identity; + init_weight=(rng, args...; kwargs...) -> ones(args...; kwargs...)) + first(Lux.apply(layer, + aType([ones(10, 1) 2 * ones(10, 1)]), + device.(Lux.setup(rng, layer))...)) + end == aType([10 20; 10 20]) + + @test begin + layer = Dense(10, 2, identity; + init_weight=(rng, args...; kwargs...) -> ones(args...; kwargs...), + use_bias=false) + first(Lux.apply(layer, + aType([ones(10, 1) 2 * ones(10, 1)]), + device.(Lux.setup(rng, layer))...)) + end == aType([10 20; 10 20]) + end + end +end + +@testitem "Scale" setup=[SharedTestSetup] begin + rng = get_stable_rng(12345) + + @testset "$mode" for (mode, aType, device, ongpu) in MODES + @testset "constructors" begin + layer = Scale(10, 100) + ps, st = Lux.setup(rng, layer) .|> device + + @test size(ps.weight) == (10, 100) + @test size(ps.bias) == (10, 100) + @test layer.activation == identity + + layer = Scale(10, 100, relu; use_bias=false) + ps, st = Lux.setup(rng, layer) .|> device + + @test !haskey(ps, :bias) + @test layer.activation == relu + end + + @testset "allow fast activation" begin + layer = Scale(10, 5, tanh) + @test layer.activation == tanh_fast + layer = Scale(10, 5, tanh; allow_fast_activation=false) + @test layer.activation == tanh + end + + @testset "dimensions" begin + layer = Scale(10, 5) + ps, st = Lux.setup(rng, layer) .|> device + + @test size(first(Lux.apply(layer, randn(10) |> aType, ps, st))) == (10, 5) + @test size(first(Lux.apply(layer, randn(10, 5, 2) |> aType, ps, st))) == + (10, 5, 2) + end + + @testset "zeros" begin + @test begin + layer = Scale(10, 1, identity; + init_weight=(rng, args...; kwargs...) -> ones(args...; kwargs...)) + first(Lux.apply(layer, + ones(10, 1) |> aType, + device.(Lux.setup(rng, layer))...)) + end == aType(ones(10, 1)) + + @test begin + layer = Scale(10, 1, identity; + init_weight=(rng, args...; kwargs...) -> ones(args...; kwargs...)) + first(Lux.apply(layer, ones(10, 2) |> aType, + device.(Lux.setup(rng, layer))...)) + end == aType(ones(10, 2)) + + @test begin + layer = Scale(2, identity; + init_weight=(rng, args...; kwargs...) -> ones(args...; kwargs...), + init_bias=(rng, args...; kwargs...) -> ones(args...; kwargs...)) + first(Lux.apply(layer, + [1 2; 3 4] |> aType, + device.(Lux.setup(rng, layer))...)) + end == aType([2.0 3.0; 4.0 5.0]) + + @test begin + layer = Scale(2, tanh; use_bias=false, + init_weight=(rng, args...; kwargs...) -> zeros(args...; kwargs...)) + first(Lux.apply(layer, [1 2; 3 4] |> aType, + device.(Lux.setup(rng, layer))...)) + end == aType(zeros(2, 2)) + end + end +end + +@testitem "Bilinear" setup=[SharedTestSetup] begin + rng = get_stable_rng(12345) + + @testset "$mode" for (mode, aType, device, ongpu) in MODES + @testset "SkipConnection recombinator" begin + d = Dense(2 => 2) + __display(d) + b = Bilinear((2, 2) => 3) + __display(b) + layer = SkipConnection(d, b) + __display(layer) + ps, st = Lux.setup(rng, layer) .|> device + x = randn(rng, Float32, 2, 1) |> aType + + @test size(layer(x, ps, st)[1]) == (3, 1) + + @jet layer(x, ps, st) + __f = (x, ps) -> sum(first(layer(x, ps, st))) + @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu + + d = Dense(2 => 2) + __display(d) + b = Bilinear((2, 2) => 3; use_bias=false) + __display(b) + layer = SkipConnection(d, b) + __display(layer) + ps, st = Lux.setup(rng, layer) .|> device + x = randn(rng, Float32, 2, 1) |> aType + + @test size(layer(x, ps, st)[1]) == (3, 1) + + @jet layer(x, ps, st) + __f = (x, ps) -> sum(first(layer(x, ps, st))) + @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu + + d = Dense(2 => 3) + __display(d) + b = Bilinear((3, 2) => 5) + __display(b) + layer = SkipConnection(d, b) + __display(layer) + ps, st = Lux.setup(rng, layer) .|> device + x = randn(rng, Float32, 2, 7, 11) |> aType + + @test size(layer(x, ps, st)[1]) == (5, 7, 11) + + @jet layer(x, ps, st) + __f = (x, ps) -> sum(first(layer(x, ps, st))) + @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu + end + + @testset "Two-streams zero sum" begin + x = zeros(Float32, 2, 1) |> aType + y = zeros(Float32, 1, 1) |> aType + layer = Bilinear((2, 1) => 3) + __display(layer) + ps, st = Lux.setup(rng, layer) .|> device + + @test size(layer((x, y), ps, st)[1]) == (3, 1) + @test sum(abs2, layer((x, y), ps, st)[1]) == 0.0f0 + + @jet layer((x, y), ps, st) + __f = (x, y, ps) -> sum(first(layer((x, y), ps, st))) + @eval @test_gradients $__f $x $y $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu + end + + @testset "Inner interactions" begin + x = randn(Float32, 2, 1) |> aType + layer = Bilinear((2, 2) => 3) + __display(layer) + ps, st = Lux.setup(rng, layer) .|> device + + @test size(layer(x, ps, st)[1]) == (3, 1) + + @jet layer(x, ps, st) + __f = (x, ps) -> sum(first(layer(x, ps, st))) + @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu + + x = randn(Float32, 2, 1) |> aType + layer = Bilinear(2 => 3) + __display(layer) + ps, st = Lux.setup(rng, layer) .|> device + + @test size(layer(x, ps, st)[1]) == (3, 1) + + @jet layer(x, ps, st) + __f = (x, ps) -> sum(first(layer(x, ps, st))) + @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu + end + end +end + +@testitem "Embedding" setup=[SharedTestSetup] begin + rng = get_stable_rng(12345) + + @testset "$mode" for (mode, aType, device, ongpu) in MODES + vocab_size, embed_size = 10, 4 + layer = Embedding(vocab_size => embed_size) + __display(layer) + ps, st = Lux.setup(rng, layer) .|> device + + @test size(ps.weight) == (embed_size, vocab_size) + + x = rand(1:vocab_size, 1)[1] + y, st_ = layer(x, ps, st) + @test size(layer(x, ps, st)[1]) == (embed_size,) + @test y == ps.weight[:, x] + + @jet layer(x, ps, st) + + x = rand(1:vocab_size, 3) |> aType + y, st_ = layer(x, ps, st) + @test y isa aType{Float32} + @test y == ps.weight[:, x] + + @jet layer(x, ps, st) + + x = rand(1:vocab_size, 3, 4) |> aType + y, st_ = layer(x, ps, st) + @test y isa aType{Float32, 3} + @test size(y) == (embed_size, 3, 4) + + @jet layer(x, ps, st) + end +end diff --git a/test/layers/containers.jl b/test/layers/containers.jl deleted file mode 100644 index fa42b0e2bc..0000000000 --- a/test/layers/containers.jl +++ /dev/null @@ -1,373 +0,0 @@ -using Lux, Random, Test, Zygote - -include("../test_utils.jl") - -rng = get_stable_rng(12345) - -@testset "$mode: SkipConnection" for (mode, aType, device, ongpu) in MODES - @testset "zero sum" begin - layer = SkipConnection(WrappedFunction(zero), (a, b) -> a .+ b) - __display(layer) - ps, st = Lux.setup(rng, layer) .|> device - x = randn(rng, 10, 10, 10, 10) |> aType - - @test layer(x, ps, st)[1] == x - - @jet layer(x, ps, st) - __f = x -> sum(first(layer(x, ps, st))) - @eval @test_gradients $__f $x atol=1.0f-3 rtol=1.0f-3 reverse_diff_broken=true gpu_testing=$ongpu - end - - @testset "concat size" begin - layer = SkipConnection(Dense(10, 10), (a, b) -> hcat(a, b)) - __display(layer) - ps, st = Lux.setup(rng, layer) .|> device - x = randn(rng, 10, 2) |> aType - - @test size(layer(x, ps, st)[1]) == (10, 4) - - @jet layer(x, ps, st) - __f = (x, ps) -> sum(first(layer(x, ps, st))) - @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu - end -end - -@testset "$mode: Parallel" for (mode, aType, device, ongpu) in MODES - @testset "zero sum" begin - layer = Parallel(+, WrappedFunction(zero), NoOpLayer()) - __display(layer) - ps, st = Lux.setup(rng, layer) .|> device - x = randn(rng, 10, 10, 10, 10) |> aType - - @test layer(x, ps, st)[1] == x - - @jet layer(x, ps, st) - __f = x -> sum(first(layer(x, ps, st))) - @eval @test_gradients $__f $x atol=1.0f-3 rtol=1.0f-3 reverse_diff_broken=true gpu_testing=$ongpu - end - - @testset "concat size" begin - layer = Parallel((a, b) -> cat(a, b; dims=2), Dense(10, 10), NoOpLayer()) - __display(layer) - ps, st = Lux.setup(rng, layer) .|> device - x = randn(rng, 10, 2) |> aType - - @test size(layer(x, ps, st)[1]) == (10, 4) - - @jet layer(x, ps, st) - __f = (x, ps) -> sum(first(layer(x, ps, st))) - @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu - - layer = Parallel(hcat, Dense(10, 10), NoOpLayer()) - __display(layer) - ps, st = Lux.setup(rng, layer) .|> device - - @test size(layer(x, ps, st)[1]) == (10, 4) - - @jet layer(x, ps, st) - __f = (x, ps) -> sum(first(layer(x, ps, st))) - @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu - end - - @testset "vararg input" begin - layer = Parallel(+, Dense(10, 2), Dense(5, 2), Dense(4, 2)) - __display(layer) - ps, st = Lux.setup(rng, layer) .|> device - x = (randn(rng, 10, 1), randn(rng, 5, 1), randn(rng, 4, 1)) .|> aType - - @test size(layer(x, ps, st)[1]) == (2, 1) - - @jet layer(x, ps, st) - __f = (x1, x2, x3, ps) -> sum(first(layer((x1, x2, x3), ps, st))) - @eval @test_gradients $__f $(x[1]) $(x[2]) $(x[3]) $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu - end - - @testset "named layers" begin - layer = Parallel(+; d102=Dense(10, 2), d52=Dense(5, 2), d42=Dense(4, 2)) - __display(layer) - ps, st = Lux.setup(rng, layer) .|> device - x = (randn(rng, 10, 1), randn(rng, 5, 1), randn(rng, 4, 1)) .|> aType - - @test size(layer(x, ps, st)[1]) == (2, 1) - - @jet layer(x, ps, st) - __f = (x1, x2, x3, ps) -> sum(first(layer((x1, x2, x3), ps, st))) - @eval @test_gradients $__f $(x[1]) $(x[2]) $(x[3]) $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu - end - - @testset "connection is called once" begin - CNT = Ref(0) - f_cnt = (x...) -> (CNT[] += 1; +(x...)) - layer = Parallel(f_cnt, - WrappedFunction(sin), - WrappedFunction(cos), - WrappedFunction(tan)) - ps, st = Lux.setup(rng, layer) .|> device - Lux.apply(layer, 1, ps, st) - @test CNT[] == 1 - @jet layer(1, ps, st) - Lux.apply(layer, (1, 2, 3), ps, st) - @test CNT[] == 2 - layer = Parallel(f_cnt, WrappedFunction(sin)) - Lux.apply(layer, 1, ps, st) - @test CNT[] == 3 - end - - # Ref https://github.com/FluxML/Flux.jl/issues/1673 - @testset "Input domain" begin - struct Input - x - end - - struct L1 <: Lux.AbstractExplicitLayer end - (::L1)(x, ps, st) = (ps.x * x, st) - Lux.initialparameters(rng::AbstractRNG, ::L1) = (x=randn(rng, Float32, 3, 3),) - Base.:*(a::AbstractArray, b::Input) = a * b.x - - par = Parallel(+, L1(), L1()) - ps, st = Lux.setup(rng, par) .|> device - - ip = Input(rand(Float32, 3, 3) |> aType) - ip2 = Input(rand(Float32, 3, 3) |> aType) - - @test check_approx(par(ip, ps, st)[1], - par.layers[1](ip.x, ps.layer_1, st.layer_1)[1] + - par.layers[2](ip.x, ps.layer_2, st.layer_2)[1]) - @test check_approx(par((ip, ip2), ps, st)[1], - par.layers[1](ip.x, ps.layer_1, st.layer_1)[1] + - par.layers[2](ip2.x, ps.layer_2, st.layer_2)[1]) - gs = Zygote.gradient((p, x...) -> sum(par(x, p, st)[1]), ps, ip, ip2) - gs_reg = Zygote.gradient(ps, ip, ip2) do p, x, y - return sum(par.layers[1](x.x, p.layer_1, st.layer_1)[1] + - par.layers[2](y.x, p.layer_2, st.layer_2)[1]) - end - - @test check_approx(gs[1], gs_reg[1]) - @test check_approx(gs[2].x, gs_reg[2].x) - @test check_approx(gs[3].x, gs_reg[3].x) - end -end - -@testset "$mode: PairwiseFusion" for (mode, aType, device, ongpu) in MODES - x = (rand(Float32, 1, 10), rand(Float32, 30, 10), rand(Float32, 10, 10)) .|> aType - layer = PairwiseFusion(+, Dense(1, 30), Dense(30, 10)) - __display(layer) - ps, st = Lux.setup(rng, layer) .|> device - y, _ = layer(x, ps, st) - @test size(y) == (10, 10) - - @jet layer(x, ps, st) - __f = (x1, x2, x3, ps) -> sum(first(layer((x1, x2, x3), ps, st))) - @eval @test_gradients $__f $(x[1]) $(x[2]) $(x[3]) $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu - - layer = PairwiseFusion(+; d1=Dense(1, 30), d2=Dense(30, 10)) - __display(layer) - ps, st = Lux.setup(rng, layer) .|> device - y, _ = layer(x, ps, st) - @test size(y) == (10, 10) - @jet layer(x, ps, st) - __f = (x1, x2, x3, ps) -> sum(first(layer((x1, x2, x3), ps, st))) - @eval @test_gradients $__f $(x[1]) $(x[2]) $(x[3]) $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu - - x = rand(1, 10) - layer = PairwiseFusion(.+, Dense(1, 10), Dense(10, 1)) - __display(layer) - ps, st = Lux.setup(rng, layer) - y, _ = layer(x, ps, st) - @test size(y) == (1, 10) - - @jet layer(x, ps, st) - __f = (x, ps) -> sum(first(layer(x, ps, st))) - @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu - - layer = PairwiseFusion(vcat, - WrappedFunction(x -> x .+ 1), - WrappedFunction(x -> x .+ 2), - WrappedFunction(x -> x .^ 3)) - __display(layer) - ps, st = Lux.setup(rng, layer) .|> device - @test layer((2, 10, 20, 40), ps, st)[1] == [125, 1728, 8000, 40] - - layer = PairwiseFusion(vcat, - WrappedFunction(x -> x .+ 1), - WrappedFunction(x -> x .+ 2), - WrappedFunction(x -> x .^ 3)) - __display(layer) - ps, st = Lux.setup(rng, layer) .|> device - @test layer(7, ps, st)[1] == [1000, 729, 343, 7] -end - -@testset "$mode: BranchLayer" for (mode, aType, device, ongpu) in MODES - layer = BranchLayer(Dense(10, 10), Dense(10, 10)) - __display(layer) - ps, st = Lux.setup(rng, layer) .|> device - x = rand(Float32, 10, 1) |> aType - (y1, y2), _ = layer(x, ps, st) - @test size(y1) == (10, 1) - @test size(y2) == (10, 1) - @test y1 == layer.layers.layer_1(x, ps.layer_1, st.layer_1)[1] - @test y2 == layer.layers.layer_2(x, ps.layer_2, st.layer_2)[1] - - @jet layer(x, ps, st) - __f = (x, ps) -> sum(sum, first(layer(x, ps, st))) - @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu - - layer = BranchLayer(; d1=Dense(10, 10), d2=Dense(10, 10)) - __display(layer) - ps, st = Lux.setup(rng, layer) - x = rand(Float32, 10, 1) - (y1, y2), _ = layer(x, ps, st) - @test size(y1) == (10, 1) - @test size(y2) == (10, 1) - @test y1 == layer.layers.d1(x, ps.d1, st.d1)[1] - @test y2 == layer.layers.d2(x, ps.d2, st.d2)[1] - - @jet layer(x, ps, st) - __f = (x, ps) -> sum(sum, first(layer(x, ps, st))) - @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu -end - -@testset "$mode: Chain" for (mode, aType, device, ongpu) in MODES - layer = Chain(Dense(10 => 5, sigmoid), Dense(5 => 2, tanh), Dense(2 => 1)) - __display(layer) - ps, st = Lux.setup(rng, layer) .|> device - x = rand(Float32, 10, 1) |> aType - y, _ = layer(x, ps, st) - @test size(y) == (1, 1) - - @jet layer(x, ps, st) - __f = (x, ps) -> sum(first(layer(x, ps, st))) - @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu - - layer = Chain(; l1=Dense(10 => 5, sigmoid), d52=Dense(5 => 2, tanh), d21=Dense(2 => 1)) - __display(layer) - ps, st = Lux.setup(rng, layer) .|> device - x = rand(Float32, 10, 1) |> aType - y, _ = layer(x, ps, st) - @test size(y) == (1, 1) - - @jet layer(x, ps, st) - __f = (x, ps) -> sum(first(layer(x, ps, st))) - @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu - - layer = Chain(; l1=Dense(10 => 5, sigmoid), d52=Dense(5 => 2, tanh), d21=Dense(2 => 1)) - __display(layer) - layer = layer[1:2] - ps, st = Lux.setup(rng, layer) .|> device - x = rand(Float32, 10, 1) |> aType - y, _ = layer(x, ps, st) - @test size(y) == (2, 1) - - @jet layer(x, ps, st) - __f = (x, ps) -> sum(first(layer(x, ps, st))) - @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu - - layer = Chain(; l1=Dense(10 => 5, sigmoid), d52=Dense(5 => 2, tanh), d21=Dense(2 => 1)) - __display(layer) - layer = layer[begin:(end - 1)] - ps, st = Lux.setup(rng, layer) .|> device - x = rand(Float32, 10, 1) |> aType - y, _ = layer(x, ps, st) - @test size(y) == (2, 1) - - @jet layer(x, ps, st) - __f = (x, ps) -> sum(first(layer(x, ps, st))) - @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu - - layer = Chain(; l1=Dense(10 => 5, sigmoid), d52=Dense(5 => 2, tanh), d21=Dense(2 => 1)) - __display(layer) - layer = layer[1] - ps, st = Lux.setup(rng, layer) .|> device - x = rand(Float32, 10, 1) |> aType - y, _ = layer(x, ps, st) - @test size(y) == (5, 1) - - @jet layer(x, ps, st) - __f = (x, ps) -> sum(first(layer(x, ps, st))) - @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu - - @test_throws ArgumentError Chain(; - l1=Dense(10 => 5, sigmoid), - d52=Dense(5 => 2, tanh), - d21=Dense(2 => 1), - d2=Dense(2 => 1), - disable_optimizations=false) -end - -@testset "$mode: Maxout" for (mode, aType, device, ongpu) in MODES - @testset "constructor" begin - layer = Maxout(() -> NoOpLayer(), 4) - __display(layer) - ps, st = Lux.setup(rng, layer) .|> device - x = rand(rng, Float32, 10, 1) |> aType - - @test layer(x, ps, st)[1] == x - - @jet layer(x, ps, st) - end - - @testset "simple alternatives" begin - layer = Maxout(NoOpLayer(), WrappedFunction(x -> 2x), WrappedFunction(x -> 0.5x)) - __display(layer) - ps, st = Lux.setup(rng, layer) .|> device - x = Float32.(collect(1:40)) |> aType - - @test layer(x, ps, st)[1] == 2 .* x - - @jet layer(x, ps, st) - __f = x -> sum(first(layer(x, ps, st))) - @eval @test_gradients $__f $x atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu - end - - @testset "complex alternatives" begin - A = aType([0.5 0.1]') - B = aType([0.2 0.7]') - layer = Maxout(WrappedFunction(x -> A * x), WrappedFunction(x -> B * x)) - __display(layer) - ps, st = Lux.setup(rng, layer) .|> device - x = [3.0 2.0] |> aType - y = aType([0.5, 0.7]) .* x - - @test layer(x, ps, st)[1] == y - - @jet layer(x, ps, st) - __f = x -> sum(first(layer(x, ps, st))) - @eval @test_gradients $__f $x atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu - end - - @testset "params" begin - layer = Maxout(() -> Dense(2, 4), 4) - __display(layer) - ps, st = Lux.setup(rng, layer) .|> device - x = [10.0f0 3.0f0]' |> aType - - @test Lux.parameterlength(layer) == sum(Lux.parameterlength.(values(layer.layers))) - @test size(layer(x, ps, st)[1]) == (4, 1) - - @jet layer(x, ps, st) - __f = (x, ps) -> sum(first(layer(x, ps, st))) - @eval @test_gradients $__f $x $ps atol=1.0f-1 rtol=1.0f-1 gpu_testing=$ongpu - end -end - -@testset "$mode: Repeated" for (mode, aType, device, ongpu) in MODES - LAYERS = [Dense(2 => 2), Parallel(+, Dense(2 => 2), Dense(2 => 2)), Dense(2 => 2), - Parallel(+, Dense(2 => 2), Dense(2 => 2))] - REPEATS = [Val(4), Val(4), Val(4), Val(4)] - INJECTION = [Val(false), Val(true), Val(false), Val(true)] - - @testset "repeats = $(repeats); input_injection = $(input_injection)" for (layer, repeats, input_injection) in zip( - LAYERS, - REPEATS, INJECTION) - layer = RepeatedLayer(layer; repeats, input_injection) - __display(layer) - ps, st = Lux.setup(rng, layer) .|> device - x = rand(rng, Float32, 2, 12) |> aType - - @test size(layer(x, ps, st)[1]) == (2, 12) - - @jet layer(x, ps, st) - __f = (x, ps) -> sum(first(layer(x, ps, st))) - @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu - end -end diff --git a/test/layers/containers_tests.jl b/test/layers/containers_tests.jl new file mode 100644 index 0000000000..a2cdd3cd8e --- /dev/null +++ b/test/layers/containers_tests.jl @@ -0,0 +1,402 @@ + +@testitem "SkipConnection" setup=[SharedTestSetup] begin + rng = get_stable_rng(12345) + + @testset "$mode" for (mode, aType, device, ongpu) in MODES + @testset "zero sum" begin + layer = SkipConnection(WrappedFunction(zero), (a, b) -> a .+ b) + __display(layer) + ps, st = Lux.setup(rng, layer) .|> device + x = randn(rng, 10, 10, 10, 10) |> aType + + @test layer(x, ps, st)[1] == x + + @jet layer(x, ps, st) + __f = x -> sum(first(layer(x, ps, st))) + @eval @test_gradients $__f $x atol=1.0f-3 rtol=1.0f-3 reverse_diff_broken=true gpu_testing=$ongpu + end + + @testset "concat size" begin + layer = SkipConnection(Dense(10, 10), (a, b) -> hcat(a, b)) + __display(layer) + ps, st = Lux.setup(rng, layer) .|> device + x = randn(rng, 10, 2) |> aType + + @test size(layer(x, ps, st)[1]) == (10, 4) + + @jet layer(x, ps, st) + __f = (x, ps) -> sum(first(layer(x, ps, st))) + @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu + end + end +end + +@testitem "Parallel" setup=[SharedTestSetup] begin + rng = get_stable_rng(12345) + + @testset "$mode" for (mode, aType, device, ongpu) in MODES + @testset "zero sum" begin + layer = Parallel(+, WrappedFunction(zero), NoOpLayer()) + __display(layer) + ps, st = Lux.setup(rng, layer) .|> device + x = randn(rng, 10, 10, 10, 10) |> aType + + @test layer(x, ps, st)[1] == x + + @jet layer(x, ps, st) + __f = x -> sum(first(layer(x, ps, st))) + @eval @test_gradients $__f $x atol=1.0f-3 rtol=1.0f-3 reverse_diff_broken=true gpu_testing=$ongpu + end + + @testset "concat size" begin + layer = Parallel((a, b) -> cat(a, b; dims=2), Dense(10, 10), NoOpLayer()) + __display(layer) + ps, st = Lux.setup(rng, layer) .|> device + x = randn(rng, 10, 2) |> aType + + @test size(layer(x, ps, st)[1]) == (10, 4) + + @jet layer(x, ps, st) + __f = (x, ps) -> sum(first(layer(x, ps, st))) + @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu + + layer = Parallel(hcat, Dense(10, 10), NoOpLayer()) + __display(layer) + ps, st = Lux.setup(rng, layer) .|> device + + @test size(layer(x, ps, st)[1]) == (10, 4) + + @jet layer(x, ps, st) + __f = (x, ps) -> sum(first(layer(x, ps, st))) + @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu + end + + @testset "vararg input" begin + layer = Parallel(+, Dense(10, 2), Dense(5, 2), Dense(4, 2)) + __display(layer) + ps, st = Lux.setup(rng, layer) .|> device + x = (randn(rng, 10, 1), randn(rng, 5, 1), randn(rng, 4, 1)) .|> aType + + @test size(layer(x, ps, st)[1]) == (2, 1) + + @jet layer(x, ps, st) + __f = (x1, x2, x3, ps) -> sum(first(layer((x1, x2, x3), ps, st))) + @eval @test_gradients $__f $(x[1]) $(x[2]) $(x[3]) $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu + end + + @testset "named layers" begin + layer = Parallel(+; d102=Dense(10, 2), d52=Dense(5, 2), d42=Dense(4, 2)) + __display(layer) + ps, st = Lux.setup(rng, layer) .|> device + x = (randn(rng, 10, 1), randn(rng, 5, 1), randn(rng, 4, 1)) .|> aType + + @test size(layer(x, ps, st)[1]) == (2, 1) + + @jet layer(x, ps, st) + __f = (x1, x2, x3, ps) -> sum(first(layer((x1, x2, x3), ps, st))) + @eval @test_gradients $__f $(x[1]) $(x[2]) $(x[3]) $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu + end + + @testset "connection is called once" begin + CNT = Ref(0) + f_cnt = (x...) -> (CNT[] += 1; +(x...)) + layer = Parallel(f_cnt, + WrappedFunction(sin), + WrappedFunction(cos), + WrappedFunction(tan)) + ps, st = Lux.setup(rng, layer) .|> device + Lux.apply(layer, 1, ps, st) + @test CNT[] == 1 + @jet layer(1, ps, st) + Lux.apply(layer, (1, 2, 3), ps, st) + @test CNT[] == 2 + layer = Parallel(f_cnt, WrappedFunction(sin)) + Lux.apply(layer, 1, ps, st) + @test CNT[] == 3 + end + + # Ref https://github.com/FluxML/Flux.jl/issues/1673 + @testset "Input domain" begin + struct Input + x + end + + struct L1 <: Lux.AbstractExplicitLayer end + (::L1)(x, ps, st) = (ps.x * x, st) + Lux.initialparameters(rng::AbstractRNG, ::L1) = (x=randn(rng, Float32, 3, 3),) + Base.:*(a::AbstractArray, b::Input) = a * b.x + + par = Parallel(+, L1(), L1()) + ps, st = Lux.setup(rng, par) .|> device + + ip = Input(rand(Float32, 3, 3) |> aType) + ip2 = Input(rand(Float32, 3, 3) |> aType) + + @test check_approx(par(ip, ps, st)[1], + par.layers[1](ip.x, ps.layer_1, st.layer_1)[1] + + par.layers[2](ip.x, ps.layer_2, st.layer_2)[1]) + @test check_approx(par((ip, ip2), ps, st)[1], + par.layers[1](ip.x, ps.layer_1, st.layer_1)[1] + + par.layers[2](ip2.x, ps.layer_2, st.layer_2)[1]) + gs = Zygote.gradient((p, x...) -> sum(par(x, p, st)[1]), ps, ip, ip2) + gs_reg = Zygote.gradient(ps, ip, ip2) do p, x, y + return sum(par.layers[1](x.x, p.layer_1, st.layer_1)[1] + + par.layers[2](y.x, p.layer_2, st.layer_2)[1]) + end + + @test check_approx(gs[1], gs_reg[1]) + @test check_approx(gs[2].x, gs_reg[2].x) + @test check_approx(gs[3].x, gs_reg[3].x) + end + end +end + +@testitem "PairwiseFusion" setup=[SharedTestSetup] begin + rng = get_stable_rng(12345) + + @testset "$mode" for (mode, aType, device, ongpu) in MODES + x = (rand(Float32, 1, 10), rand(Float32, 30, 10), rand(Float32, 10, 10)) .|> aType + layer = PairwiseFusion(+, Dense(1, 30), Dense(30, 10)) + __display(layer) + ps, st = Lux.setup(rng, layer) .|> device + y, _ = layer(x, ps, st) + @test size(y) == (10, 10) + + @jet layer(x, ps, st) + __f = (x1, x2, x3, ps) -> sum(first(layer((x1, x2, x3), ps, st))) + @eval @test_gradients $__f $(x[1]) $(x[2]) $(x[3]) $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu + + layer = PairwiseFusion(+; d1=Dense(1, 30), d2=Dense(30, 10)) + __display(layer) + ps, st = Lux.setup(rng, layer) .|> device + y, _ = layer(x, ps, st) + @test size(y) == (10, 10) + @jet layer(x, ps, st) + __f = (x1, x2, x3, ps) -> sum(first(layer((x1, x2, x3), ps, st))) + @eval @test_gradients $__f $(x[1]) $(x[2]) $(x[3]) $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu + + x = rand(1, 10) + layer = PairwiseFusion(.+, Dense(1, 10), Dense(10, 1)) + __display(layer) + ps, st = Lux.setup(rng, layer) + y, _ = layer(x, ps, st) + @test size(y) == (1, 10) + + @jet layer(x, ps, st) + __f = (x, ps) -> sum(first(layer(x, ps, st))) + @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu + + layer = PairwiseFusion(vcat, + WrappedFunction(x -> x .+ 1), + WrappedFunction(x -> x .+ 2), + WrappedFunction(x -> x .^ 3)) + __display(layer) + ps, st = Lux.setup(rng, layer) .|> device + @test layer((2, 10, 20, 40), ps, st)[1] == [125, 1728, 8000, 40] + + layer = PairwiseFusion(vcat, + WrappedFunction(x -> x .+ 1), + WrappedFunction(x -> x .+ 2), + WrappedFunction(x -> x .^ 3)) + __display(layer) + ps, st = Lux.setup(rng, layer) .|> device + @test layer(7, ps, st)[1] == [1000, 729, 343, 7] + end +end + +@testitem "BranchLayer" setup=[SharedTestSetup] begin + rng = get_stable_rng(12345) + + @testset "$mode" for (mode, aType, device, ongpu) in MODES + layer = BranchLayer(Dense(10, 10), Dense(10, 10)) + __display(layer) + ps, st = Lux.setup(rng, layer) .|> device + x = rand(Float32, 10, 1) |> aType + (y1, y2), _ = layer(x, ps, st) + @test size(y1) == (10, 1) + @test size(y2) == (10, 1) + @test y1 == layer.layers.layer_1(x, ps.layer_1, st.layer_1)[1] + @test y2 == layer.layers.layer_2(x, ps.layer_2, st.layer_2)[1] + + @jet layer(x, ps, st) + __f = (x, ps) -> sum(sum, first(layer(x, ps, st))) + @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu + + layer = BranchLayer(; d1=Dense(10, 10), d2=Dense(10, 10)) + __display(layer) + ps, st = Lux.setup(rng, layer) + x = rand(Float32, 10, 1) + (y1, y2), _ = layer(x, ps, st) + @test size(y1) == (10, 1) + @test size(y2) == (10, 1) + @test y1 == layer.layers.d1(x, ps.d1, st.d1)[1] + @test y2 == layer.layers.d2(x, ps.d2, st.d2)[1] + + @jet layer(x, ps, st) + __f = (x, ps) -> sum(sum, first(layer(x, ps, st))) + @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu + end +end + +@testitem "Chain" setup=[SharedTestSetup] begin + rng = get_stable_rng(12345) + + @testset "$mode" for (mode, aType, device, ongpu) in MODES + layer = Chain(Dense(10 => 5, sigmoid), Dense(5 => 2, tanh), Dense(2 => 1)) + __display(layer) + ps, st = Lux.setup(rng, layer) .|> device + x = rand(Float32, 10, 1) |> aType + y, _ = layer(x, ps, st) + @test size(y) == (1, 1) + + @jet layer(x, ps, st) + __f = (x, ps) -> sum(first(layer(x, ps, st))) + @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu + + layer = Chain(; + l1=Dense(10 => 5, sigmoid), d52=Dense(5 => 2, tanh), d21=Dense(2 => 1)) + __display(layer) + ps, st = Lux.setup(rng, layer) .|> device + x = rand(Float32, 10, 1) |> aType + y, _ = layer(x, ps, st) + @test size(y) == (1, 1) + + @jet layer(x, ps, st) + __f = (x, ps) -> sum(first(layer(x, ps, st))) + @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu + + layer = Chain(; + l1=Dense(10 => 5, sigmoid), d52=Dense(5 => 2, tanh), d21=Dense(2 => 1)) + __display(layer) + layer = layer[1:2] + ps, st = Lux.setup(rng, layer) .|> device + x = rand(Float32, 10, 1) |> aType + y, _ = layer(x, ps, st) + @test size(y) == (2, 1) + + @jet layer(x, ps, st) + __f = (x, ps) -> sum(first(layer(x, ps, st))) + @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu + + layer = Chain(; + l1=Dense(10 => 5, sigmoid), d52=Dense(5 => 2, tanh), d21=Dense(2 => 1)) + __display(layer) + layer = layer[begin:(end - 1)] + ps, st = Lux.setup(rng, layer) .|> device + x = rand(Float32, 10, 1) |> aType + y, _ = layer(x, ps, st) + @test size(y) == (2, 1) + + @jet layer(x, ps, st) + __f = (x, ps) -> sum(first(layer(x, ps, st))) + @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu + + layer = Chain(; + l1=Dense(10 => 5, sigmoid), d52=Dense(5 => 2, tanh), d21=Dense(2 => 1)) + __display(layer) + layer = layer[1] + ps, st = Lux.setup(rng, layer) .|> device + x = rand(Float32, 10, 1) |> aType + y, _ = layer(x, ps, st) + @test size(y) == (5, 1) + + @jet layer(x, ps, st) + __f = (x, ps) -> sum(first(layer(x, ps, st))) + @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu + + @test_throws ArgumentError Chain(; + l1=Dense(10 => 5, sigmoid), + d52=Dense(5 => 2, tanh), + d21=Dense(2 => 1), + d2=Dense(2 => 1), + disable_optimizations=false) + end +end + +@testitem "Maxout" setup=[SharedTestSetup] begin + rng = get_stable_rng(12345) + + @testset "$mode" for (mode, aType, device, ongpu) in MODES + @testset "constructor" begin + layer = Maxout(() -> NoOpLayer(), 4) + __display(layer) + ps, st = Lux.setup(rng, layer) .|> device + x = rand(rng, Float32, 10, 1) |> aType + + @test layer(x, ps, st)[1] == x + + @jet layer(x, ps, st) + end + + @testset "simple alternatives" begin + layer = Maxout(NoOpLayer(), WrappedFunction(x -> 2x), + WrappedFunction(x -> 0.5x)) + __display(layer) + ps, st = Lux.setup(rng, layer) .|> device + x = Float32.(collect(1:40)) |> aType + + @test layer(x, ps, st)[1] == 2 .* x + + @jet layer(x, ps, st) + __f = x -> sum(first(layer(x, ps, st))) + @eval @test_gradients $__f $x atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu + end + + @testset "complex alternatives" begin + A = aType([0.5 0.1]') + B = aType([0.2 0.7]') + layer = Maxout(WrappedFunction(x -> A * x), WrappedFunction(x -> B * x)) + __display(layer) + ps, st = Lux.setup(rng, layer) .|> device + x = [3.0 2.0] |> aType + y = aType([0.5, 0.7]) .* x + + @test layer(x, ps, st)[1] == y + + @jet layer(x, ps, st) + __f = x -> sum(first(layer(x, ps, st))) + @eval @test_gradients $__f $x atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu + end + + @testset "params" begin + layer = Maxout(() -> Dense(2, 4), 4) + __display(layer) + ps, st = Lux.setup(rng, layer) .|> device + x = [10.0f0 3.0f0]' |> aType + + @test Lux.parameterlength(layer) == + sum(Lux.parameterlength.(values(layer.layers))) + @test size(layer(x, ps, st)[1]) == (4, 1) + + @jet layer(x, ps, st) + __f = (x, ps) -> sum(first(layer(x, ps, st))) + @eval @test_gradients $__f $x $ps atol=1.0f-1 rtol=1.0f-1 gpu_testing=$ongpu + end + end +end + +@testitem "Repeated" setup=[SharedTestSetup] begin + rng = get_stable_rng(12345) + + @testset "$mode" for (mode, aType, device, ongpu) in MODES + LAYERS = [Dense(2 => 2), Parallel(+, Dense(2 => 2), Dense(2 => 2)), Dense(2 => 2), + Parallel(+, Dense(2 => 2), Dense(2 => 2))] + REPEATS = [Val(4), Val(4), Val(4), Val(4)] + INJECTION = [Val(false), Val(true), Val(false), Val(true)] + + @testset "repeats = $(repeats); input_injection = $(input_injection)" for (layer, repeats, input_injection) in zip( + LAYERS, + REPEATS, INJECTION) + layer = RepeatedLayer(layer; repeats, input_injection) + __display(layer) + ps, st = Lux.setup(rng, layer) .|> device + x = rand(rng, Float32, 2, 12) |> aType + + @test size(layer(x, ps, st)[1]) == (2, 12) + + @jet layer(x, ps, st) + __f = (x, ps) -> sum(first(layer(x, ps, st))) + @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu + end + end +end diff --git a/test/layers/conv.jl b/test/layers/conv.jl deleted file mode 100644 index 21a8a1159d..0000000000 --- a/test/layers/conv.jl +++ /dev/null @@ -1,651 +0,0 @@ -using Lux, Test, Zygote - -include("../test_utils.jl") - -rng = get_stable_rng(12345) - -@testset "$mode: Pooling" for (mode, aType, device, ongpu) in MODES - x = randn(rng, Float32, 10, 10, 3, 2) |> aType - y = randn(rng, Float32, 20, 20, 3, 2) |> aType - - layer = AdaptiveMaxPool((5, 5)) - __display(layer) - ps, st = Lux.setup(rng, layer) .|> device - - @test layer(x, ps, st)[1] == maxpool(x, PoolDims(x, 2)) - @jet layer(x, ps, st) - - layer = AdaptiveMeanPool((5, 5)) - __display(layer) - ps, st = Lux.setup(rng, layer) .|> device - - @test layer(x, ps, st)[1] == meanpool(x, PoolDims(x, 2)) - @jet layer(x, ps, st) - - layer = AdaptiveMaxPool((10, 5)) - __display(layer) - ps, st = Lux.setup(rng, layer) .|> device - - @test layer(y, ps, st)[1] == maxpool(y, PoolDims(y, (2, 4))) - @jet layer(y, ps, st) - - layer = AdaptiveMeanPool((10, 5)) - __display(layer) - ps, st = Lux.setup(rng, layer) .|> device - - @test layer(y, ps, st)[1] == meanpool(y, PoolDims(y, (2, 4))) - @jet layer(y, ps, st) - - layer = GlobalMaxPool() - __display(layer) - ps, st = Lux.setup(rng, layer) .|> device - - @test size(layer(x, ps, st)[1]) == (1, 1, 3, 2) - @jet layer(x, ps, st) - - layer = GlobalMeanPool() - __display(layer) - ps, st = Lux.setup(rng, layer) .|> device - - @test size(layer(x, ps, st)[1]) == (1, 1, 3, 2) - @jet layer(x, ps, st) - - layer = MaxPool((2, 2)) - __display(layer) - ps, st = Lux.setup(rng, layer) .|> device - - @test layer(x, ps, st)[1] == maxpool(x, PoolDims(x, 2)) - @jet layer(x, ps, st) - - layer = MeanPool((2, 2)) - __display(layer) - ps, st = Lux.setup(rng, layer) .|> device - - @test layer(x, ps, st)[1] == meanpool(x, PoolDims(x, 2)) - @jet layer(x, ps, st) - - @testset "$ltype SamePad windowsize $k" for ltype in (MeanPool, MaxPool), - k in ((1,), (2,), (3,), (4, 5), (6, 7, 8)) - - x = ones(Float32, (k .+ 3)..., 1, 1) |> aType - - layer = ltype(k; pad=Lux.SamePad()) - __display(layer) - ps, st = Lux.setup(rng, layer) .|> device - - @test size(layer(x, ps, st)[1])[1:(end - 2)] == cld.(size(x)[1:(end - 2)], k) - @jet layer(x, ps, st) - end -end - -@testset "$mode: CNN" for (mode, aType, device, ongpu) in MODES - @testset "Grouped Conv" begin - x = rand(rng, Float32, 4, 6, 1) |> aType - layer = Conv((3,), 6 => 2; groups=2) - __display(layer) - ps, st = Lux.setup(rng, layer) .|> device - - broken = false - try - layer(x, ps, st) - broken = false - catch - @warn "$mode Test broken for $layer" - broken = true - end - - if !broken - @test size(ps.weight) == (3, 3, 2) - @test size(layer(x, ps, st)[1]) == (2, 2, 1) - - @jet layer(x, ps, st) - __f = (x, ps) -> sum(first(layer(x, ps, st))) - @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu - else - @test_broken !broken - end - - x = rand(rng, Float32, 4, 4, 6, 1) |> aType - layer = Conv((3, 3), 6 => 2; groups=2) - __display(layer) - ps, st = Lux.setup(rng, layer) .|> device - - broken = false - try - layer(x, ps, st) - broken = false - catch - @warn "$mode Test broken for $layer" - broken = true - end - - if !broken - @test size(ps.weight) == (3, 3, 3, 2) - @test size(layer(x, ps, st)[1]) == (2, 2, 2, 1) - - @jet layer(x, ps, st) - __f = (x, ps) -> sum(first(layer(x, ps, st))) - @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu - else - @test_broken !broken - end - - x = rand(rng, Float32, 4, 4, 4, 6, 1) |> aType - layer = Conv((3, 3, 3), 6 => 2; groups=2) - __display(layer) - ps, st = Lux.setup(rng, layer) .|> device - - broken = false - try - layer(x, ps, st) - broken = false - catch - @warn "$mode Test broken for $layer" - broken = true - end - - if !broken - @test size(ps.weight) == (3, 3, 3, 3, 2) - @test size(layer(x, ps, st)[1]) == (2, 2, 2, 2, 1) - - @jet layer(x, ps, st) - __f = (x, ps) -> sum(first(layer(x, ps, st))) - @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu - else - @test_broken !broken - end - - # Test that we cannot ask for non-integer multiplication factors - layer = Conv((2, 2), 3 => 10; groups=2) - __display(layer) - @test_throws AssertionError Lux.setup(rng, layer) - layer = Conv((2, 2), 2 => 9; groups=2) - __display(layer) - @test_throws AssertionError Lux.setup(rng, layer) - - @testset "Segfault Test LuxDL/Lux.jl#386" begin - layer = Conv((5,), 32 => 32, tanh; groups=32) - __display(layer) - x = rand(rng, Float32, 16, 32, 1) |> aType - ps, st = Lux.setup(rng, layer) .|> device - - broken = false - try - layer(x, ps, st) - broken = false - catch - @warn "$mode Test broken for $layer" - broken = true - end - - if !broken - @jet layer(x, ps, st) - __f = (x, ps) -> sum(first(layer(x, ps, st))) - @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu - else - @test_broken !broken - end - end - end - - @testset "Asymmetric Padding" begin - layer = Conv((3, 3), 1 => 1, relu; pad=(0, 1, 1, 2)) - __display(layer) - x = ones(Float32, 28, 28, 1, 1) |> aType - ps, st = Lux.setup(rng, layer) .|> device - - ps.weight .= 1.0 - ps.bias .= 0.0 - - y_hat = layer(x, ps, st)[1][:, :, 1, 1] |> Array - @test size(y_hat) == (27, 29) - @test check_approx(y_hat[1, 1], 6.0) - @test check_approx(y_hat[2, 2], 9.0) - @test check_approx(y_hat[end, 1], 4.0) - @test check_approx(y_hat[1, end], 3.0) - @test check_approx(y_hat[1, end - 1], 6.0) - @test check_approx(y_hat[end, end], 2.0) - - @jet layer(x, ps, st) - end - - @testset "Variable BitWidth Parameters FluxML/Flux.jl#1421" begin - layer = Conv((5, 5), - 10 => 20, - identity; - init_weight=(rng, dims...) -> aType(randn(rng, Float64, dims...)), - init_bias=(rng, dims...) -> aType(randn(rng, Float16, dims...))) - __display(layer) - ps, st = Lux.setup(rng, layer) - @test ps.weight isa aType{Float64, 4} - @test ps.bias isa aType{Float16, 4} - end - - @testset "Depthwise Conv" begin - x = randn(rng, Float32, 4, 4, 3, 2) |> aType - - layer = Conv((2, 2), 3 => 15; groups=3) - __display(layer) - ps, st = Lux.setup(rng, layer) .|> device - @test Lux.parameterlength(layer) == Lux.parameterlength(ps) - - @test size(layer(x, ps, st)[1], 3) == 15 - - @jet layer(x, ps, st) - __f = (x, ps) -> sum(first(layer(x, ps, st))) - @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu - - layer = Conv((2, 2), 3 => 9; groups=3) - __display(layer) - ps, st = Lux.setup(rng, layer) .|> device - - @test size(layer(x, ps, st)[1], 3) == 9 - - @jet layer(x, ps, st) - __f = (x, ps) -> sum(first(layer(x, ps, st))) - @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu - - layer = Conv((2, 2), 3 => 9; groups=3, use_bias=false) - __display(layer) - ps, st = Lux.setup(rng, layer) .|> device - @test Lux.parameterlength(layer) == Lux.parameterlength(ps) - - @test size(layer(x, ps, st)[1], 3) == 9 - - @jet layer(x, ps, st) - __f = (x, ps) -> sum(first(layer(x, ps, st))) - @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu - - # Test that we cannot ask for non-integer multiplication factors - layer = Conv((2, 2), 3 => 10; groups=3) - __display(layer) - @test_throws AssertionError Lux.setup(rng, layer) - end - - @testset "Conv SamePad kernelsize $k" for k in ((1,), (2,), (3,), (2, 3), (1, 2, 3)) - x = ones(Float32, (k .+ 3)..., 1, 1) |> aType - - @testset "Kwargs: $kwarg" for kwarg in ((; stride=1), - (; dilation=max.(k .÷ 2, 1), stride=1), - (; stride=3)) - layer = Conv(k, 1 => 1; pad=Lux.SamePad(), kwarg...) - __display(layer) - ps, st = Lux.setup(rng, layer) .|> device - - broken = false - try - layer(x, ps, st) - broken = false - catch - @warn "$mode Test broken for $layer" - broken = true - end - - if !broken - if kwarg.stride == 1 - @test size(layer(x, ps, st)[1]) == size(x) - else - @test size(layer(x, ps, st)[1])[1:(end - 2)] == - cld.(size(x)[1:(end - 2)], kwarg.stride) - end - - @jet layer(x, ps, st) - __f = (x, ps) -> sum(first(layer(x, ps, st))) - @eval @test_gradients $__f $x $ps gpu_testing=$ongpu atol=1e-3 rtol=1e-3 - else - @test_broken !broken - end - end - end - - @testset "Conv with non quadratic window FluxML/Flux.jl#700" begin - x = zeros(Float32, 7, 7, 1, 1) - x[4, 4, 1, 1] = 1 - x = x |> aType - - layer = Conv((3, 3), 1 => 1) - __display(layer) - ps, st = Lux.setup(rng, layer) .|> device - - y = zeros(eltype(ps.weight), 5, 5, 1, 1) |> aType - y[2:(end - 1), 2:(end - 1), 1, 1] = ps.weight - @test check_approx(y, layer(x, ps, st)[1]) - - @jet layer(x, ps, st) - - layer = Conv((3, 1), 1 => 1) - __display(layer) - ps, st = Lux.setup(rng, layer) .|> device - - y = zeros(eltype(ps.weight), 5, 7, 1, 1) |> aType - y[2:(end - 1), 4, 1, 1] = ps.weight - @test check_approx(y, layer(x, ps, st)[1]) - - @jet layer(x, ps, st) - - layer = Conv((1, 3), 1 => 1) - __display(layer) - ps, st = Lux.setup(rng, layer) .|> device - - y = zeros(eltype(ps.weight), 7, 5, 1, 1) |> aType - y[4, 2:(end - 1), 1, 1] = ps.weight - @test check_approx(y, layer(x, ps, st)[1]) - - @jet layer(x, ps, st) - - layer = Conv((1, 3), 1 => 1; init_weight=Lux.glorot_normal) - __display(layer) - ps, st = Lux.setup(rng, layer) .|> device - - y = zeros(eltype(ps.weight), 7, 5, 1, 1) |> aType - y[4, 2:(end - 1), 1, 1] = ps.weight - @test check_approx(y, layer(x, ps, st)[1]) - - @jet layer(x, ps, st) - end - - @testset "allow fast activation" begin - layer = Conv((3, 3), 1 => 1, tanh) - @test layer.activation == tanh_fast - layer = Conv((3, 3), 1 => 1, tanh; allow_fast_activation=false) - @test layer.activation == tanh - end -end - -@testset "$mode: Upsample" for (mode, aType, device, ongpu) in MODES - @testset "Construction" begin - @test_nowarn Upsample(:nearest; scale=2) - @test_nowarn Upsample(:nearest; size=(64, 64)) - @test_nowarn Upsample(:bilinear; scale=2) - @test_nowarn Upsample(:bilinear; size=(64, 64)) - @test_nowarn Upsample(:trilinear; scale=2) - @test_nowarn Upsample(:trilinear; size=(64, 64)) - - @test_throws ArgumentError Upsample(:linear; scale=2) - @test_throws ArgumentError Upsample(:nearest; scale=2, size=(64, 64)) - @test_throws ArgumentError Upsample(:nearest) - - @test_nowarn Upsample(2) - @test_nowarn Upsample(2, :nearest) - end - - @testset "Size Correctness" begin - # NNlib is checking algorithmic correctness. So we should just verify correct - # function call - modes = (:nearest, :bilinear, :trilinear) - sizes = (nothing, (64, 64), (64, 32)) - scales = (nothing, 2, (2, 1)) - - for umode in modes, xsize in sizes, scale in scales - if !xor(isnothing(xsize), isnothing(scale)) - continue - end - layer = Upsample(umode; size=xsize, scale=scale) - __display(layer) - ps, st = Lux.setup(rng, layer) .|> device - x = zeros((32, 32, 3, 4)) |> aType - - @jet layer(x, ps, st) - - y, _ = layer(x, ps, st) - if isnothing(scale) - @test size(y)[1:2] == xsize - else - @test size(y)[1:2] == size(x)[1:2] .* scale - end - @test size(y)[3:4] == size(x)[3:4] - end - - sizes = (nothing, (64, 64, 64), (64, 32, 128)) - scales = (nothing, 2, (2, 1, 1), (2, 2, 1)) - - for umode in modes, xsize in sizes, scale in scales - if !xor(isnothing(xsize), isnothing(scale)) - continue - end - layer = Upsample(umode; size=xsize, scale=scale) - __display(layer) - ps, st = Lux.setup(rng, layer) .|> device - x = zeros((32, 32, 32, 3, 4)) |> aType - - @jet layer(x, ps, st) - - y, _ = layer(x, ps, st) - - if isnothing(scale) - @test size(y)[1:3] == xsize - else - @test size(y)[1:3] == size(x)[1:3] .* scale - end - @test size(y)[4:5] == size(x)[4:5] - end - end -end - -@testset "$mode: PixelShuffle" for (mode, aType, device, ongpu) in MODES - layer = PixelShuffle(2) - __display(layer) - ps, st = Lux.setup(rng, layer) .|> device - x = rand(rng, Float32, 3, 6, 3) |> aType - - y, st_ = layer(x, ps, st) - @test y isa aType{Float32, 3} - @test size(y) == (6, 3, 3) - - @jet layer(x, ps, st) - __f = x -> sum(first(layer(x, ps, st))) - @eval @test_gradients $__f $x gpu_testing=$ongpu atol=1e-3 rtol=1e-3 - - layer = PixelShuffle(3) - __display(layer) - ps, st = Lux.setup(rng, layer) .|> device - x = rand(Float32, 3, 4, 9, 3) |> aType - - y, st_ = layer(x, ps, st) - @test y isa aType{Float32, 4} - @test size(y) == (9, 12, 1, 3) - - @jet layer(x, ps, st) - __f = x -> sum(first(layer(x, ps, st))) - @eval @test_gradients $__f $x gpu_testing=$ongpu atol=1e-3 rtol=1e-3 -end - -@testset "$mode: CrossCor" for (mode, aType, device, ongpu) in MODES - @testset "Asymmetric Padding" begin - layer = CrossCor((3, 3), 1 => 1, relu; pad=(0, 1, 1, 2)) - __display(layer) - x = ones(Float32, 28, 28, 1, 1) |> aType - ps, st = Lux.setup(rng, layer) .|> device - - ps.weight .= 1.0 - ps.bias .= 0.0 - - y_hat = layer(x, ps, st)[1][:, :, 1, 1] |> Array - @test size(y_hat) == (27, 29) - @test check_approx(y_hat[1, 1], 6.0) - @test check_approx(y_hat[2, 2], 9.0) - @test check_approx(y_hat[end, 1], 4.0) - @test check_approx(y_hat[1, end], 3.0) - @test check_approx(y_hat[1, end - 1], 6.0) - @test check_approx(y_hat[end, end], 2.0) - - @jet layer(x, ps, st) - end - - @testset "Variable BitWidth Parameters FluxML/Flux.jl#1421" begin - layer = CrossCor((5, 5), - 10 => 20, - identity; - init_weight=(rng, dims...) -> aType(randn(rng, Float64, dims...)), - init_bias=(rng, dims...) -> aType(randn(rng, Float16, dims...))) - __display(layer) - ps, st = Lux.setup(rng, layer) - @test ps.weight isa aType{Float64, 4} - @test ps.bias isa aType{Float16, 4} - end - - @testset "CrossCor SamePad kernelsize $k" for k in ((1,), (2,), (3,), (2, 3), (1, 2, 3)) - x = ones(Float32, (k .+ 3)..., 1, 1) |> aType - - @testset "Kwargs: $kwarg" for kwarg in ((; stride=1), - (; dilation=max.(k .÷ 2, 1), stride=1), - (; stride=3)) - layer = CrossCor(k, 1 => 1; pad=Lux.SamePad(), kwarg...) - __display(layer) - ps, st = Lux.setup(rng, layer) .|> device - - broken = false - try - layer(x, ps, st) - broken = false - catch - @warn "$mode Test broken for $layer" - broken = true - end - - if !broken - if kwarg.stride == 1 - @test size(layer(x, ps, st)[1]) == size(x) - else - @test size(layer(x, ps, st)[1])[1:(end - 2)] == - cld.(size(x)[1:(end - 2)], kwarg.stride) - end - - @jet layer(x, ps, st) - __f = (x, ps) -> sum(first(layer(x, ps, st))) - @eval @test_gradients $__f $x $ps gpu_testing=$ongpu atol=1e-3 rtol=1e-3 - else - @test_broken !broken - end - end - end - - @testset "allow fast activation" begin - layer = CrossCor((3, 3), 1 => 1, tanh) - @test layer.activation == tanh_fast - layer = CrossCor((3, 3), 1 => 1, tanh; allow_fast_activation=false) - @test layer.activation == tanh - end -end - -@testset "$mode: ConvTranspose" for (mode, aType, device, ongpu) in MODES - x = randn(Float32, 5, 5, 1, 1) |> aType - layer = Conv((3, 3), 1 => 1) - ps, st = Lux.setup(rng, layer) .|> device - y = layer(x, ps, st)[1] - - layer = ConvTranspose((3, 3), 1 => 1) - __display(layer) - ps, st = Lux.setup(rng, layer) .|> device - - @jet layer(y, ps, st) opt_broken=true - - x_hat1 = layer(y, ps, st)[1] - - layer = ConvTranspose((3, 3), 1 => 1; use_bias=false) - __display(layer) - ps, st = Lux.setup(rng, layer) .|> device - - @jet layer(y, ps, st) opt_broken=true - - x_hat2 = layer(y, ps, st)[1] - - @test size(x_hat1) == size(x_hat2) == size(x) - - layer = ConvTranspose((3, 3), 1 => 1) - __display(layer) - ps, st = Lux.setup(rng, layer) .|> device - x = rand(Float32, 5, 5, 1, 1) |> aType - - @jet layer(x, ps, st) opt_broken=true - __f = (x, ps) -> sum(first(layer(x, ps, st))) - @eval @test_gradients $__f $x $ps gpu_testing=$ongpu atol=1e-3 rtol=1e-3 - - x = rand(Float32, 5, 5, 2, 4) |> aType - layer = ConvTranspose((3, 3), 2 => 3) - __display(layer) - ps, st = Lux.setup(rng, layer) .|> device - - @jet layer(x, ps, st) opt_broken=true - __f = (x, ps) -> sum(first(layer(x, ps, st))) - @eval @test_gradients $__f $x $ps gpu_testing=$ongpu atol=1e-3 rtol=1e-3 - - # test ConvTranspose supports groups argument - x = randn(Float32, 10, 10, 2, 3) |> aType - layer1 = ConvTranspose((3, 3), 2 => 4; pad=SamePad()) - __display(layer1) - ps1, st1 = Lux.setup(rng, layer1) .|> device - @test size(ps1.weight) == (3, 3, 4, 2) - @test size(layer1(x, ps1, st1)[1]) == (10, 10, 4, 3) - - layer2 = ConvTranspose((3, 3), 2 => 4; groups=2, pad=SamePad()) - __display(layer2) - ps2, st2 = Lux.setup(rng, layer2) .|> device - @test size(ps2.weight) == (3, 3, 2, 2) - @test size(layer1(x, ps1, st1)[1]) == size(layer2(x, ps2, st2)[1]) - - __f = (x, ps) -> sum(first(layer1(x, ps, st1))) - @eval @test_gradients $__f $x $ps1 gpu_testing=$ongpu atol=1e-3 rtol=1e-3 - - __f = (x, ps) -> sum(first(layer2(x, ps, st2))) - @eval @test_gradients $__f $x $ps2 gpu_testing=$ongpu atol=1e-3 rtol=1e-3 - - x = randn(Float32, 10, 2, 1) |> aType - layer = ConvTranspose((3,), 2 => 4; pad=SamePad(), groups=2) - __display(layer) - ps, st = Lux.setup(rng, layer) .|> device - - @jet layer(x, ps, st) opt_broken=true - - @test size(layer(x, ps, st)[1]) == (10, 4, 1) - @test length(ps.weight) == 3 * (2 * 4) / 2 - - __f = (x, ps) -> sum(first(layer(x, ps, st))) - @eval @test_gradients $__f $x $ps gpu_testing=$ongpu atol=1e-3 rtol=1e-3 - - x = randn(Float32, 10, 11, 4, 2) |> aType - layer = ConvTranspose((3, 5), 4 => 4; pad=SamePad(), groups=4) - __display(layer) - ps, st = Lux.setup(rng, layer) .|> device - - @jet layer(x, ps, st) opt_broken=true - - @test size(layer(x, ps, st)[1]) == (10, 11, 4, 2) - @test length(ps.weight) == (3 * 5) * (4 * 4) / 4 - - __f = (x, ps) -> sum(first(layer(x, ps, st))) - @eval @test_gradients $__f $x $ps gpu_testing=$ongpu atol=1e-3 rtol=1e-3 - - x = randn(Float32, 10, 11, 4, 2) |> aType - layer = ConvTranspose((3, 5), 4 => 4, tanh; pad=SamePad(), groups=4) - __display(layer) - ps, st = Lux.setup(rng, layer) .|> device - - @jet layer(x, ps, st) opt_broken=true - @test size(layer(x, ps, st)[1]) == (10, 11, 4, 2) - @test length(ps.weight) == (3 * 5) * (4 * 4) / 4 - - __f = (x, ps) -> sum(first(layer(x, ps, st))) - @eval @test_gradients $__f $x $ps gpu_testing=$ongpu atol=1e-3 rtol=1e-3 - - x = randn(Float32, 10, 11, 12, 3, 2) |> aType - layer = ConvTranspose((3, 5, 3), 3 => 6; pad=SamePad(), groups=3) - __display(layer) - ps, st = Lux.setup(rng, layer) .|> device - - @jet layer(x, ps, st) opt_broken=true - @test size(layer(x, ps, st)[1]) == (10, 11, 12, 6, 2) - @test length(ps.weight) == (3 * 5 * 3) * (3 * 6) / 3 - - x = randn(Float32, 10, 11, 12, 3, 2) |> aType - layer = ConvTranspose((3, 5, 3), 3 => 6, tanh; pad=SamePad(), groups=3) - __display(layer) - ps, st = Lux.setup(rng, layer) .|> device - - @jet layer(x, ps, st) opt_broken=true - @test size(layer(x, ps, st)[1]) == (10, 11, 12, 6, 2) - @test length(ps.weight) == (3 * 5 * 3) * (3 * 6) / 3 - - @test occursin("groups=2", sprint(show, ConvTranspose((3, 3), 2 => 4; groups=2))) - @test occursin("2 => 4", sprint(show, ConvTranspose((3, 3), 2 => 4; groups=2))) -end diff --git a/test/layers/conv_tests.jl b/test/layers/conv_tests.jl new file mode 100644 index 0000000000..3383241655 --- /dev/null +++ b/test/layers/conv_tests.jl @@ -0,0 +1,670 @@ +@testitem "Pooling" setup=[SharedTestSetup] begin + rng = get_stable_rng(12345) + + @testset "$mode" for (mode, aType, device, ongpu) in MODES + x = randn(rng, Float32, 10, 10, 3, 2) |> aType + y = randn(rng, Float32, 20, 20, 3, 2) |> aType + + layer = AdaptiveMaxPool((5, 5)) + __display(layer) + ps, st = Lux.setup(rng, layer) .|> device + + @test layer(x, ps, st)[1] == maxpool(x, PoolDims(x, 2)) + @jet layer(x, ps, st) + + layer = AdaptiveMeanPool((5, 5)) + __display(layer) + ps, st = Lux.setup(rng, layer) .|> device + + @test layer(x, ps, st)[1] == meanpool(x, PoolDims(x, 2)) + @jet layer(x, ps, st) + + layer = AdaptiveMaxPool((10, 5)) + __display(layer) + ps, st = Lux.setup(rng, layer) .|> device + + @test layer(y, ps, st)[1] == maxpool(y, PoolDims(y, (2, 4))) + @jet layer(y, ps, st) + + layer = AdaptiveMeanPool((10, 5)) + __display(layer) + ps, st = Lux.setup(rng, layer) .|> device + + @test layer(y, ps, st)[1] == meanpool(y, PoolDims(y, (2, 4))) + @jet layer(y, ps, st) + + layer = GlobalMaxPool() + __display(layer) + ps, st = Lux.setup(rng, layer) .|> device + + @test size(layer(x, ps, st)[1]) == (1, 1, 3, 2) + @jet layer(x, ps, st) + + layer = GlobalMeanPool() + __display(layer) + ps, st = Lux.setup(rng, layer) .|> device + + @test size(layer(x, ps, st)[1]) == (1, 1, 3, 2) + @jet layer(x, ps, st) + + layer = MaxPool((2, 2)) + __display(layer) + ps, st = Lux.setup(rng, layer) .|> device + + @test layer(x, ps, st)[1] == maxpool(x, PoolDims(x, 2)) + @jet layer(x, ps, st) + + layer = MeanPool((2, 2)) + __display(layer) + ps, st = Lux.setup(rng, layer) .|> device + + @test layer(x, ps, st)[1] == meanpool(x, PoolDims(x, 2)) + @jet layer(x, ps, st) + + @testset "$ltype SamePad windowsize $k" for ltype in (MeanPool, MaxPool), + k in ((1,), (2,), (3,), (4, 5), (6, 7, 8)) + + x = ones(Float32, (k .+ 3)..., 1, 1) |> aType + + layer = ltype(k; pad=Lux.SamePad()) + __display(layer) + ps, st = Lux.setup(rng, layer) .|> device + + @test size(layer(x, ps, st)[1])[1:(end - 2)] == cld.(size(x)[1:(end - 2)], k) + @jet layer(x, ps, st) + end + end +end + +@testitem "CNN" setup=[SharedTestSetup] begin + rng = get_stable_rng(12345) + + @testset "$mode" for (mode, aType, device, ongpu) in MODES + @testset "Grouped Conv" begin + x = rand(rng, Float32, 4, 6, 1) |> aType + layer = Conv((3,), 6 => 2; groups=2) + __display(layer) + ps, st = Lux.setup(rng, layer) .|> device + + broken = false + try + layer(x, ps, st) + broken = false + catch + @warn "$mode Test broken for $layer" + broken = true + end + + if !broken + @test size(ps.weight) == (3, 3, 2) + @test size(layer(x, ps, st)[1]) == (2, 2, 1) + + @jet layer(x, ps, st) + __f = (x, ps) -> sum(first(layer(x, ps, st))) + @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu + else + @test_broken !broken + end + + x = rand(rng, Float32, 4, 4, 6, 1) |> aType + layer = Conv((3, 3), 6 => 2; groups=2) + __display(layer) + ps, st = Lux.setup(rng, layer) .|> device + + broken = false + try + layer(x, ps, st) + broken = false + catch + @warn "$mode Test broken for $layer" + broken = true + end + + if !broken + @test size(ps.weight) == (3, 3, 3, 2) + @test size(layer(x, ps, st)[1]) == (2, 2, 2, 1) + + @jet layer(x, ps, st) + __f = (x, ps) -> sum(first(layer(x, ps, st))) + @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu + else + @test_broken !broken + end + + x = rand(rng, Float32, 4, 4, 4, 6, 1) |> aType + layer = Conv((3, 3, 3), 6 => 2; groups=2) + __display(layer) + ps, st = Lux.setup(rng, layer) .|> device + + broken = false + try + layer(x, ps, st) + broken = false + catch + @warn "$mode Test broken for $layer" + broken = true + end + + if !broken + @test size(ps.weight) == (3, 3, 3, 3, 2) + @test size(layer(x, ps, st)[1]) == (2, 2, 2, 2, 1) + + @jet layer(x, ps, st) + __f = (x, ps) -> sum(first(layer(x, ps, st))) + @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu + else + @test_broken !broken + end + + # Test that we cannot ask for non-integer multiplication factors + layer = Conv((2, 2), 3 => 10; groups=2) + __display(layer) + @test_throws AssertionError Lux.setup(rng, layer) + layer = Conv((2, 2), 2 => 9; groups=2) + __display(layer) + @test_throws AssertionError Lux.setup(rng, layer) + + @testset "Segfault Test LuxDL/Lux.jl#386" begin + layer = Conv((5,), 32 => 32, tanh; groups=32) + __display(layer) + x = rand(rng, Float32, 16, 32, 1) |> aType + ps, st = Lux.setup(rng, layer) .|> device + + broken = false + try + layer(x, ps, st) + broken = false + catch + @warn "$mode Test broken for $layer" + broken = true + end + + if !broken + @jet layer(x, ps, st) + __f = (x, ps) -> sum(first(layer(x, ps, st))) + @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu + else + @test_broken !broken + end + end + end + + @testset "Asymmetric Padding" begin + layer = Conv((3, 3), 1 => 1, relu; pad=(0, 1, 1, 2)) + __display(layer) + x = ones(Float32, 28, 28, 1, 1) |> aType + ps, st = Lux.setup(rng, layer) .|> device + + ps.weight .= 1.0 + ps.bias .= 0.0 + + y_hat = layer(x, ps, st)[1][:, :, 1, 1] |> Array + @test size(y_hat) == (27, 29) + @test check_approx(y_hat[1, 1], 6.0) + @test check_approx(y_hat[2, 2], 9.0) + @test check_approx(y_hat[end, 1], 4.0) + @test check_approx(y_hat[1, end], 3.0) + @test check_approx(y_hat[1, end - 1], 6.0) + @test check_approx(y_hat[end, end], 2.0) + + @jet layer(x, ps, st) + end + + @testset "Variable BitWidth Parameters FluxML/Flux.jl#1421" begin + layer = Conv((5, 5), + 10 => 20, + identity; + init_weight=(rng, dims...) -> aType(randn(rng, Float64, dims...)), + init_bias=(rng, dims...) -> aType(randn(rng, Float16, dims...))) + __display(layer) + ps, st = Lux.setup(rng, layer) + @test ps.weight isa aType{Float64, 4} + @test ps.bias isa aType{Float16, 4} + end + + @testset "Depthwise Conv" begin + x = randn(rng, Float32, 4, 4, 3, 2) |> aType + + layer = Conv((2, 2), 3 => 15; groups=3) + __display(layer) + ps, st = Lux.setup(rng, layer) .|> device + @test Lux.parameterlength(layer) == Lux.parameterlength(ps) + + @test size(layer(x, ps, st)[1], 3) == 15 + + @jet layer(x, ps, st) + __f = (x, ps) -> sum(first(layer(x, ps, st))) + @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu + + layer = Conv((2, 2), 3 => 9; groups=3) + __display(layer) + ps, st = Lux.setup(rng, layer) .|> device + + @test size(layer(x, ps, st)[1], 3) == 9 + + @jet layer(x, ps, st) + __f = (x, ps) -> sum(first(layer(x, ps, st))) + @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu + + layer = Conv((2, 2), 3 => 9; groups=3, use_bias=false) + __display(layer) + ps, st = Lux.setup(rng, layer) .|> device + @test Lux.parameterlength(layer) == Lux.parameterlength(ps) + + @test size(layer(x, ps, st)[1], 3) == 9 + + @jet layer(x, ps, st) + __f = (x, ps) -> sum(first(layer(x, ps, st))) + @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu + + # Test that we cannot ask for non-integer multiplication factors + layer = Conv((2, 2), 3 => 10; groups=3) + __display(layer) + @test_throws AssertionError Lux.setup(rng, layer) + end + + @testset "Conv SamePad kernelsize $k" for k in ((1,), (2,), (3,), (2, 3), (1, 2, 3)) + x = ones(Float32, (k .+ 3)..., 1, 1) |> aType + + @testset "Kwargs: $kwarg" for kwarg in ((; stride=1), + (; dilation=max.(k .÷ 2, 1), stride=1), + (; stride=3)) + layer = Conv(k, 1 => 1; pad=Lux.SamePad(), kwarg...) + __display(layer) + ps, st = Lux.setup(rng, layer) .|> device + + broken = false + try + layer(x, ps, st) + broken = false + catch + @warn "$mode Test broken for $layer" + broken = true + end + + if !broken + if kwarg.stride == 1 + @test size(layer(x, ps, st)[1]) == size(x) + else + @test size(layer(x, ps, st)[1])[1:(end - 2)] == + cld.(size(x)[1:(end - 2)], kwarg.stride) + end + + @jet layer(x, ps, st) + __f = (x, ps) -> sum(first(layer(x, ps, st))) + @eval @test_gradients $__f $x $ps gpu_testing=$ongpu atol=1e-3 rtol=1e-3 + else + @test_broken !broken + end + end + end + + @testset "Conv with non quadratic window FluxML/Flux.jl#700" begin + x = zeros(Float32, 7, 7, 1, 1) + x[4, 4, 1, 1] = 1 + x = x |> aType + + layer = Conv((3, 3), 1 => 1) + __display(layer) + ps, st = Lux.setup(rng, layer) .|> device + + y = zeros(eltype(ps.weight), 5, 5, 1, 1) |> aType + y[2:(end - 1), 2:(end - 1), 1, 1] = ps.weight + @test check_approx(y, layer(x, ps, st)[1]) + + @jet layer(x, ps, st) + + layer = Conv((3, 1), 1 => 1) + __display(layer) + ps, st = Lux.setup(rng, layer) .|> device + + y = zeros(eltype(ps.weight), 5, 7, 1, 1) |> aType + y[2:(end - 1), 4, 1, 1] = ps.weight + @test check_approx(y, layer(x, ps, st)[1]) + + @jet layer(x, ps, st) + + layer = Conv((1, 3), 1 => 1) + __display(layer) + ps, st = Lux.setup(rng, layer) .|> device + + y = zeros(eltype(ps.weight), 7, 5, 1, 1) |> aType + y[4, 2:(end - 1), 1, 1] = ps.weight + @test check_approx(y, layer(x, ps, st)[1]) + + @jet layer(x, ps, st) + + layer = Conv((1, 3), 1 => 1; init_weight=Lux.glorot_normal) + __display(layer) + ps, st = Lux.setup(rng, layer) .|> device + + y = zeros(eltype(ps.weight), 7, 5, 1, 1) |> aType + y[4, 2:(end - 1), 1, 1] = ps.weight + @test check_approx(y, layer(x, ps, st)[1]) + + @jet layer(x, ps, st) + end + + @testset "allow fast activation" begin + layer = Conv((3, 3), 1 => 1, tanh) + @test layer.activation == tanh_fast + layer = Conv((3, 3), 1 => 1, tanh; allow_fast_activation=false) + @test layer.activation == tanh + end + end +end + +@testitem "Upsample" setup=[SharedTestSetup] begin + rng = get_stable_rng(12345) + + @testset "$mode" for (mode, aType, device, ongpu) in MODES + @testset "Construction" begin + @test_nowarn Upsample(:nearest; scale=2) + @test_nowarn Upsample(:nearest; size=(64, 64)) + @test_nowarn Upsample(:bilinear; scale=2) + @test_nowarn Upsample(:bilinear; size=(64, 64)) + @test_nowarn Upsample(:trilinear; scale=2) + @test_nowarn Upsample(:trilinear; size=(64, 64)) + + @test_throws ArgumentError Upsample(:linear; scale=2) + @test_throws ArgumentError Upsample(:nearest; scale=2, size=(64, 64)) + @test_throws ArgumentError Upsample(:nearest) + + @test_nowarn Upsample(2) + @test_nowarn Upsample(2, :nearest) + end + + @testset "Size Correctness" begin + # NNlib is checking algorithmic correctness. So we should just verify correct + # function call + modes = (:nearest, :bilinear, :trilinear) + sizes = (nothing, (64, 64), (64, 32)) + scales = (nothing, 2, (2, 1)) + + for umode in modes, xsize in sizes, scale in scales + if !xor(isnothing(xsize), isnothing(scale)) + continue + end + layer = Upsample(umode; size=xsize, scale=scale) + __display(layer) + ps, st = Lux.setup(rng, layer) .|> device + x = zeros((32, 32, 3, 4)) |> aType + + @jet layer(x, ps, st) + + y, _ = layer(x, ps, st) + if isnothing(scale) + @test size(y)[1:2] == xsize + else + @test size(y)[1:2] == size(x)[1:2] .* scale + end + @test size(y)[3:4] == size(x)[3:4] + end + + sizes = (nothing, (64, 64, 64), (64, 32, 128)) + scales = (nothing, 2, (2, 1, 1), (2, 2, 1)) + + for umode in modes, xsize in sizes, scale in scales + if !xor(isnothing(xsize), isnothing(scale)) + continue + end + layer = Upsample(umode; size=xsize, scale=scale) + __display(layer) + ps, st = Lux.setup(rng, layer) .|> device + x = zeros((32, 32, 32, 3, 4)) |> aType + + @jet layer(x, ps, st) + + y, _ = layer(x, ps, st) + + if isnothing(scale) + @test size(y)[1:3] == xsize + else + @test size(y)[1:3] == size(x)[1:3] .* scale + end + @test size(y)[4:5] == size(x)[4:5] + end + end + end +end + +@testitem "PixelShuffle" setup=[SharedTestSetup] begin + rng = get_stable_rng(12345) + + @testset "$mode" for (mode, aType, device, ongpu) in MODES + layer = PixelShuffle(2) + __display(layer) + ps, st = Lux.setup(rng, layer) .|> device + x = rand(rng, Float32, 3, 6, 3) |> aType + + y, st_ = layer(x, ps, st) + @test y isa aType{Float32, 3} + @test size(y) == (6, 3, 3) + + @jet layer(x, ps, st) + __f = x -> sum(first(layer(x, ps, st))) + @eval @test_gradients $__f $x gpu_testing=$ongpu atol=1e-3 rtol=1e-3 + + layer = PixelShuffle(3) + __display(layer) + ps, st = Lux.setup(rng, layer) .|> device + x = rand(Float32, 3, 4, 9, 3) |> aType + + y, st_ = layer(x, ps, st) + @test y isa aType{Float32, 4} + @test size(y) == (9, 12, 1, 3) + + @jet layer(x, ps, st) + __f = x -> sum(first(layer(x, ps, st))) + @eval @test_gradients $__f $x gpu_testing=$ongpu atol=1e-3 rtol=1e-3 + end +end + +@testitem "CrossCor" setup=[SharedTestSetup] begin + rng = get_stable_rng(12345) + + @testset "$mode" for (mode, aType, device, ongpu) in MODES + @testset "Asymmetric Padding" begin + layer = CrossCor((3, 3), 1 => 1, relu; pad=(0, 1, 1, 2)) + __display(layer) + x = ones(Float32, 28, 28, 1, 1) |> aType + ps, st = Lux.setup(rng, layer) .|> device + + ps.weight .= 1.0 + ps.bias .= 0.0 + + y_hat = layer(x, ps, st)[1][:, :, 1, 1] |> Array + @test size(y_hat) == (27, 29) + @test check_approx(y_hat[1, 1], 6.0) + @test check_approx(y_hat[2, 2], 9.0) + @test check_approx(y_hat[end, 1], 4.0) + @test check_approx(y_hat[1, end], 3.0) + @test check_approx(y_hat[1, end - 1], 6.0) + @test check_approx(y_hat[end, end], 2.0) + + @jet layer(x, ps, st) + end + + @testset "Variable BitWidth Parameters FluxML/Flux.jl#1421" begin + layer = CrossCor((5, 5), + 10 => 20, + identity; + init_weight=(rng, dims...) -> aType(randn(rng, Float64, dims...)), + init_bias=(rng, dims...) -> aType(randn(rng, Float16, dims...))) + __display(layer) + ps, st = Lux.setup(rng, layer) + @test ps.weight isa aType{Float64, 4} + @test ps.bias isa aType{Float16, 4} + end + + @testset "CrossCor SamePad kernelsize $k" for k in ((1,), (2,), (3,), (2, 3), + (1, 2, 3)) + x = ones(Float32, (k .+ 3)..., 1, 1) |> aType + + @testset "Kwargs: $kwarg" for kwarg in ((; stride=1), + (; dilation=max.(k .÷ 2, 1), stride=1), + (; stride=3)) + layer = CrossCor(k, 1 => 1; pad=Lux.SamePad(), kwarg...) + __display(layer) + ps, st = Lux.setup(rng, layer) .|> device + + broken = false + try + layer(x, ps, st) + broken = false + catch + @warn "$mode Test broken for $layer" + broken = true + end + + if !broken + if kwarg.stride == 1 + @test size(layer(x, ps, st)[1]) == size(x) + else + @test size(layer(x, ps, st)[1])[1:(end - 2)] == + cld.(size(x)[1:(end - 2)], kwarg.stride) + end + + @jet layer(x, ps, st) + __f = (x, ps) -> sum(first(layer(x, ps, st))) + @eval @test_gradients $__f $x $ps gpu_testing=$ongpu atol=1e-3 rtol=1e-3 + else + @test_broken !broken + end + end + end + + @testset "allow fast activation" begin + layer = CrossCor((3, 3), 1 => 1, tanh) + @test layer.activation == tanh_fast + layer = CrossCor((3, 3), 1 => 1, tanh; allow_fast_activation=false) + @test layer.activation == tanh + end + end +end + +@testitem "ConvTranspose" setup=[SharedTestSetup] begin + rng = get_stable_rng(12345) + + @testset "$mode" for (mode, aType, device, ongpu) in MODES + x = randn(Float32, 5, 5, 1, 1) |> aType + layer = Conv((3, 3), 1 => 1) + ps, st = Lux.setup(rng, layer) .|> device + y = layer(x, ps, st)[1] + + layer = ConvTranspose((3, 3), 1 => 1) + __display(layer) + ps, st = Lux.setup(rng, layer) .|> device + + @jet layer(y, ps, st) opt_broken=true + + x_hat1 = layer(y, ps, st)[1] + + layer = ConvTranspose((3, 3), 1 => 1; use_bias=false) + __display(layer) + ps, st = Lux.setup(rng, layer) .|> device + + @jet layer(y, ps, st) opt_broken=true + + x_hat2 = layer(y, ps, st)[1] + + @test size(x_hat1) == size(x_hat2) == size(x) + + layer = ConvTranspose((3, 3), 1 => 1) + __display(layer) + ps, st = Lux.setup(rng, layer) .|> device + x = rand(Float32, 5, 5, 1, 1) |> aType + + @jet layer(x, ps, st) opt_broken=true + __f = (x, ps) -> sum(first(layer(x, ps, st))) + @eval @test_gradients $__f $x $ps gpu_testing=$ongpu atol=1e-3 rtol=1e-3 + + x = rand(Float32, 5, 5, 2, 4) |> aType + layer = ConvTranspose((3, 3), 2 => 3) + __display(layer) + ps, st = Lux.setup(rng, layer) .|> device + + @jet layer(x, ps, st) opt_broken=true + __f = (x, ps) -> sum(first(layer(x, ps, st))) + @eval @test_gradients $__f $x $ps gpu_testing=$ongpu atol=1e-3 rtol=1e-3 + + # test ConvTranspose supports groups argument + x = randn(Float32, 10, 10, 2, 3) |> aType + layer1 = ConvTranspose((3, 3), 2 => 4; pad=SamePad()) + __display(layer1) + ps1, st1 = Lux.setup(rng, layer1) .|> device + @test size(ps1.weight) == (3, 3, 4, 2) + @test size(layer1(x, ps1, st1)[1]) == (10, 10, 4, 3) + + layer2 = ConvTranspose((3, 3), 2 => 4; groups=2, pad=SamePad()) + __display(layer2) + ps2, st2 = Lux.setup(rng, layer2) .|> device + @test size(ps2.weight) == (3, 3, 2, 2) + @test size(layer1(x, ps1, st1)[1]) == size(layer2(x, ps2, st2)[1]) + + __f = (x, ps) -> sum(first(layer1(x, ps, st1))) + @eval @test_gradients $__f $x $ps1 gpu_testing=$ongpu atol=1e-3 rtol=1e-3 + + __f = (x, ps) -> sum(first(layer2(x, ps, st2))) + @eval @test_gradients $__f $x $ps2 gpu_testing=$ongpu atol=1e-3 rtol=1e-3 + + x = randn(Float32, 10, 2, 1) |> aType + layer = ConvTranspose((3,), 2 => 4; pad=SamePad(), groups=2) + __display(layer) + ps, st = Lux.setup(rng, layer) .|> device + + @jet layer(x, ps, st) opt_broken=true + + @test size(layer(x, ps, st)[1]) == (10, 4, 1) + @test length(ps.weight) == 3 * (2 * 4) / 2 + + __f = (x, ps) -> sum(first(layer(x, ps, st))) + @eval @test_gradients $__f $x $ps gpu_testing=$ongpu atol=1e-3 rtol=1e-3 + + x = randn(Float32, 10, 11, 4, 2) |> aType + layer = ConvTranspose((3, 5), 4 => 4; pad=SamePad(), groups=4) + __display(layer) + ps, st = Lux.setup(rng, layer) .|> device + + @jet layer(x, ps, st) opt_broken=true + + @test size(layer(x, ps, st)[1]) == (10, 11, 4, 2) + @test length(ps.weight) == (3 * 5) * (4 * 4) / 4 + + __f = (x, ps) -> sum(first(layer(x, ps, st))) + @eval @test_gradients $__f $x $ps gpu_testing=$ongpu atol=1e-3 rtol=1e-3 + + x = randn(Float32, 10, 11, 4, 2) |> aType + layer = ConvTranspose((3, 5), 4 => 4, tanh; pad=SamePad(), groups=4) + __display(layer) + ps, st = Lux.setup(rng, layer) .|> device + + @jet layer(x, ps, st) opt_broken=true + @test size(layer(x, ps, st)[1]) == (10, 11, 4, 2) + @test length(ps.weight) == (3 * 5) * (4 * 4) / 4 + + __f = (x, ps) -> sum(first(layer(x, ps, st))) + @eval @test_gradients $__f $x $ps gpu_testing=$ongpu atol=1e-3 rtol=1e-3 + + x = randn(Float32, 10, 11, 12, 3, 2) |> aType + layer = ConvTranspose((3, 5, 3), 3 => 6; pad=SamePad(), groups=3) + __display(layer) + ps, st = Lux.setup(rng, layer) .|> device + + @jet layer(x, ps, st) opt_broken=true + @test size(layer(x, ps, st)[1]) == (10, 11, 12, 6, 2) + @test length(ps.weight) == (3 * 5 * 3) * (3 * 6) / 3 + + x = randn(Float32, 10, 11, 12, 3, 2) |> aType + layer = ConvTranspose((3, 5, 3), 3 => 6, tanh; pad=SamePad(), groups=3) + __display(layer) + ps, st = Lux.setup(rng, layer) .|> device + + @jet layer(x, ps, st) opt_broken=true + @test size(layer(x, ps, st)[1]) == (10, 11, 12, 6, 2) + @test length(ps.weight) == (3 * 5 * 3) * (3 * 6) / 3 + + @test occursin("groups=2", sprint(show, ConvTranspose((3, 3), 2 => 4; groups=2))) + @test occursin("2 => 4", sprint(show, ConvTranspose((3, 3), 2 => 4; groups=2))) + end +end diff --git a/test/layers/dropout.jl b/test/layers/dropout.jl deleted file mode 100644 index 9665b75988..0000000000 --- a/test/layers/dropout.jl +++ /dev/null @@ -1,96 +0,0 @@ -using Lux, Random, Test - -include("../test_utils.jl") - -rng = Random.default_rng() -Random.seed!(rng, 0) - -@testset "$mode: Dropout" for (mode, aType, device, ongpu) in MODES - for p in (0.5f0, 0.5) - layer = Dropout(p) - __display(layer) - ps, st = Lux.setup(rng, layer) .|> device - x = randn(Float32, 5, 2) |> aType - - x_, st_ = layer(x, ps, st) - x__, st__ = layer(x, ps, st) - x___, st___ = layer(x_, ps, st_) - - @test st_.rng != st.rng - @test st_.rng == st__.rng - @test x_ == x__ - @test x_ != x___ - - @jet layer(x, ps, st) - __f = x -> sum(first(layer(x, ps, st))) - @eval @test_gradients $__f $x atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu - - st = Lux.testmode(st) - - @test first(layer(x, ps, st)) == x - end -end - -@testset "$mode: AlphaDropout" for (mode, aType, device, ongpu) in MODES - for p in (0.5f0, 0.5) - layer = AlphaDropout(p) - __display(layer) - ps, st = Lux.setup(rng, layer) .|> device - # GPU compilation for mixed types fail atm - x = randn(typeof(p), 5, 2) |> aType - - x_, st_ = layer(x, ps, st) - x__, st__ = layer(x, ps, st) - x___, st___ = layer(x_, ps, st_) - - @test st_.rng != st.rng - @test st_.rng == st__.rng - @test x_ == x__ - @test x_ != x___ - - @jet layer(x, ps, st) - __f = x -> sum(first(layer(x, ps, st))) - @eval @test_gradients $__f $x atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu - - st = Lux.testmode(st) - - @test first(layer(x, ps, st)) == x - end -end - -@testset "$mode: VariationalHiddenDropout" for (mode, aType, device, ongpu) in MODES - for p in (0.5f0, 0.5) - layer = VariationalHiddenDropout(p) - __display(layer) - ps, st = Lux.setup(rng, layer) .|> device - x = randn(Float32, 5, 2) |> aType - - x_, st_ = layer(x, ps, st) - x__, st__ = layer(x, ps, st) - x___, st___ = layer(x_, ps, st_) - - @test st_.rng != st.rng - @test st_.rng == st__.rng - @test st_.mask == st__.mask - @test x_ == x__ - @test x_ != x___ - - @jet layer(x, ps, st) - __f = x -> sum(first(layer(x, ps, st))) - @eval @test_gradients $__f $x atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu - - @jet layer(x, ps, st_) - __f = x -> sum(first(layer(x, ps, st_))) - @eval @test_gradients $__f $x atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu - - st__ = Lux.update_state(st_, :update_mask, Val(true)) - x___, st___ = layer(x, ps, st__) - - @test st___.mask != st__.mask - @test x___ != x_ - - @jet layer(x, ps, st__) - __f = x -> sum(first(layer(x, ps, st__))) - @eval @test_gradients $__f $x atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu - end -end diff --git a/test/layers/dropout_tests.jl b/test/layers/dropout_tests.jl new file mode 100644 index 0000000000..b45cb1ad5c --- /dev/null +++ b/test/layers/dropout_tests.jl @@ -0,0 +1,101 @@ +@testitem "Dropout" setup=[SharedTestSetup] begin + rng = get_stable_rng(12345) + + @testset "$mode" for (mode, aType, device, ongpu) in MODES + for p in (0.5f0, 0.5) + layer = Dropout(p) + __display(layer) + ps, st = Lux.setup(rng, layer) .|> device + x = randn(Float32, 5, 2) |> aType + + x_, st_ = layer(x, ps, st) + x__, st__ = layer(x, ps, st) + x___, st___ = layer(x_, ps, st_) + + @test st_.rng != st.rng + @test st_.rng == st__.rng + @test x_ == x__ + @test x_ != x___ + + @jet layer(x, ps, st) + __f = x -> sum(first(layer(x, ps, st))) + @eval @test_gradients $__f $x atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu + + st = Lux.testmode(st) + + @test first(layer(x, ps, st)) == x + end + end +end + +@testitem "AlphaDropout" setup=[SharedTestSetup] begin + rng = get_stable_rng(12345) + + @testset "$mode" for (mode, aType, device, ongpu) in MODES + for p in (0.5f0, 0.5) + layer = AlphaDropout(p) + __display(layer) + ps, st = Lux.setup(rng, layer) .|> device + # GPU compilation for mixed types fail atm + x = randn(typeof(p), 5, 2) |> aType + + x_, st_ = layer(x, ps, st) + x__, st__ = layer(x, ps, st) + x___, st___ = layer(x_, ps, st_) + + @test st_.rng != st.rng + @test st_.rng == st__.rng + @test x_ == x__ + @test x_ != x___ + + @jet layer(x, ps, st) + __f = x -> sum(first(layer(x, ps, st))) + @eval @test_gradients $__f $x atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu + + st = Lux.testmode(st) + + @test first(layer(x, ps, st)) == x + end + end +end + +@testitem "VariationalHiddenDropout" setup=[SharedTestSetup] begin + rng = get_stable_rng(12345) + + @testset "$mode" for (mode, aType, device, ongpu) in MODES + for p in (0.5f0, 0.5) + layer = VariationalHiddenDropout(p) + __display(layer) + ps, st = Lux.setup(rng, layer) .|> device + x = randn(Float32, 5, 2) |> aType + + x_, st_ = layer(x, ps, st) + x__, st__ = layer(x, ps, st) + x___, st___ = layer(x_, ps, st_) + + @test st_.rng != st.rng + @test st_.rng == st__.rng + @test st_.mask == st__.mask + @test x_ == x__ + @test x_ != x___ + + @jet layer(x, ps, st) + __f = x -> sum(first(layer(x, ps, st))) + @eval @test_gradients $__f $x atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu + + @jet layer(x, ps, st_) + __f = x -> sum(first(layer(x, ps, st_))) + @eval @test_gradients $__f $x atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu + + st__ = Lux.update_state(st_, :update_mask, Val(true)) + x___, st___ = layer(x, ps, st__) + + @test st___.mask != st__.mask + @test x___ != x_ + + @jet layer(x, ps, st__) + __f = x -> sum(first(layer(x, ps, st__))) + @eval @test_gradients $__f $x atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu + end + end +end diff --git a/test/layers/normalize.jl b/test/layers/normalize.jl deleted file mode 100644 index bd61c34a5b..0000000000 --- a/test/layers/normalize.jl +++ /dev/null @@ -1,405 +0,0 @@ -using Lux, Statistics, Zygote - -include("../test_utils.jl") - -rng = get_stable_rng(12345) - -@testset "$mode: BatchNorm" for (mode, aType, device, ongpu) in MODES - m = BatchNorm(2) - x = [1.0f0 3.0f0 5.0f0 - 2.0f0 4.0f0 6.0f0] |> aType - __display(m) - ps, st = Lux.setup(rng, m) .|> device - - @test Lux.parameterlength(m) == Lux.parameterlength(ps) - @test Lux.statelength(m) == Lux.statelength(st) - - @test ps.bias == [0, 0] |> aType # init_bias(2) - @test ps.scale == [1, 1] |> aType # init_scale(2) - - y, st_ = pullback(m, x, ps, st)[1] - st_ = st_ |> LuxCPUDevice() - @test check_approx(Array(y), [-1.22474 0 1.22474; -1.22474 0 1.22474]; atol=1.0e-5) - # julia> x - # 2×3 Array{Float64,2}: - # 1.0 3.0 5.0 - # 2.0 4.0 6.0 - - # mean of batch will be - # (1. + 3. + 5.) / 3 = 3 - # (2. + 4. + 6.) / 3 = 4 - - # ∴ update rule with momentum: - # .1 * 3 + 0 = .3 - # .1 * 4 + 0 = .4 - @test check_approx(st_.running_mean, reshape([0.3, 0.4], 2, 1)) - - # julia> .1 .* var(x, dims = 2, corrected=false) .* (3 / 2).+ .9 .* [1., 1.] - # 2×1 Array{Float64,2}: - # 1.3 - # 1.3 - @test check_approx(st_.running_var, - 0.1 .* var(Array(x); dims=2, corrected=false) .* (3 / 2) .+ 0.9 .* [1.0, 1.0]) - - st_ = Lux.testmode(st_) |> device - x_ = m(x, ps, st_)[1] |> LuxCPUDevice() - @test check_approx(x_[1], (1 .- 0.3) / sqrt(1.3), atol=1.0e-5) - - @jet m(x, ps, st) - __f = (x, ps) -> sum(first(m(x, ps, st))) - @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu skip_finite_differences=true - - for affine in (true, false) - m = BatchNorm(2; affine, track_stats=false) - x = [1.0f0 3.0f0 5.0f0; 2.0f0 4.0f0 6.0f0] |> aType - __display(m) - ps, st = Lux.setup(rng, m) .|> device - - @jet m(x, ps, st) - - if affine - __f = (x, ps) -> sum(first(m(x, ps, st))) - @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu skip_finite_differences=true - else - __f = x -> sum(first(m(x, ps, st))) - @eval @test_gradients $__f $x atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu skip_finite_differences=true - end - - # with activation function - m = BatchNorm(2, sigmoid; affine) - x = [1.0f0 3.0f0 5.0f0 - 2.0f0 4.0f0 6.0f0] |> aType - __display(m) - ps, st = Lux.setup(rng, m) .|> device - st = Lux.testmode(st) - y, st_ = m(x, ps, st) - @test check_approx(y, - sigmoid.((x .- st_.running_mean) ./ sqrt.(st_.running_var .+ m.epsilon)), - atol=1.0e-7) - - @jet m(x, ps, st) - - if affine - __f = (x, ps) -> sum(first(m(x, ps, st))) - @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu skip_finite_differences=true - else - __f = x -> sum(first(m(x, ps, st))) - @eval @test_gradients $__f $x atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu skip_finite_differences=true - end - - m = BatchNorm(32; affine) - x = randn(Float32, 416, 416, 32, 1) |> aType - __display(m) - ps, st = Lux.setup(rng, m) .|> device - st = Lux.testmode(st) - m(x, ps, st) - @test (@allocated m(x, ps, st)) < 100_000_000 - - @jet m(x, ps, st) - end - - @testset "allow fast activation" begin - layer = BatchNorm(10, tanh) - @test layer.activation == tanh_fast - layer = BatchNorm(10, tanh; allow_fast_activation=false) - @test layer.activation == tanh - end -end - -@testset "$mode: GroupNorm" for (mode, aType, device, ongpu) in MODES - squeeze(x) = dropdims(x; dims=tuple(findall(size(x) .== 1)...)) # To remove all singular dimensions - - m = GroupNorm(4, 2) - sizes = (3, 4, 2) - x = reshape(collect(1:prod(sizes)), sizes) |> aType - - __display(m) - x = Float32.(x) - ps, st = Lux.setup(rng, m) .|> device - @test Lux.parameterlength(m) == Lux.parameterlength(ps) - @test Lux.statelength(m) == Lux.statelength(st) - @test ps.bias == [0, 0, 0, 0] |> aType # init_bias(32) - @test ps.scale == [1, 1, 1, 1] |> aType # init_scale(32) - - y, st_ = pullback(m, x, ps, st)[1] - - @jet m(x, ps, st) - __f = ps -> sum(first(m(x, ps, st))) - @eval @test_gradients $__f $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu - - @testset "affine: $affine" for affine in (true, false) - m = GroupNorm(2, 2; affine) - x = rand(rng, Float32, 3, 2, 1) |> aType - __display(m) - ps, st = Lux.setup(rng, m) .|> device - - @jet m(x, ps, st) - - if affine - __f = (x, ps) -> sum(first(m(x, ps, st))) - @eval @test_gradients $__f $x $ps atol=1.0f-2 rtol=1.0f-2 gpu_testing=$ongpu skip_finite_differences=true - else - __f = x -> sum(first(m(x, ps, st))) - @eval @test_gradients $__f $x atol=1.0f-2 rtol=1.0f-2 gpu_testing=$ongpu skip_finite_differences=true - end - - # with activation function - m = GroupNorm(2, 2, sigmoid; affine) - x = randn(rng, Float32, 3, 2, 1) |> aType - __display(m) - ps, st = Lux.setup(rng, m) .|> device - st = Lux.testmode(st) - y, st_ = m(x, ps, st) - - @jet m(x, ps, st) - - if affine - __f = (x, ps) -> sum(first(m(x, ps, st))) - @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu skip_finite_differences=true - else - __f = x -> sum(first(m(x, ps, st))) - @eval @test_gradients $__f $x atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu skip_finite_differences=true - end - - m = GroupNorm(32, 16; affine) - x = randn(rng, Float32, 416, 416, 32, 1) |> aType - __display(m) - ps, st = Lux.setup(rng, m) .|> device - st = Lux.testmode(st) - m(x, ps, st) - - @test (@allocated m(x, ps, st)) < 100_000_000 - - if affine - LuxTestUtils.JET.@test_opt target_modules=(LuxLib,) m(x, ps, st) - else - @jet m(x, ps, st) - end - end - - @test_throws AssertionError GroupNorm(5, 2) - - @testset "allow fast activation" begin - layer = GroupNorm(10, 2, tanh) - @test layer.activation == tanh_fast - layer = GroupNorm(10, 2, tanh; allow_fast_activation=false) - @test layer.activation == tanh - end -end - -@testset "$mode: WeightNorm" for (mode, aType, device, ongpu) in MODES - @testset "_norm_except" begin - z = randn(rng, Float32, 3, 3, 4, 2) |> aType - - @test size(Lux._norm(z; dims=(1, 2))) == (1, 1, 4, 2) - @test size(Lux._norm_except(z; dims=1)) == (3, 1, 1, 1) - @test Lux._norm_except(z; dims=2) == Lux._norm(z; dims=(1, 3, 4)) - @test size(Lux._norm_except(z; dims=(1, 2))) == (3, 3, 1, 1) - @test Lux._norm_except(z; dims=(1, 2)) == Lux._norm(z; dims=(3, 4)) - - @jet Lux._norm_except(z) - __f = z -> sum(Lux._norm_except(z; dims=(3, 4))) - @jet __f(z) - end - - @testset "Conv" begin - c = Conv((3, 3), 3 => 3; init_bias=Lux.ones32) - - wn = WeightNorm(c, (:weight, :bias)) - __display(wn) - ps, st = Lux.setup(rng, wn) .|> device - x = randn(rng, Float32, 3, 3, 3, 1) |> aType - - @jet wn(x, ps, st) - __f = ps -> sum(first(wn(x, ps, st))) - @eval @test_gradients $__f $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu skip_reverse_diff=true - - wn = WeightNorm(c, (:weight,)) - __display(wn) - ps, st = Lux.setup(rng, wn) .|> device - x = randn(rng, Float32, 3, 3, 3, 1) |> aType - - @jet wn(x, ps, st) - __f = (x, ps) -> sum(first(wn(x, ps, st))) - @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu skip_reverse_diff=true - - wn = WeightNorm(c, (:weight, :bias), (2, 2)) - __display(wn) - ps, st = Lux.setup(rng, wn) .|> device - x = randn(rng, Float32, 3, 3, 3, 1) |> aType - - @jet wn(x, ps, st) - __f = (x, ps) -> sum(first(wn(x, ps, st))) - @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu skip_reverse_diff=true - - wn = WeightNorm(c, (:weight,), (2,)) - __display(wn) - ps, st = Lux.setup(rng, wn) .|> device - x = randn(rng, Float32, 3, 3, 3, 1) |> aType - - @jet wn(x, ps, st) - __f = (x, ps) -> sum(first(wn(x, ps, st))) - @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu skip_reverse_diff=true - end - - @testset "Dense" begin - d = Dense(3 => 3; init_bias=Lux.ones32) - - wn = WeightNorm(d, (:weight, :bias)) - __display(wn) - ps, st = Lux.setup(rng, wn) .|> device - x = randn(rng, Float32, 3, 1) |> aType - - @jet wn(x, ps, st) - __f = ps -> sum(first(wn(x, ps, st))) - @eval @test_gradients $__f $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu - - wn = WeightNorm(d, (:weight,)) - __display(wn) - ps, st = Lux.setup(rng, wn) .|> device - x = randn(rng, Float32, 3, 1) |> aType - - @jet wn(x, ps, st) - __f = (x, ps) -> sum(first(wn(x, ps, st))) - @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu - - wn = WeightNorm(d, (:weight, :bias), (2, 2)) - __display(wn) - ps, st = Lux.setup(rng, wn) .|> device - x = randn(rng, Float32, 3, 1) |> aType - - @jet wn(x, ps, st) - __f = (x, ps) -> sum(first(wn(x, ps, st))) - @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu - - wn = WeightNorm(d, (:weight,), (2,)) - __display(wn) - ps, st = Lux.setup(rng, wn) .|> device - x = randn(rng, Float32, 3, 1) |> aType - - @jet wn(x, ps, st) - __f = (x, ps) -> sum(first(wn(x, ps, st))) - @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu - end - - # See https://github.com/avik-pal/Lux.jl/issues/95 - @testset "Normalizing Zero Parameters" begin - c = Conv((3, 3), 3 => 3) - - wn = WeightNorm(c, (:weight, :bias)) - @test_throws ArgumentError Lux.setup(rng, wn) - - wn = WeightNorm(c, (:weight,)) - @test_nowarn Lux.setup(rng, wn) - - c = Conv((3, 3), 3 => 3; init_bias=Lux.ones32) - - wn = WeightNorm(c, (:weight, :bias)) - @test_nowarn Lux.setup(rng, wn) - - wn = WeightNorm(c, (:weight,)) - @test_nowarn Lux.setup(rng, wn) - end -end - -@testset "$mode: LayerNorm" for (mode, aType, device, ongpu) in MODES - x = randn(rng, Float32, 3, 3, 3, 2) |> aType - - for bshape in ((3, 3, 3), (1, 3, 1), (3, 1, 3)) - for affine in (true, false) - ln = LayerNorm(bshape; affine) - __display(ln) - ps, st = Lux.setup(rng, ln) .|> device - - y, st_ = ln(x, ps, st) - - @test check_approx(mean(y), 0; atol=1.0f-3, rtol=1.0f-3) - @test check_approx(std(y), 1; atol=1.0f-2, rtol=1.0f-2) - - @jet ln(x, ps, st) - - if affine - __f = (x, ps) -> sum(first(ln(x, ps, st))) - @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu skip_finite_differences=true - else - __f = x -> sum(first(ln(x, ps, st))) - @eval @test_gradients $__f $x atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu skip_finite_differences=true - end - - for act in (sigmoid, tanh) - ln = LayerNorm(bshape, act; affine) - __display(ln) - ps, st = Lux.setup(rng, ln) .|> device - - y, st_ = ln(x, ps, st) - - @jet ln(x, ps, st) - - if affine - __f = (x, ps) -> sum(first(ln(x, ps, st))) - @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu skip_finite_differences=true - else - __f = x -> sum(first(ln(x, ps, st))) - @eval @test_gradients $__f $x atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu skip_finite_differences=true - end - end - end - end - - @testset "allow fast activation" begin - layer = LayerNorm((3, 1), tanh) - @test layer.activation == tanh_fast - layer = LayerNorm((3, 1), tanh; allow_fast_activation=false) - @test layer.activation == tanh - end -end - -@testset "$mode: InstanceNorm" for (mode, aType, device, ongpu) in MODES - for x in (randn(rng, Float32, 3, 3, 3, 2), - randn(rng, Float32, 3, 3, 2), - randn(rng, Float32, 3, 3, 3, 3, 2)) - x = x |> aType - for affine in (true, false) - layer = InstanceNorm(3; affine) - __display(layer) - ps, st = Lux.setup(rng, layer) .|> device - - y, st_ = layer(x, ps, st) - - @jet layer(x, ps, st) - - if affine - __f = (x, ps) -> sum(first(layer(x, ps, st))) - @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu skip_finite_differences=true - else - __f = x -> sum(first(layer(x, ps, st))) - @eval @test_gradients $__f $x atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu skip_finite_differences=true - end - - for act in (sigmoid, tanh) - layer = InstanceNorm(3, act; affine) - __display(layer) - ps, st = Lux.setup(rng, layer) .|> device - - y, st_ = layer(x, ps, st) - - @jet layer(x, ps, st) - - if affine - __f = (x, ps) -> sum(first(layer(x, ps, st))) - @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu skip_finite_differences=true - else - __f = x -> sum(first(layer(x, ps, st))) - @eval @test_gradients $__f $x atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu skip_finite_differences=true - end - end - end - end - - @testset "allow fast activation" begin - layer = InstanceNorm(3, tanh) - @test layer.activation == tanh_fast - layer = InstanceNorm(3, tanh; allow_fast_activation=false) - @test layer.activation == tanh - end -end diff --git a/test/layers/normalize_tests.jl b/test/layers/normalize_tests.jl new file mode 100644 index 0000000000..216297c292 --- /dev/null +++ b/test/layers/normalize_tests.jl @@ -0,0 +1,419 @@ +@testitem "BatchNorm" setup=[SharedTestSetup] begin + rng = get_stable_rng(12345) + + @testset "$mode" for (mode, aType, device, ongpu) in MODES + m = BatchNorm(2) + x = [1.0f0 3.0f0 5.0f0 + 2.0f0 4.0f0 6.0f0] |> aType + __display(m) + ps, st = Lux.setup(rng, m) .|> device + + @test Lux.parameterlength(m) == Lux.parameterlength(ps) + @test Lux.statelength(m) == Lux.statelength(st) + + @test ps.bias == [0, 0] |> aType # init_bias(2) + @test ps.scale == [1, 1] |> aType # init_scale(2) + + y, st_ = pullback(m, x, ps, st)[1] + st_ = st_ |> LuxCPUDevice() + @test check_approx(Array(y), [-1.22474 0 1.22474; -1.22474 0 1.22474]; atol=1.0e-5) + # julia> x + # 2×3 Array{Float64,2}: + # 1.0 3.0 5.0 + # 2.0 4.0 6.0 + + # mean of batch will be + # (1. + 3. + 5.) / 3 = 3 + # (2. + 4. + 6.) / 3 = 4 + + # ∴ update rule with momentum: + # .1 * 3 + 0 = .3 + # .1 * 4 + 0 = .4 + @test check_approx(st_.running_mean, reshape([0.3, 0.4], 2, 1)) + + # julia> .1 .* var(x, dims = 2, corrected=false) .* (3 / 2).+ .9 .* [1., 1.] + # 2×1 Array{Float64,2}: + # 1.3 + # 1.3 + @test check_approx(st_.running_var, + 0.1 .* var(Array(x); dims=2, corrected=false) .* (3 / 2) .+ 0.9 .* [1.0, 1.0]) + + st_ = Lux.testmode(st_) |> device + x_ = m(x, ps, st_)[1] |> LuxCPUDevice() + @test check_approx(x_[1], (1 .- 0.3) / sqrt(1.3), atol=1.0e-5) + + @jet m(x, ps, st) + __f = (x, ps) -> sum(first(m(x, ps, st))) + @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu skip_finite_differences=true + + for affine in (true, false) + m = BatchNorm(2; affine, track_stats=false) + x = [1.0f0 3.0f0 5.0f0; 2.0f0 4.0f0 6.0f0] |> aType + __display(m) + ps, st = Lux.setup(rng, m) .|> device + + @jet m(x, ps, st) + + if affine + __f = (x, ps) -> sum(first(m(x, ps, st))) + @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu skip_finite_differences=true + else + __f = x -> sum(first(m(x, ps, st))) + @eval @test_gradients $__f $x atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu skip_finite_differences=true + end + + # with activation function + m = BatchNorm(2, sigmoid; affine) + x = [1.0f0 3.0f0 5.0f0 + 2.0f0 4.0f0 6.0f0] |> aType + __display(m) + ps, st = Lux.setup(rng, m) .|> device + st = Lux.testmode(st) + y, st_ = m(x, ps, st) + @test check_approx(y, + sigmoid.((x .- st_.running_mean) ./ sqrt.(st_.running_var .+ m.epsilon)), + atol=1.0e-7) + + @jet m(x, ps, st) + + if affine + __f = (x, ps) -> sum(first(m(x, ps, st))) + @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu skip_finite_differences=true + else + __f = x -> sum(first(m(x, ps, st))) + @eval @test_gradients $__f $x atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu skip_finite_differences=true + end + + m = BatchNorm(32; affine) + x = randn(Float32, 416, 416, 32, 1) |> aType + __display(m) + ps, st = Lux.setup(rng, m) .|> device + st = Lux.testmode(st) + m(x, ps, st) + @test (@allocated m(x, ps, st)) < 100_000_000 + + @jet m(x, ps, st) + end + + @testset "allow fast activation" begin + layer = BatchNorm(10, tanh) + @test layer.activation == tanh_fast + layer = BatchNorm(10, tanh; allow_fast_activation=false) + @test layer.activation == tanh + end + end +end + +@testitem "GroupNorm" setup=[SharedTestSetup] begin + rng = get_stable_rng(12345) + + @testset "$mode" for (mode, aType, device, ongpu) in MODES + squeeze(x) = dropdims(x; dims=tuple(findall(size(x) .== 1)...)) # To remove all singular dimensions + + m = GroupNorm(4, 2) + sizes = (3, 4, 2) + x = reshape(collect(1:prod(sizes)), sizes) |> aType + + __display(m) + x = Float32.(x) + ps, st = Lux.setup(rng, m) .|> device + @test Lux.parameterlength(m) == Lux.parameterlength(ps) + @test Lux.statelength(m) == Lux.statelength(st) + @test ps.bias == [0, 0, 0, 0] |> aType # init_bias(32) + @test ps.scale == [1, 1, 1, 1] |> aType # init_scale(32) + + y, st_ = pullback(m, x, ps, st)[1] + + @jet m(x, ps, st) + __f = ps -> sum(first(m(x, ps, st))) + @eval @test_gradients $__f $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu + + @testset "affine: $affine" for affine in (true, false) + m = GroupNorm(2, 2; affine) + x = rand(rng, Float32, 3, 2, 1) |> aType + __display(m) + ps, st = Lux.setup(rng, m) .|> device + + @jet m(x, ps, st) + + if affine + __f = (x, ps) -> sum(first(m(x, ps, st))) + @eval @test_gradients $__f $x $ps atol=1.0f-2 rtol=1.0f-2 gpu_testing=$ongpu skip_finite_differences=true + else + __f = x -> sum(first(m(x, ps, st))) + @eval @test_gradients $__f $x atol=1.0f-2 rtol=1.0f-2 gpu_testing=$ongpu skip_finite_differences=true + end + + # with activation function + m = GroupNorm(2, 2, sigmoid; affine) + x = randn(rng, Float32, 3, 2, 1) |> aType + __display(m) + ps, st = Lux.setup(rng, m) .|> device + st = Lux.testmode(st) + y, st_ = m(x, ps, st) + + @jet m(x, ps, st) + + if affine + __f = (x, ps) -> sum(first(m(x, ps, st))) + @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu skip_finite_differences=true + else + __f = x -> sum(first(m(x, ps, st))) + @eval @test_gradients $__f $x atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu skip_finite_differences=true + end + + m = GroupNorm(32, 16; affine) + x = randn(rng, Float32, 416, 416, 32, 1) |> aType + __display(m) + ps, st = Lux.setup(rng, m) .|> device + st = Lux.testmode(st) + m(x, ps, st) + + @test (@allocated m(x, ps, st)) < 100_000_000 + + if affine + LuxTestUtils.JET.@test_opt target_modules=(LuxLib,) m(x, ps, st) + else + @jet m(x, ps, st) + end + end + + @test_throws AssertionError GroupNorm(5, 2) + + @testset "allow fast activation" begin + layer = GroupNorm(10, 2, tanh) + @test layer.activation == tanh_fast + layer = GroupNorm(10, 2, tanh; allow_fast_activation=false) + @test layer.activation == tanh + end + end +end + +@testitem "WeightNorm" setup=[SharedTestSetup] begin + rng = get_stable_rng(12345) + + @testset "$mode" for (mode, aType, device, ongpu) in MODES + @testset "_norm_except" begin + z = randn(rng, Float32, 3, 3, 4, 2) |> aType + + @test size(Lux._norm(z; dims=(1, 2))) == (1, 1, 4, 2) + @test size(Lux._norm_except(z; dims=1)) == (3, 1, 1, 1) + @test Lux._norm_except(z; dims=2) == Lux._norm(z; dims=(1, 3, 4)) + @test size(Lux._norm_except(z; dims=(1, 2))) == (3, 3, 1, 1) + @test Lux._norm_except(z; dims=(1, 2)) == Lux._norm(z; dims=(3, 4)) + + @jet Lux._norm_except(z) + __f = z -> sum(Lux._norm_except(z; dims=(3, 4))) + @jet __f(z) + end + + @testset "Conv" begin + c = Conv((3, 3), 3 => 3; init_bias=Lux.ones32) + + wn = WeightNorm(c, (:weight, :bias)) + __display(wn) + ps, st = Lux.setup(rng, wn) .|> device + x = randn(rng, Float32, 3, 3, 3, 1) |> aType + + @jet wn(x, ps, st) + __f = ps -> sum(first(wn(x, ps, st))) + @eval @test_gradients $__f $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu skip_reverse_diff=true + + wn = WeightNorm(c, (:weight,)) + __display(wn) + ps, st = Lux.setup(rng, wn) .|> device + x = randn(rng, Float32, 3, 3, 3, 1) |> aType + + @jet wn(x, ps, st) + __f = (x, ps) -> sum(first(wn(x, ps, st))) + @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu skip_reverse_diff=true + + wn = WeightNorm(c, (:weight, :bias), (2, 2)) + __display(wn) + ps, st = Lux.setup(rng, wn) .|> device + x = randn(rng, Float32, 3, 3, 3, 1) |> aType + + @jet wn(x, ps, st) + __f = (x, ps) -> sum(first(wn(x, ps, st))) + @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu skip_reverse_diff=true + + wn = WeightNorm(c, (:weight,), (2,)) + __display(wn) + ps, st = Lux.setup(rng, wn) .|> device + x = randn(rng, Float32, 3, 3, 3, 1) |> aType + + @jet wn(x, ps, st) + __f = (x, ps) -> sum(first(wn(x, ps, st))) + @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu skip_reverse_diff=true + end + + @testset "Dense" begin + d = Dense(3 => 3; init_bias=Lux.ones32) + + wn = WeightNorm(d, (:weight, :bias)) + __display(wn) + ps, st = Lux.setup(rng, wn) .|> device + x = randn(rng, Float32, 3, 1) |> aType + + @jet wn(x, ps, st) + __f = ps -> sum(first(wn(x, ps, st))) + @eval @test_gradients $__f $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu + + wn = WeightNorm(d, (:weight,)) + __display(wn) + ps, st = Lux.setup(rng, wn) .|> device + x = randn(rng, Float32, 3, 1) |> aType + + @jet wn(x, ps, st) + __f = (x, ps) -> sum(first(wn(x, ps, st))) + @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu + + wn = WeightNorm(d, (:weight, :bias), (2, 2)) + __display(wn) + ps, st = Lux.setup(rng, wn) .|> device + x = randn(rng, Float32, 3, 1) |> aType + + @jet wn(x, ps, st) + __f = (x, ps) -> sum(first(wn(x, ps, st))) + @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu + + wn = WeightNorm(d, (:weight,), (2,)) + __display(wn) + ps, st = Lux.setup(rng, wn) .|> device + x = randn(rng, Float32, 3, 1) |> aType + + @jet wn(x, ps, st) + __f = (x, ps) -> sum(first(wn(x, ps, st))) + @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu + end + + # See https://github.com/avik-pal/Lux.jl/issues/95 + @testset "Normalizing Zero Parameters" begin + c = Conv((3, 3), 3 => 3) + + wn = WeightNorm(c, (:weight, :bias)) + @test_throws ArgumentError Lux.setup(rng, wn) + + wn = WeightNorm(c, (:weight,)) + @test_nowarn Lux.setup(rng, wn) + + c = Conv((3, 3), 3 => 3; init_bias=Lux.ones32) + + wn = WeightNorm(c, (:weight, :bias)) + @test_nowarn Lux.setup(rng, wn) + + wn = WeightNorm(c, (:weight,)) + @test_nowarn Lux.setup(rng, wn) + end + end +end + +@testitem "LayerNorm" setup=[SharedTestSetup] begin + rng = get_stable_rng(12345) + + @testset "$mode" for (mode, aType, device, ongpu) in MODES + x = randn(rng, Float32, 3, 3, 3, 2) |> aType + + for bshape in ((3, 3, 3), (1, 3, 1), (3, 1, 3)) + for affine in (true, false) + ln = LayerNorm(bshape; affine) + __display(ln) + ps, st = Lux.setup(rng, ln) .|> device + + y, st_ = ln(x, ps, st) + + @test check_approx(mean(y), 0; atol=1.0f-3, rtol=1.0f-3) + @test check_approx(std(y), 1; atol=1.0f-2, rtol=1.0f-2) + + @jet ln(x, ps, st) + + if affine + __f = (x, ps) -> sum(first(ln(x, ps, st))) + @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu skip_finite_differences=true + else + __f = x -> sum(first(ln(x, ps, st))) + @eval @test_gradients $__f $x atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu skip_finite_differences=true + end + + for act in (sigmoid, tanh) + ln = LayerNorm(bshape, act; affine) + __display(ln) + ps, st = Lux.setup(rng, ln) .|> device + + y, st_ = ln(x, ps, st) + + @jet ln(x, ps, st) + + if affine + __f = (x, ps) -> sum(first(ln(x, ps, st))) + @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu skip_finite_differences=true + else + __f = x -> sum(first(ln(x, ps, st))) + @eval @test_gradients $__f $x atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu skip_finite_differences=true + end + end + end + end + + @testset "allow fast activation" begin + layer = LayerNorm((3, 1), tanh) + @test layer.activation == tanh_fast + layer = LayerNorm((3, 1), tanh; allow_fast_activation=false) + @test layer.activation == tanh + end + end +end + +@testitem "InstanceNorm" setup=[SharedTestSetup] begin + rng = get_stable_rng(12345) + + @testset "$mode" for (mode, aType, device, ongpu) in MODES + for x in (randn(rng, Float32, 3, 3, 3, 2), + randn(rng, Float32, 3, 3, 2), + randn(rng, Float32, 3, 3, 3, 3, 2)) + x = x |> aType + for affine in (true, false) + layer = InstanceNorm(3; affine) + __display(layer) + ps, st = Lux.setup(rng, layer) .|> device + + y, st_ = layer(x, ps, st) + + @jet layer(x, ps, st) + + if affine + __f = (x, ps) -> sum(first(layer(x, ps, st))) + @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu skip_finite_differences=true + else + __f = x -> sum(first(layer(x, ps, st))) + @eval @test_gradients $__f $x atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu skip_finite_differences=true + end + + for act in (sigmoid, tanh) + layer = InstanceNorm(3, act; affine) + __display(layer) + ps, st = Lux.setup(rng, layer) .|> device + + y, st_ = layer(x, ps, st) + + @jet layer(x, ps, st) + + if affine + __f = (x, ps) -> sum(first(layer(x, ps, st))) + @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu skip_finite_differences=true + else + __f = x -> sum(first(layer(x, ps, st))) + @eval @test_gradients $__f $x atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu skip_finite_differences=true + end + end + end + end + + @testset "allow fast activation" begin + layer = InstanceNorm(3, tanh) + @test layer.activation == tanh_fast + layer = InstanceNorm(3, tanh; allow_fast_activation=false) + @test layer.activation == tanh + end + end +end diff --git a/test/layers/recurrent.jl b/test/layers/recurrent.jl deleted file mode 100644 index fac044be31..0000000000 --- a/test/layers/recurrent.jl +++ /dev/null @@ -1,306 +0,0 @@ -using Lux, Test - -include("../test_utils.jl") - -rng = get_stable_rng(12345) - -@testset "$mode: RNNCell" for (mode, aType, device, ongpu) in MODES - for rnncell in (RNNCell(3 => 5, identity), - RNNCell(3 => 5, tanh), - RNNCell(3 => 5, tanh; use_bias=false), - RNNCell(3 => 5, identity; use_bias=false), - RNNCell(3 => 5, identity; use_bias=false, train_state=false)) - __display(rnncell) - ps, st = Lux.setup(rng, rnncell) .|> device - x = randn(rng, Float32, 3, 2) |> aType - (y, carry), st_ = Lux.apply(rnncell, x, ps, st) - - @jet rnncell(x, ps, st) - @jet rnncell((x, carry), ps, st) - - function loss_loop_rnncell(p) - (y, carry), st_ = rnncell(x, p, st) - for i in 1:10 - (y, carry), st_ = rnncell((x, carry), p, st_) - end - return sum(abs2, y) - end - - @test_throws ErrorException ps.train_state - - @eval @test_gradients $loss_loop_rnncell $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu - end - - @testset "Trainable hidden states" begin - for rnncell in (RNNCell(3 => 5, identity; use_bias=false, train_state=true), - RNNCell(3 => 5, identity; use_bias=true, train_state=true)) - rnn_no_trainable_state = RNNCell(3 => 5, - identity; - use_bias=false, - train_state=false) - x = randn(rng, Float32, 3, 2) |> aType - _ps, _st = Lux.setup(rng, rnn_no_trainable_state) .|> device - (_y, _carry), _ = Lux.apply(rnn_no_trainable_state, x, _ps, _st) - - rnncell = RNNCell(3 => 5, identity; use_bias=false, train_state=true) - ps, st = Lux.setup(rng, rnncell) .|> device - ps = merge(_ps, (hidden_state=ps.hidden_state,)) - (y, carry), _ = Lux.apply(rnncell, x, ps, st) - @test carry == _carry - - l, back = Zygote.pullback(p -> sum(abs2, 0 .- rnncell(x, p, st)[1][1]), ps) - gs = back(one(l))[1] - @test !isnothing(gs.hidden_state) - end - end -end - -@testset "$mode: LSTMCell" for (mode, aType, device, ongpu) in MODES - for lstmcell in (LSTMCell(3 => 5), - LSTMCell(3 => 5; use_bias=true), - LSTMCell(3 => 5; use_bias=false)) - __display(lstmcell) - ps, st = Lux.setup(rng, lstmcell) .|> device - x = randn(rng, Float32, 3, 2) |> aType - (y, carry), st_ = Lux.apply(lstmcell, x, ps, st) - - @jet lstmcell(x, ps, st) - @jet lstmcell((x, carry), ps, st) - - function loss_loop_lstmcell(p) - (y, carry), st_ = lstmcell(x, p, st) - for i in 1:10 - (y, carry), st_ = lstmcell((x, carry), p, st_) - end - return sum(abs2, y) - end - - @eval @test_gradients $loss_loop_lstmcell $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu - - @test_throws ErrorException ps.train_state - @test_throws ErrorException ps.train_memory - end - - @testset "Trainable hidden states" begin - x = randn(rng, Float32, 3, 2) |> aType - _lstm = LSTMCell(3 => 5; use_bias=false, train_state=false, train_memory=false) - _ps, _st = Lux.setup(rng, _lstm) .|> device - (_y, _carry), _ = Lux.apply(_lstm, x, _ps, _st) - - lstm = LSTMCell(3 => 5; use_bias=false, train_state=false, train_memory=false) - ps, st = Lux.setup(rng, lstm) .|> device - ps = _ps - (y, carry), _ = Lux.apply(lstm, x, ps, st) - @test carry == _carry - l, back = Zygote.pullback(p -> sum(abs2, 0 .- sum(lstm(x, p, st)[1][1])), ps) - gs = back(one(l))[1] - @test_throws ErrorException gs.bias - @test_throws ErrorException gs.hidden_state - @test_throws ErrorException gs.memory - - lstm = LSTMCell(3 => 5; use_bias=false, train_state=true, train_memory=false) - ps, st = Lux.setup(rng, lstm) .|> device - ps = merge(_ps, (hidden_state=ps.hidden_state,)) - (y, carry), _ = Lux.apply(lstm, x, ps, st) - @test carry == _carry - l, back = Zygote.pullback(p -> sum(abs2, 0 .- sum(lstm(x, p, st)[1][1])), ps) - gs = back(one(l))[1] - @test_throws ErrorException gs.bias - @test !isnothing(gs.hidden_state) - @test_throws ErrorException gs.memory - - lstm = LSTMCell(3 => 5; use_bias=false, train_state=false, train_memory=true) - ps, st = Lux.setup(rng, lstm) .|> device - ps = merge(_ps, (memory=ps.memory,)) - (y, carry), _ = Lux.apply(lstm, x, ps, st) - @test carry == _carry - l, back = Zygote.pullback(p -> sum(abs2, 0 .- sum(lstm(x, p, st)[1][1])), ps) - gs = back(one(l))[1] - @test_throws ErrorException gs.bias - @test_throws ErrorException gs.hidden_state - @test !isnothing(gs.memory) - - lstm = LSTMCell(3 => 5; use_bias=false, train_state=true, train_memory=true) - ps, st = Lux.setup(rng, lstm) .|> device - ps = merge(_ps, (hidden_state=ps.hidden_state, memory=ps.memory)) - (y, carry), _ = Lux.apply(lstm, x, ps, st) - @test carry == _carry - l, back = Zygote.pullback(p -> sum(abs2, 0 .- sum(lstm(x, p, st)[1][1])), ps) - gs = back(one(l))[1] - @test_throws ErrorException gs.bias - @test !isnothing(gs.hidden_state) - @test !isnothing(gs.memory) - - lstm = LSTMCell(3 => 5; use_bias=true, train_state=true, train_memory=true) - ps, st = Lux.setup(rng, lstm) .|> device - ps = merge(_ps, (bias=ps.bias, hidden_state=ps.hidden_state, memory=ps.memory)) - (y, carry), _ = Lux.apply(lstm, x, ps, st) - l, back = Zygote.pullback(p -> sum(abs2, 0 .- sum(lstm(x, p, st)[1][1])), ps) - gs = back(one(l))[1] - @test !isnothing(gs.bias) - @test !isnothing(gs.hidden_state) - @test !isnothing(gs.memory) - end -end - -@testset "$mode: GRUCell" for (mode, aType, device, ongpu) in MODES - for grucell in (GRUCell(3 => 5), - GRUCell(3 => 5; use_bias=true), - GRUCell(3 => 5; use_bias=false)) - __display(grucell) - ps, st = Lux.setup(rng, grucell) .|> device - x = randn(rng, Float32, 3, 2) |> aType - (y, carry), st_ = Lux.apply(grucell, x, ps, st) - - @jet grucell(x, ps, st) - @jet grucell((x, carry), ps, st) - - function loss_loop_grucell(p) - (y, carry), st_ = grucell(x, p, st) - for i in 1:10 - (y, carry), st_ = grucell((x, carry), p, st_) - end - return sum(abs2, y) - end - - @eval @test_gradients $loss_loop_grucell $ps atol=1e-2 rtol=1e-2 gpu_testing=$ongpu - - @test_throws ErrorException ps.train_state - end - - @testset "Trainable hidden states" begin - x = randn(rng, Float32, 3, 2) |> aType - _gru = GRUCell(3 => 5; use_bias=false, train_state=false) - _ps, _st = Lux.setup(rng, _gru) .|> device - (_y, _carry), _ = Lux.apply(_gru, x, _ps, _st) - - gru = GRUCell(3 => 5; use_bias=false, train_state=false) - ps, st = Lux.setup(rng, gru) .|> device - ps = _ps - (y, carry), _ = Lux.apply(gru, x, ps, st) - @test carry == _carry - l, back = Zygote.pullback(p -> sum(abs2, 0 .- sum(gru(x, p, st)[1][1])), ps) - gs = back(one(l))[1] - @test_throws ErrorException gs.bias - @test_throws ErrorException gs.hidden_state - - gru = GRUCell(3 => 5; use_bias=false, train_state=true) - ps, st = Lux.setup(rng, gru) .|> device - ps = merge(_ps, (hidden_state=ps.hidden_state,)) - (y, carry), _ = Lux.apply(gru, x, ps, st) - @test carry == _carry - l, back = Zygote.pullback(p -> sum(abs2, 0 .- sum(gru(x, p, st)[1][1])), ps) - gs = back(one(l))[1] - @test !isnothing(gs.hidden_state) - - gru = GRUCell(3 => 5; use_bias=true, train_state=true) - ps, st = Lux.setup(rng, gru) .|> device - ps = merge(_ps, (bias_h=ps.bias_h, bias_i=ps.bias_i, hidden_state=ps.hidden_state)) - (y, carry), _ = Lux.apply(gru, x, ps, st) - @test carry == _carry - l, back = Zygote.pullback(p -> sum(abs2, 0 .- sum(gru(x, p, st)[1][1])), ps) - gs = back(one(l))[1] - @test !isnothing(gs.hidden_state) - end -end - -@testset "$mode: StatefulRecurrentCell" for (mode, aType, device, ongpu) in MODES - for _cell in (RNNCell, LSTMCell, GRUCell), - use_bias in (true, false), - train_state in (true, false) - - cell = _cell(3 => 5; use_bias, train_state) - rnn = StatefulRecurrentCell(cell) - __display(rnn) - x = randn(rng, Float32, 3, 2) |> aType - ps, st = Lux.setup(rng, rnn) .|> device - - y, st_ = rnn(x, ps, st) - - @jet rnn(x, ps, st) - @jet rnn(x, ps, st_) - - @test size(y) == (5, 2) - @test st.carry === nothing - @test st_.carry !== nothing - - st__ = Lux.update_state(st, :carry, nothing) - @test st__.carry === nothing - - function loss_loop_rnn(p) - y, st_ = rnn(x, p, st) - for i in 1:10 - y, st_ = rnn(x, p, st_) - end - return sum(abs2, y) - end - - @eval @test_gradients $loss_loop_rnn $ps atol=1e-2 rtol=1e-2 gpu_testing=$ongpu - end -end - -@testset "$mode: Recurrence" for (mode, aType, device, ongpu) in MODES - @testset "ordering: $ordering" for ordering in (BatchLastIndex(), TimeLastIndex()) - @testset "cell: $_cell" for _cell in (RNNCell, LSTMCell, GRUCell) - @testset "use_bias: $use_bias, train_state: $train_state" for use_bias in ( - true, - false), - train_state in (true, false) - - cell = _cell(3 => 5; use_bias, train_state) - rnn = Recurrence(cell; ordering) - rnn_seq = Recurrence(cell; ordering, return_sequence=true) - __display(rnn) - - # Batched Time Series - @testset "typeof(x): $(typeof(x))" for x in ( - randn(rng, Float32, 3, 4, 2) |> - aType, - Tuple(randn(rng, Float32, 3, 2) for _ in 1:4) .|> aType, - [randn(rng, Float32, 3, 2) for _ in 1:4] .|> aType) - # Fix data ordering for testing - if ordering isa TimeLastIndex && x isa AbstractArray && ndims(x) ≥ 2 - x = permutedims(x, - (ntuple(identity, ndims(x) - 2)..., ndims(x), ndims(x) - 1)) - end - - ps, st = Lux.setup(rng, rnn) .|> device - y, st_ = rnn(x, ps, st) - y_, st__ = rnn_seq(x, ps, st) - - @jet rnn(x, ps, st) - @jet rnn_seq(x, ps, st) - - @test size(y) == (5, 2) - @test length(y_) == 4 - @test all(x -> size(x) == (5, 2), y_) - - if mode != "AMDGPU" && !(VERSION < v"1.9" && x isa AbstractVector) - __f = p -> sum(first(rnn(x, p, st))) - @eval @test_gradients $__f $ps atol=1e-2 rtol=1e-2 gpu_testing=$ongpu - - __f = p -> sum(Base.Fix1(sum, abs2), first(rnn_seq(x, p, st))) - @eval @test_gradients $__f $ps atol=1e-2 rtol=1e-2 gpu_testing=$ongpu - else - # This is just added as a stub to remember about this broken test - @test_broken 1 + 1 == 1 - end - end - end - end - end - - # Ordering Check: https://github.com/LuxDL/Lux.jl/issues/302 - encoder = Recurrence( - RNNCell(1 => 1, - identity; - init_weight=(rng, args...; kwargs...) -> ones(args...; kwargs...), - init_state=(rng, args...; kwargs...) -> zeros(args...; kwargs...), - init_bias=(rng, args...; kwargs...) -> zeros(args...; kwargs...)); - return_sequence=true) - ps, st = Lux.setup(rng, encoder) .|> device - m2 = reshape([0.5, 0.0, 0.7, 0.8], 1, :, 1) |> aType - res, _ = encoder(m2, ps, st) - - @test Array(vec(reduce(vcat, res))) ≈ [0.5, 0.5, 1.2, 2.0] -end diff --git a/test/layers/recurrent_tests.jl b/test/layers/recurrent_tests.jl new file mode 100644 index 0000000000..4a941b79d7 --- /dev/null +++ b/test/layers/recurrent_tests.jl @@ -0,0 +1,326 @@ + +@testitem "RNNCell" setup=[SharedTestSetup] begin + rng = get_stable_rng(12345) + + @testset "$mode" for (mode, aType, device, ongpu) in MODES + for rnncell in (RNNCell(3 => 5, identity), + RNNCell(3 => 5, tanh), + RNNCell(3 => 5, tanh; use_bias=false), + RNNCell(3 => 5, identity; use_bias=false), + RNNCell(3 => 5, identity; use_bias=false, train_state=false)) + __display(rnncell) + ps, st = Lux.setup(rng, rnncell) .|> device + x = randn(rng, Float32, 3, 2) |> aType + (y, carry), st_ = Lux.apply(rnncell, x, ps, st) + + @jet rnncell(x, ps, st) + @jet rnncell((x, carry), ps, st) + + function loss_loop_rnncell(p) + (y, carry), st_ = rnncell(x, p, st) + for i in 1:10 + (y, carry), st_ = rnncell((x, carry), p, st_) + end + return sum(abs2, y) + end + + @test_throws ErrorException ps.train_state + + @eval @test_gradients $loss_loop_rnncell $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu + end + + @testset "Trainable hidden states" begin + for rnncell in (RNNCell(3 => 5, identity; use_bias=false, train_state=true), + RNNCell(3 => 5, identity; use_bias=true, train_state=true)) + rnn_no_trainable_state = RNNCell(3 => 5, + identity; + use_bias=false, + train_state=false) + x = randn(rng, Float32, 3, 2) |> aType + _ps, _st = Lux.setup(rng, rnn_no_trainable_state) .|> device + (_y, _carry), _ = Lux.apply(rnn_no_trainable_state, x, _ps, _st) + + rnncell = RNNCell(3 => 5, identity; use_bias=false, train_state=true) + ps, st = Lux.setup(rng, rnncell) .|> device + ps = merge(_ps, (hidden_state=ps.hidden_state,)) + (y, carry), _ = Lux.apply(rnncell, x, ps, st) + @test carry == _carry + + l, back = Zygote.pullback(p -> sum(abs2, 0 .- rnncell(x, p, st)[1][1]), ps) + gs = back(one(l))[1] + @test !isnothing(gs.hidden_state) + end + end + end +end + +@testitem "LSTMCell" setup=[SharedTestSetup] begin + rng = get_stable_rng(12345) + + @testset "$mode" for (mode, aType, device, ongpu) in MODES + for lstmcell in (LSTMCell(3 => 5), + LSTMCell(3 => 5; use_bias=true), + LSTMCell(3 => 5; use_bias=false)) + __display(lstmcell) + ps, st = Lux.setup(rng, lstmcell) .|> device + x = randn(rng, Float32, 3, 2) |> aType + (y, carry), st_ = Lux.apply(lstmcell, x, ps, st) + + @jet lstmcell(x, ps, st) + @jet lstmcell((x, carry), ps, st) + + function loss_loop_lstmcell(p) + (y, carry), st_ = lstmcell(x, p, st) + for i in 1:10 + (y, carry), st_ = lstmcell((x, carry), p, st_) + end + return sum(abs2, y) + end + + @eval @test_gradients $loss_loop_lstmcell $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu + + @test_throws ErrorException ps.train_state + @test_throws ErrorException ps.train_memory + end + + @testset "Trainable hidden states" begin + x = randn(rng, Float32, 3, 2) |> aType + _lstm = LSTMCell(3 => 5; use_bias=false, train_state=false, train_memory=false) + _ps, _st = Lux.setup(rng, _lstm) .|> device + (_y, _carry), _ = Lux.apply(_lstm, x, _ps, _st) + + lstm = LSTMCell(3 => 5; use_bias=false, train_state=false, train_memory=false) + ps, st = Lux.setup(rng, lstm) .|> device + ps = _ps + (y, carry), _ = Lux.apply(lstm, x, ps, st) + @test carry == _carry + l, back = Zygote.pullback(p -> sum(abs2, 0 .- sum(lstm(x, p, st)[1][1])), ps) + gs = back(one(l))[1] + @test_throws ErrorException gs.bias + @test_throws ErrorException gs.hidden_state + @test_throws ErrorException gs.memory + + lstm = LSTMCell(3 => 5; use_bias=false, train_state=true, train_memory=false) + ps, st = Lux.setup(rng, lstm) .|> device + ps = merge(_ps, (hidden_state=ps.hidden_state,)) + (y, carry), _ = Lux.apply(lstm, x, ps, st) + @test carry == _carry + l, back = Zygote.pullback(p -> sum(abs2, 0 .- sum(lstm(x, p, st)[1][1])), ps) + gs = back(one(l))[1] + @test_throws ErrorException gs.bias + @test !isnothing(gs.hidden_state) + @test_throws ErrorException gs.memory + + lstm = LSTMCell(3 => 5; use_bias=false, train_state=false, train_memory=true) + ps, st = Lux.setup(rng, lstm) .|> device + ps = merge(_ps, (memory=ps.memory,)) + (y, carry), _ = Lux.apply(lstm, x, ps, st) + @test carry == _carry + l, back = Zygote.pullback(p -> sum(abs2, 0 .- sum(lstm(x, p, st)[1][1])), ps) + gs = back(one(l))[1] + @test_throws ErrorException gs.bias + @test_throws ErrorException gs.hidden_state + @test !isnothing(gs.memory) + + lstm = LSTMCell(3 => 5; use_bias=false, train_state=true, train_memory=true) + ps, st = Lux.setup(rng, lstm) .|> device + ps = merge(_ps, (hidden_state=ps.hidden_state, memory=ps.memory)) + (y, carry), _ = Lux.apply(lstm, x, ps, st) + @test carry == _carry + l, back = Zygote.pullback(p -> sum(abs2, 0 .- sum(lstm(x, p, st)[1][1])), ps) + gs = back(one(l))[1] + @test_throws ErrorException gs.bias + @test !isnothing(gs.hidden_state) + @test !isnothing(gs.memory) + + lstm = LSTMCell(3 => 5; use_bias=true, train_state=true, train_memory=true) + ps, st = Lux.setup(rng, lstm) .|> device + ps = merge(_ps, (bias=ps.bias, hidden_state=ps.hidden_state, memory=ps.memory)) + (y, carry), _ = Lux.apply(lstm, x, ps, st) + l, back = Zygote.pullback(p -> sum(abs2, 0 .- sum(lstm(x, p, st)[1][1])), ps) + gs = back(one(l))[1] + @test !isnothing(gs.bias) + @test !isnothing(gs.hidden_state) + @test !isnothing(gs.memory) + end + end +end + +@testitem "GRUCell" setup=[SharedTestSetup] begin + rng = get_stable_rng(12345) + + @testset "$mode" for (mode, aType, device, ongpu) in MODES + for grucell in (GRUCell(3 => 5), + GRUCell(3 => 5; use_bias=true), + GRUCell(3 => 5; use_bias=false)) + __display(grucell) + ps, st = Lux.setup(rng, grucell) .|> device + x = randn(rng, Float32, 3, 2) |> aType + (y, carry), st_ = Lux.apply(grucell, x, ps, st) + + @jet grucell(x, ps, st) + @jet grucell((x, carry), ps, st) + + function loss_loop_grucell(p) + (y, carry), st_ = grucell(x, p, st) + for i in 1:10 + (y, carry), st_ = grucell((x, carry), p, st_) + end + return sum(abs2, y) + end + + @eval @test_gradients $loss_loop_grucell $ps atol=1e-2 rtol=1e-2 gpu_testing=$ongpu + + @test_throws ErrorException ps.train_state + end + + @testset "Trainable hidden states" begin + x = randn(rng, Float32, 3, 2) |> aType + _gru = GRUCell(3 => 5; use_bias=false, train_state=false) + _ps, _st = Lux.setup(rng, _gru) .|> device + (_y, _carry), _ = Lux.apply(_gru, x, _ps, _st) + + gru = GRUCell(3 => 5; use_bias=false, train_state=false) + ps, st = Lux.setup(rng, gru) .|> device + ps = _ps + (y, carry), _ = Lux.apply(gru, x, ps, st) + @test carry == _carry + l, back = Zygote.pullback(p -> sum(abs2, 0 .- sum(gru(x, p, st)[1][1])), ps) + gs = back(one(l))[1] + @test_throws ErrorException gs.bias + @test_throws ErrorException gs.hidden_state + + gru = GRUCell(3 => 5; use_bias=false, train_state=true) + ps, st = Lux.setup(rng, gru) .|> device + ps = merge(_ps, (hidden_state=ps.hidden_state,)) + (y, carry), _ = Lux.apply(gru, x, ps, st) + @test carry == _carry + l, back = Zygote.pullback(p -> sum(abs2, 0 .- sum(gru(x, p, st)[1][1])), ps) + gs = back(one(l))[1] + @test !isnothing(gs.hidden_state) + + gru = GRUCell(3 => 5; use_bias=true, train_state=true) + ps, st = Lux.setup(rng, gru) .|> device + ps = merge(_ps, + (bias_h=ps.bias_h, bias_i=ps.bias_i, hidden_state=ps.hidden_state)) + (y, carry), _ = Lux.apply(gru, x, ps, st) + @test carry == _carry + l, back = Zygote.pullback(p -> sum(abs2, 0 .- sum(gru(x, p, st)[1][1])), ps) + gs = back(one(l))[1] + @test !isnothing(gs.hidden_state) + end + end +end + +@testitem "StatefulRecurrentCell" setup=[SharedTestSetup] begin + rng = get_stable_rng(12345) + + @testset "$mode" for (mode, aType, device, ongpu) in MODES + for _cell in (RNNCell, LSTMCell, GRUCell), + use_bias in (true, false), + train_state in (true, false) + + cell = _cell(3 => 5; use_bias, train_state) + rnn = StatefulRecurrentCell(cell) + __display(rnn) + x = randn(rng, Float32, 3, 2) |> aType + ps, st = Lux.setup(rng, rnn) .|> device + + y, st_ = rnn(x, ps, st) + + @jet rnn(x, ps, st) + @jet rnn(x, ps, st_) + + @test size(y) == (5, 2) + @test st.carry === nothing + @test st_.carry !== nothing + + st__ = Lux.update_state(st, :carry, nothing) + @test st__.carry === nothing + + function loss_loop_rnn(p) + y, st_ = rnn(x, p, st) + for i in 1:10 + y, st_ = rnn(x, p, st_) + end + return sum(abs2, y) + end + + @eval @test_gradients $loss_loop_rnn $ps atol=1e-2 rtol=1e-2 gpu_testing=$ongpu + end + end +end + +@testitem "Recurrence" setup=[SharedTestSetup] begin + rng = get_stable_rng(12345) + + @testset "$mode" for (mode, aType, device, ongpu) in MODES + @testset "ordering: $ordering" for ordering in (BatchLastIndex(), TimeLastIndex()) + @testset "cell: $_cell" for _cell in (RNNCell, LSTMCell, GRUCell) + @testset "use_bias: $use_bias, train_state: $train_state" for use_bias in ( + true, + false), + train_state in (true, false) + + cell = _cell(3 => 5; use_bias, train_state) + rnn = Recurrence(cell; ordering) + rnn_seq = Recurrence(cell; ordering, return_sequence=true) + __display(rnn) + + # Batched Time Series + @testset "typeof(x): $(typeof(x))" for x in ( + randn(rng, + Float32, + 3, + 4, + 2) |> + aType, + Tuple(randn(rng, Float32, 3, 2) for _ in 1:4) .|> aType, + [randn(rng, Float32, 3, 2) for _ in 1:4] .|> aType) + # Fix data ordering for testing + if ordering isa TimeLastIndex && x isa AbstractArray && ndims(x) ≥ 2 + x = permutedims(x, + (ntuple(identity, ndims(x) - 2)..., ndims(x), ndims(x) - 1)) + end + + ps, st = Lux.setup(rng, rnn) .|> device + y, st_ = rnn(x, ps, st) + y_, st__ = rnn_seq(x, ps, st) + + @jet rnn(x, ps, st) + @jet rnn_seq(x, ps, st) + + @test size(y) == (5, 2) + @test length(y_) == 4 + @test all(x -> size(x) == (5, 2), y_) + + if mode != "AMDGPU" && !(VERSION < v"1.9" && x isa AbstractVector) + __f = p -> sum(first(rnn(x, p, st))) + @eval @test_gradients $__f $ps atol=1e-2 rtol=1e-2 gpu_testing=$ongpu + + __f = p -> sum(Base.Fix1(sum, abs2), first(rnn_seq(x, p, st))) + @eval @test_gradients $__f $ps atol=1e-2 rtol=1e-2 gpu_testing=$ongpu + else + # This is just added as a stub to remember about this broken test + @test_broken 1 + 1 == 1 + end + end + end + end + end + + # Ordering Check: https://github.com/LuxDL/Lux.jl/issues/302 + encoder = Recurrence( + RNNCell(1 => 1, + identity; + init_weight=(rng, args...; kwargs...) -> ones(args...; kwargs...), + init_state=(rng, args...; kwargs...) -> zeros(args...; kwargs...), + init_bias=(rng, args...; kwargs...) -> zeros(args...; kwargs...)); + return_sequence=true) + ps, st = Lux.setup(rng, encoder) .|> device + m2 = reshape([0.5, 0.0, 0.7, 0.8], 1, :, 1) |> aType + res, _ = encoder(m2, ps, st) + + @test Array(vec(reduce(vcat, res))) ≈ [0.5, 0.5, 1.2, 2.0] + end +end diff --git a/test/misc.jl b/test/misc.jl deleted file mode 100644 index 36e83d790d..0000000000 --- a/test/misc.jl +++ /dev/null @@ -1,50 +0,0 @@ -using ComponentArrays, Lux, Zygote - -include("test_utils.jl") - -rng = get_stable_rng() - -@testset "$mode: Simple Zygote Second Order Derivative" for (mode, aType, dev, ongpu) in MODES - x = rand(rng, Float32, 1, 128) |> aType - nn = Lux.Dense(1 => 1) - ps, st = Lux.setup(rng, nn) |> dev - ps_ca = ComponentArray(ps) - - function test_f(x, ps) - mz, back = Zygote.pullback(x -> first(nn(x, ps, st)), x) - ep = only(back(one.(mz))) - return sum(mz) + sum(ep) - end - - @testset "Named Tuple Parameters" begin - @test_nowarn test_f(x, ps) - - @test begin - y, back = Zygote.pullback(test_f, x, ps) - ∂x, ∂ps = back(one(y)) - ∂x !== nothing && ∂ps !== nothing - end - - # Weird Zygote Quirks - @test_broken begin - ∂x, ∂ps = Zygote.jacobian(test_f, x, ps) - ∂x !== nothing && ∂ps !== nothing - end - end - - @testset "Component Array Parameters" begin - @test_nowarn test_f(x, ps_ca) - - @test begin - y, back = Zygote.pullback(test_f, x, ps_ca) - ∂x, ∂ps = back(one(y)) - ∂x !== nothing && ∂ps !== nothing - end - - # Weird Zygote Quirks - @test begin - ∂x, ∂ps = Zygote.jacobian(test_f, x, ps_ca) - ∂x !== nothing && ∂ps !== nothing - end - end -end diff --git a/test/misc_tests.jl b/test/misc_tests.jl new file mode 100644 index 0000000000..e643f89a67 --- /dev/null +++ b/test/misc_tests.jl @@ -0,0 +1,48 @@ +@testitem "Simple Zygote Second Order Derivative" setup=[SharedTestSetup] begin + rng = get_stable_rng() + + @testset "$mode" for (mode, aType, dev, ongpu) in MODES + x = rand(rng, Float32, 1, 128) |> aType + nn = Lux.Dense(1 => 1) + ps, st = Lux.setup(rng, nn) |> dev + ps_ca = ComponentArray(ps) + + function test_f(x, ps) + mz, back = Zygote.pullback(x -> first(nn(x, ps, st)), x) + ep = only(back(one.(mz))) + return sum(mz) + sum(ep) + end + + @testset "Named Tuple Parameters" begin + @test_nowarn test_f(x, ps) + + @test begin + y, back = Zygote.pullback(test_f, x, ps) + ∂x, ∂ps = back(one(y)) + ∂x !== nothing && ∂ps !== nothing + end + + # Weird Zygote Quirks + @test_broken begin + ∂x, ∂ps = Zygote.jacobian(test_f, x, ps) + ∂x !== nothing && ∂ps !== nothing + end + end + + @testset "Component Array Parameters" begin + @test_nowarn test_f(x, ps_ca) + + @test begin + y, back = Zygote.pullback(test_f, x, ps_ca) + ∂x, ∂ps = back(one(y)) + ∂x !== nothing && ∂ps !== nothing + end + + # Weird Zygote Quirks + @test begin + ∂x, ∂ps = Zygote.jacobian(test_f, x, ps_ca) + ∂x !== nothing && ∂ps !== nothing + end + end + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 8fac735b2d..8ba7978a23 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,38 +1,3 @@ -using SafeTestsets, Test, TestSetExtensions +using ReTestItems -const GROUP = get(ENV, "GROUP", "All") - -@testset ExtendedTestSet "Lux.jl" begin - @safetestset "Utils" include("utils.jl") - - @safetestset "Core" include("core.jl") - - @testset "Layers" begin - @safetestset "Basic" include("layers/basic.jl") - @safetestset "Containers" include("layers/containers.jl") - @safetestset "Convolution" include("layers/conv.jl") - @safetestset "Normalization" include("layers/normalize.jl") - @safetestset "Recurrent" include("layers/recurrent.jl") - @safetestset "Dropout" include("layers/dropout.jl") - end - - @testset "Experimental" begin - @safetestset "Map" include("contrib/map.jl") - @safetestset "Training" include("contrib/training.jl") - @safetestset "Freeze" include("contrib/freeze.jl") - @safetestset "Shared Parameters" include("contrib/share_parameters.jl") - @safetestset "Debugging Tools" include("contrib/debug.jl") - # Tests for StatefulLuxLayer is embedded into @compact tests - @safetestset "Stateful & Compact Layers" include("contrib/compact.jl") - end - - @safetestset "Aqua Tests" include("aqua.jl") - - @safetestset "Miscellaneous Tests" include("misc.jl") - - @testset "Extensions" begin - # Most CA tests are already included in the other tests - @safetestset "ComponentArrays" include("ext/LuxComponentArraysExt.jl") - @safetestset "Flux" include("ext/LuxFluxTransformExt.jl") - end -end +ReTestItems.runtests(@__DIR__) diff --git a/test/test_utils.jl b/test/shared_testsetup.jl similarity index 56% rename from test/test_utils.jl rename to test/shared_testsetup.jl index 164d0eb240..8c08d56fd9 100644 --- a/test/test_utils.jl +++ b/test/shared_testsetup.jl @@ -1,6 +1,10 @@ -using Lux, LuxCore, LuxLib, LuxTestUtils, Random, StableRNGs, Test, Zygote -using LuxCUDA, LuxAMDGPU -using LuxTestUtils: @jet, @test_gradients, check_approx +@testsetup module SharedTestSetup +import Reexport: @reexport + +using Lux, LuxCUDA, LuxAMDGPU +@reexport using ComponentArrays, + LuxCore, LuxLib, LuxTestUtils, Random, StableRNGs, Test, Zygote, Statistics +import LuxTestUtils: @jet, @test_gradients, check_approx const GROUP = get(ENV, "GROUP", "All") @@ -26,30 +30,27 @@ end # Some Helper Functions function get_default_rng(mode::String) - if mode == "CPU" - return Random.default_rng() - elseif mode == "CUDA" - return CUDA.RNG() - elseif mode == "AMDGPU" - return AMDGPU.rocRAND.RNG() - else - error("Unknown mode: $mode") - end + dev = mode == "CPU" ? LuxCPUDevice() : + mode == "CUDA" ? LuxCUDADevice() : mode == "AMDGPU" ? LuxAMDGPUDevice() : nothing + return default_device_rng(dev) end get_stable_rng(seed=12345) = StableRNG(seed) +__display(args...) = (println(); display(args...)) + # AMDGPU Specifics function _rocRAND_functional() try - AMDGPU.rocRAND.RNG() + get_default_rng("AMDGPU") return true catch return false end end -function __display(args...) - println() - return display(args...) +export @jet, @test_gradients, check_approx +export GROUP, MODES, cpu_testing, cuda_testing, amdgpu_testing, get_default_rng, + get_stable_rng, __display, _rocRAND_functional + end diff --git a/test/utils.jl b/test/utils.jl deleted file mode 100644 index e6b6acd200..0000000000 --- a/test/utils.jl +++ /dev/null @@ -1,121 +0,0 @@ -using Lux, ComponentArrays, Functors, Optimisers, Zygote, Test -using Statistics: std - -include("test_utils.jl") - -rng = get_stable_rng(12345) - -@testset "$mode: replicate" for (mode, aType, device, ongpu) in MODES - _rng = get_default_rng(mode) - if mode == "AMDGPU" - @test randn(_rng, 10, 2) != randn(_rng, 10, 2) - @test_broken randn(Lux.replicate(_rng), 10, 2) == randn(Lux.replicate(_rng), 10, 2) - else - @test randn(_rng, 10, 2) != randn(_rng, 10, 2) - @test randn(Lux.replicate(_rng), 10, 2) == randn(Lux.replicate(_rng), 10, 2) - end -end - -@testset "istraining" begin - @test Lux.istraining(Val(true)) - @test !Lux.istraining(Val(false)) - @test !Lux.istraining((training=Val(false),)) - @test Lux.istraining((training=Val(true),)) - @test !Lux.istraining((no_training=1,)) - @test_throws MethodError Lux.istraining((training=true,)) -end - -@testset "$mode: multigate" for (mode, aType, device, ongpu) in MODES - x = randn(rng, 10, 1) |> aType - x1, x2 = Lux.multigate(x, Val(2)) - - @test x1 == x[1:5, :] - @test x2 == x[6:10, :] - - @jet Lux.multigate(x, Val(2)) - - x = randn(rng, 10) |> aType - x1, x2 = Lux.multigate(x, Val(2)) - - @test x1 == x[1:5] - @test x2 == x[6:10] - - @jet Lux.multigate(x, Val(2)) - - x = rand(6, 5) |> aType - __f = x -> begin - x1, x2, x3 = Lux.multigate(x, Val(3)) - return sum(x1) + sum(x3 .+ x2 .^ 2) - end - res, (dx,) = Zygote.withgradient(__f, x) - - @jet Lux.multigate(x, Val(3)) - - @test res ≈ sum(x[1:2, :]) + sum(x[5:6, :]) + sum(abs2, x[3:4, :]) - @test dx ≈ aType([ones(2, 5); Array(x[3:4, :] .* 2); ones(2, 5)]) - - @eval @test_gradients $__f $x atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu -end - -@testset "$mode: ComponentArrays" for (mode, aType, device, ongpu) in MODES - ps = (weight=randn(rng, 3, 4), bias=randn(rng, 4)) - p_flat, re = Optimisers.destructure(ps) - ps_c = ComponentArray(ps) - - @test ps_c.weight == ps.weight - @test ps_c.bias == ps.bias - - @test p_flat == getdata(ps_c) - @test -p_flat == getdata(-ps_c) - @test zero(p_flat) == getdata(zero(ps_c)) - - @test_nowarn similar(ps_c, 10) - @test_nowarn similar(ps_c) - - ps_c_f, ps_c_re = Functors.functor(ps_c) - @test ps_c_f == ps - @test ps_c_re(ps_c_f) == ps_c - - # Empty ComponentArray test - @test_nowarn __display(ComponentArray(NamedTuple())) - println() - - # Optimisers - opt = Adam(0.001f0) - ps_c = ps_c |> device - st_opt = Optimisers.setup(opt, ps_c) - - @test_nowarn Optimisers.update(st_opt, ps_c, ps_c) - @test_nowarn Optimisers.update!(st_opt, ps_c, ps_c) -end - -@testset "$mode: _init_hidden_state" for (mode, aType, device, ongpu) in MODES - rnn = RNNCell(3 => 5; init_state=Lux.zeros32) - x = randn(rng, Float32, 3, 2, 2) - @test Lux._init_hidden_state(rng, rnn, view(device(x), :, 1, :)) == - aType(zeros(Float32, 5, 2)) -end - -@testset "$mode: FP Conversions" for (mode, aType, device, ongpu) in MODES - model = Chain(Dense(1 => 16, relu), Chain(Dense(16 => 1), Dense(1 => 1)), - BatchNorm(1); disable_optimizations=true) - - for (f, ftype) in zip((f16, f32, f64), (Float16, Float32, Float64)) - ps, st = Lux.setup(rng, model) |> device |> f - - @test eltype(ps.layer_1.weight) == ftype - @test eltype(ps.layer_1.bias) == ftype - @test eltype(ps.layer_2.layer_1.weight) == ftype - @test eltype(ps.layer_2.layer_1.bias) == ftype - @test eltype(ps.layer_2.layer_2.weight) == ftype - @test eltype(ps.layer_2.layer_2.bias) == ftype - @test eltype(ps.layer_3.scale) == ftype - @test eltype(ps.layer_3.bias) == ftype - @test st.layer_1 == NamedTuple() - @test st.layer_2.layer_1 == NamedTuple() - @test st.layer_2.layer_2 == NamedTuple() - @test eltype(st.layer_3.running_mean) == ftype - @test eltype(st.layer_3.running_var) == ftype - @test typeof(st.layer_3.training) == Val{true} - end -end diff --git a/test/utils_tests.jl b/test/utils_tests.jl new file mode 100644 index 0000000000..51569edc40 --- /dev/null +++ b/test/utils_tests.jl @@ -0,0 +1,143 @@ +@testitem "replicate" setup=[SharedTestSetup] begin + @testset "$mode" for (mode, aType, device, ongpu) in MODES + _rng = get_default_rng(mode) + + if mode == "AMDGPU" + @test_skip randn(_rng, 10, 2) != randn(_rng, 10, 2) + @test_skip randn(Lux.replicate(_rng), 10, 2) == + randn(Lux.replicate(_rng), 10, 2) + else + @test randn(_rng, 10, 2) != randn(_rng, 10, 2) + @test randn(Lux.replicate(_rng), 10, 2) == randn(Lux.replicate(_rng), 10, 2) + end + end +end + +@testitem "istraining" begin + @test Lux.istraining(Val(true)) + @test !Lux.istraining(Val(false)) + @test !Lux.istraining((training=Val(false),)) + @test Lux.istraining((training=Val(true),)) + @test !Lux.istraining((no_training=1,)) + @test_throws MethodError Lux.istraining((training=true,)) +end + +@testitem "multigate" setup=[SharedTestSetup] begin + rng = get_stable_rng(12345) + + @testset "$mode" for (mode, aType, device, ongpu) in MODES + x = randn(rng, 10, 1) |> aType + x1, x2 = Lux.multigate(x, Val(2)) + + @test x1 == x[1:5, :] + @test x2 == x[6:10, :] + + @jet Lux.multigate(x, Val(2)) + + x = randn(rng, 10) |> aType + x1, x2 = Lux.multigate(x, Val(2)) + + @test x1 == x[1:5] + @test x2 == x[6:10] + + @jet Lux.multigate(x, Val(2)) + + x = rand(6, 5) |> aType + __f = x -> begin + x1, x2, x3 = Lux.multigate(x, Val(3)) + return sum(x1) + sum(x3 .+ x2 .^ 2) + end + res, (dx,) = Zygote.withgradient(__f, x) + + @jet Lux.multigate(x, Val(3)) + + @test res ≈ sum(x[1:2, :]) + sum(x[5:6, :]) + sum(abs2, x[3:4, :]) + @test dx ≈ aType([ones(2, 5); Array(x[3:4, :] .* 2); ones(2, 5)]) + + @eval @test_gradients $__f $x atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu + end +end + +@testitem "ComponentArrays" setup=[SharedTestSetup] begin + using Optimisers, Functors + + rng = get_stable_rng(12345) + + @testset "$mode" for (mode, aType, device, ongpu) in MODES + ps = (weight=randn(rng, 3, 4), bias=randn(rng, 4)) + p_flat, re = Optimisers.destructure(ps) + ps_c = ComponentArray(ps) + + @test ps_c.weight == ps.weight + @test ps_c.bias == ps.bias + + @test p_flat == getdata(ps_c) + @test -p_flat == getdata(-ps_c) + @test zero(p_flat) == getdata(zero(ps_c)) + + @test_nowarn similar(ps_c, 10) + @test_nowarn similar(ps_c) + + ps_c_f, ps_c_re = Functors.functor(ps_c) + @test ps_c_f == ps + @test ps_c_re(ps_c_f) == ps_c + + # Empty ComponentArray test + @test_nowarn __display(ComponentArray(NamedTuple())) + println() + + # Optimisers + opt = Adam(0.001f0) + ps_c = ps_c |> device + st_opt = Optimisers.setup(opt, ps_c) + + @test_nowarn Optimisers.update(st_opt, ps_c, ps_c) + @test_nowarn Optimisers.update!(st_opt, ps_c, ps_c) + end + + # Ref: https://github.com/LuxDL/Lux.jl/issues/243 + nn = Chain(Dense(4, 3), Dense(3, 2)) + ps, st = Lux.setup(rng, nn) + + l2reg(p) = sum(abs2, ComponentArray(p)) + @test_nowarn Zygote.gradient(l2reg, ps) +end + +@testitem "_init_hidden_state" setup=[SharedTestSetup] begin + rng = get_stable_rng(12345) + + @testset "$mode" for (mode, aType, device, ongpu) in MODES + rnn = RNNCell(3 => 5; init_state=Lux.zeros32) + x = randn(rng, Float32, 3, 2, 2) + @test Lux._init_hidden_state(rng, rnn, view(device(x), :, 1, :)) == + aType(zeros(Float32, 5, 2)) + end +end + +@testitem "FP Conversions" setup=[SharedTestSetup] begin + rng = get_stable_rng(12345) + + @testset "$mode" for (mode, aType, device, ongpu) in MODES + model = Chain(Dense(1 => 16, relu), Chain(Dense(16 => 1), Dense(1 => 1)), + BatchNorm(1); disable_optimizations=true) + + for (f, ftype) in zip((f16, f32, f64), (Float16, Float32, Float64)) + ps, st = Lux.setup(rng, model) |> device |> f + + @test eltype(ps.layer_1.weight) == ftype + @test eltype(ps.layer_1.bias) == ftype + @test eltype(ps.layer_2.layer_1.weight) == ftype + @test eltype(ps.layer_2.layer_1.bias) == ftype + @test eltype(ps.layer_2.layer_2.weight) == ftype + @test eltype(ps.layer_2.layer_2.bias) == ftype + @test eltype(ps.layer_3.scale) == ftype + @test eltype(ps.layer_3.bias) == ftype + @test st.layer_1 == NamedTuple() + @test st.layer_2.layer_1 == NamedTuple() + @test st.layer_2.layer_2 == NamedTuple() + @test eltype(st.layer_3.running_mean) == ftype + @test eltype(st.layer_3.running_var) == ftype + @test typeof(st.layer_3.training) == Val{true} + end + end +end