diff --git a/.gitignore b/.gitignore index 1e8a6b3b8a..21bd9e6e68 100644 --- a/.gitignore +++ b/.gitignore @@ -10,3 +10,4 @@ Manifest.toml LocalPreferences.toml .DS_Store docs/mymodel.bson +prova.jl diff --git a/src/distributed/public_api.jl b/src/distributed/public_api.jl index 26d321814d..d5d10e42a4 100644 --- a/src/distributed/public_api.jl +++ b/src/distributed/public_api.jl @@ -132,7 +132,7 @@ Backend Agnostic API to perform an allreduce operation on the given buffer `send workers. """ function allreduce!(backend::AbstractFluxDistributedBackend, sendrecvbuf, op::F) where {F} - return __allreduce!(backend, sendrecvbuf, op, get_device()) + return __allreduce!(backend, sendrecvbuf, op, gpu_device()) end function allreduce!( diff --git a/test/Project.toml b/test/Project.toml index b5fcda422b..99f1d7175a 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -21,4 +21,4 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] FiniteDifferences = "0.12" Tracker = "0.2.33" -Enzyme = "0.12.4" +Enzyme = "0.13" diff --git a/test/ext_amdgpu/basic.jl b/test/ext_amdgpu/basic.jl index 831b577d48..163064c072 100644 --- a/test/ext_amdgpu/basic.jl +++ b/test/ext_amdgpu/basic.jl @@ -19,26 +19,27 @@ end end @testset "Chain of Dense layers" begin - m = Chain(Dense(10, 5, tanh), Dense(5, 2), softmax) |> f32 + m = Chain(Dense(10, 5, tanh), Dense(5, 2), softmax) x = rand(Float32, 10, 10) - gpu_autodiff_test(m, x) + test_gradients(m, x, test_gpu=true, compare_finite_diff=false) end @testset "Convolution" begin for conv_type in (Conv, ConvTranspose), nd in 1:3 - m = conv_type(tuple(fill(2, nd)...), 3 => 4) |> f32 + m = conv_type(tuple(fill(2, nd)...), 3 => 4) x = rand(Float32, fill(10, nd)..., 3, 5) + md, xd = Flux.gpu.((m, x)) + y = m(x) # Ensure outputs are the same. - gpu_autodiff_test(m, x; atol=1f-3, checkgrad=false) + @test collect(md(xd)) ≈ y atol=1f-3 # Gradients are flipped as well. - md, xd = Flux.gpu.((m, x)) - gs = gradient(m -> sum(m(x)), m) - gsd = gradient(m -> sum(m(xd)), md) + gs = gradient(m -> sum(m(x)), m)[1] + gsd = gradient(m -> sum(m(xd)), md)[1] dims = ntuple(i -> i, ndims(m.weight) - 2) - @test reverse(gs[1].weight; dims) ≈ Array(gsd[1].weight) atol=1f-2 + @test reverse(gs.weight; dims) ≈ Array(gsd.weight) atol=1f-2 # Movement back to CPU flips weights back. mh = Flux.cpu(md) @@ -52,10 +53,10 @@ end x = rand(Float32, fill(10, nd)..., 3, 5) |> gpu pad = ntuple(i -> i, nd) - m = conv_type(kernel, 3 => 4, pad=pad) |> f32 |> gpu + m = conv_type(kernel, 3 => 4, pad=pad) |> gpu expanded_pad = ntuple(i -> pad[(i - 1) ÷ 2 + 1], 2 * nd) - m_expanded = conv_type(kernel, 3 => 4, pad=expanded_pad) |> f32 |> gpu + m_expanded = conv_type(kernel, 3 => 4, pad=expanded_pad) |> gpu @test size(m(x)) == size(m_expanded(x)) end @@ -74,25 +75,25 @@ end end @testset "Chain(Conv)" begin - m = Chain(Conv((3, 3), 3 => 3)) |> f32 - x = rand(Float32, 10, 10, 3, 2) - gpu_autodiff_test(m, x; atol=1f-3, checkgrad=false) + m = Chain(Conv((3, 3), 3 => 3)) + x = rand(Float32, 5, 5, 3, 2) + test_gradients(m, x, test_gpu=true, compare_finite_diff=false, test_grad_f=false) md = m |> gpu |> cpu @test md[1].weight ≈ m[1].weight atol=1f-3 - m = Chain(ConvTranspose((3, 3), 3 => 3)) |> f32 - x = rand(Float32, 10, 10, 3, 2) - gpu_autodiff_test(m, x; atol=1f-3, checkgrad=false) + m = Chain(ConvTranspose((3, 3), 3 => 3)) + x = rand(Float32, 5, 5, 3, 2) + test_gradients(m, x, test_gpu=true, compare_finite_diff=false, test_grad_f=false) md = m |> gpu |> cpu @test md[1].weight ≈ m[1].weight atol=1f-3 end @testset "Cross-correlation" begin - m = CrossCor((2, 2), 3 => 4) |> f32 - x = rand(Float32, 10, 10, 3, 2) - gpu_autodiff_test(m, x; atol=1f-3) + m = CrossCor((2, 2), 3 => 4) + x = rand(Float32, 5, 5, 3, 2) + test_gradients(m, x, test_gpu=true, compare_finite_diff=false) end @testset "Restructure" begin @@ -132,7 +133,7 @@ end bn = BatchNorm(3, σ) for nd in 1:3 x = rand(Float32, fill(2, nd - 1)..., 3, 4) - gpu_autodiff_test(bn, x; atol=1f-3, allow_nothing=true) + test_gradients(bn, x; test_gpu=true, compare_finite_diff=false) end end diff --git a/test/ext_amdgpu/get_devices.jl b/test/ext_amdgpu/get_devices.jl index 7f4d8ccd7a..24b1d71a38 100644 --- a/test/ext_amdgpu/get_devices.jl +++ b/test/ext_amdgpu/get_devices.jl @@ -17,9 +17,9 @@ x = randn(Float32, 5, 5) cx = x |> amdgpu_device @test cx isa AMDGPU.ROCArray -# moving models to specific NVIDIA devices +# moving models to specific AMDGPU devices for id in 0:(length(AMDGPU.devices()) - 1) - current_amdgpu_device = Flux.get_device("AMDGPU", id) + current_amdgpu_device = gpu_device(id+1) global dense_model = dense_model |> current_amdgpu_device @test dense_model.weight isa AMDGPU.ROCArray diff --git a/test/ext_amdgpu/runtests.jl b/test/ext_amdgpu/runtests.jl index 9027a31f76..ec779dedea 100644 --- a/test/ext_amdgpu/runtests.jl +++ b/test/ext_amdgpu/runtests.jl @@ -2,9 +2,6 @@ @assert AMDGPU.functional() AMDGPU.allowscalar(false) -include("../test_utils.jl") -include("test_utils.jl") - @testset "get_devices" begin include("get_devices.jl") end diff --git a/test/ext_amdgpu/test_utils.jl b/test/ext_amdgpu/test_utils.jl deleted file mode 100644 index 3c84f01048..0000000000 --- a/test/ext_amdgpu/test_utils.jl +++ /dev/null @@ -1,15 +0,0 @@ -function check_grad( - g_gpu::ROCArray{Float32}, g_cpu::Array{Float32}; - atol, rtol, allow_nothing::Bool, -) - @test g_cpu ≈ collect(g_gpu) atol=atol rtol=rtol -end - -function check_grad( - g_gpu::ROCArray{Float32}, g_cpu::Zygote.FillArrays.AbstractFill; - atol, rtol, allow_nothing::Bool, -) - @test g_cpu ≈ collect(g_gpu) atol=atol rtol=rtol -end - -check_type(x::ROCArray{Float32}) = true diff --git a/test/ext_cuda/get_devices.jl b/test/ext_cuda/get_devices.jl index 2f4ea3bd98..ae722319a5 100644 --- a/test/ext_cuda/get_devices.jl +++ b/test/ext_cuda/get_devices.jl @@ -8,9 +8,6 @@ dense_model = Dense(2 => 3) # initially lives on CPU weight = copy(dense_model.weight) # store the weight bias = copy(dense_model.bias) # store the bias -cuda_device = Flux.get_device() - -@test typeof(cuda_device) <: Flux.CUDADevice # correctness of data transfer x = randn(5, 5) @@ -30,6 +27,12 @@ for id in 0:(length(CUDA.devices()) - 1) @test isequal(Flux.cpu(dense_model.weight), weight) @test isequal(Flux.cpu(dense_model.bias), bias) end + +# gpu_device remembers the last device selected +# Therefore, we need to reset it to the current cuda device +@test gpu_device().device.handle == length(CUDA.devices()) - 1 +gpu_device(CUDA.device().handle + 1) + # finally move to CPU, and see if things work cdev = cpu_device() dense_model = cdev(dense_model) diff --git a/test/ext_cuda/layers.jl b/test/ext_cuda/layers.jl index 63bcc8b526..cba95cee75 100644 --- a/test/ext_cuda/layers.jl +++ b/test/ext_cuda/layers.jl @@ -10,73 +10,23 @@ @test gradient(x -> sum(cpu(x)), gpu(rand(3,3))) isa Tuple end -# TODO: These layers get into scalar indexing issues. -const BROKEN_LAYERS = Union{} -const ACTIVATIONS = [identity, relu, tanh, - sigmoid, exp, softplus, - elu, selu] +const ACTIVATIONS = [identity, tanh] -function gpu_gradtest(name::String, layers::Vector, x_cpu = nothing, args...; test_cpu = true, test_mode = false) - isnothing(x_cpu) && error("Missing input to test the layers against.") +function gpu_gradtest(name::String, layers::Vector, x_cpu, args...; + test_mode=false, test_grad_x=true, + atol=1e-4, rtol=1e-4) @testset "$name GPU grad tests" begin for layer in layers @testset "$layer Layer GPU grad test" begin # compute output and grad of parameters l_cpu = layer(args...) - l_gpu = l_cpu |> gpu if test_mode testmode!(l_cpu) - testmode!(l_gpu) end - ps_cpu = Flux.params(l_cpu) - y_cpu, back_cpu = pullback(() -> sum(l_cpu(x_cpu)), ps_cpu) - gs_cpu = back_cpu(1f0) - - x_gpu = gpu(x_cpu) - ps_gpu = Flux.params(l_gpu) - - if typeof(l_gpu) <: BROKEN_LAYERS - @test_broken gradient(() -> sum(l_gpu(x_gpu)), ps_gpu) isa Flux.Zygote.Grads - else - y_gpu, back_gpu = pullback(() -> sum(l_gpu(x_gpu)), ps_gpu) - gs_gpu = back_gpu(1f0) # TODO many layers error out when backprop int 1, should fix - - # compute grad of input - xg_cpu = gradient(x -> sum(l_cpu(x)), x_cpu)[1] - xg_gpu = gradient(x -> sum(l_gpu(x)), x_gpu)[1] - - # test - if test_cpu - if layer === GroupedConvTranspose - @test y_gpu ≈ y_cpu rtol=1f-2 atol=1f-3 - else - @test y_gpu ≈ y_cpu rtol=1f-3 atol=1f-3 - end - if isnothing(xg_cpu) - @test isnothing(xg_gpu) - else - if layer === GroupedConvTranspose - @test Array(xg_gpu) ≈ xg_cpu rtol = 2f-2 atol = 1f-3 - else - @test Array(xg_gpu) ≈ xg_cpu rtol = 1f-3 atol = 1f-3 - end - end - end - @test gs_gpu isa Flux.Zygote.Grads - for (p_cpu, p_gpu) in zip(ps_cpu, ps_gpu) - if isnothing(gs_cpu[p_cpu]) - @test isnothing(gs_gpu[p_gpu]) - else - @test gs_gpu[p_gpu] isa CuArray - if test_cpu - @test Array(gs_gpu[p_gpu]) ≈ gs_cpu[p_cpu] rtol=1f-3 atol=1f-3 - end - end - end - end + test_gradients(l_cpu, x_cpu; test_gpu=true, compare_finite_diff=false, test_grad_x, atol, rtol) end end end @@ -97,23 +47,24 @@ for act in ACTIVATIONS ConvTranspose, ConvTransposeNoBias, CrossCor, CrossCorNoBias, DepthwiseConv, DepthwiseConvNoBias] - gpu_gradtest("Convolution with $act", conv_layers, r, (2,2), 1=>3, act, test_cpu = false) + gpu_gradtest("Convolution with $act", conv_layers, r, (2,2), 1=>3, act) groupedconv = [GroupedConv, GroupedConvTranspose] - gpu_gradtest("GroupedConvolution with $act", groupedconv, rand(Float32, 28, 28, 100, 2), (3,3), 100 => 25, act, test_cpu = true) + gpu_gradtest("GroupedConvolution with $act", groupedconv, rand(Float32, 28, 28, 100, 2), (3,3), 100 => 25, act) batch_norm = [BatchNorm, BatchNormNoTrackStats] - gpu_gradtest("BatchNorm 1 with $act", batch_norm, rand(Float32, 28,28,3,4), 3, act, test_cpu = false) #TODO fix errors - gpu_gradtest("BatchNorm 2 with $act", batch_norm, rand(Float32, 5,4), 5, act, test_cpu = true) + gpu_gradtest("BatchNorm 1 with $act", batch_norm, rand(Float32, 28,28,3,4), 3, act, atol=1e-3) + gpu_gradtest("BatchNorm 2 with $act", batch_norm, rand(Float32, 5,4), 5, act, atol=1e-3) batch_norm = [BatchNormNoTrackStats] - gpu_gradtest("BatchNorm 3 with $act (test mode)", batch_norm, rand(Float32, 5,4), 5, act, test_cpu = true, test_mode = true) + gpu_gradtest("BatchNorm 3 with $act (test mode)", batch_norm, rand(Float32, 5,4), 5, act, + test_mode=true, atol=1e-3) instancenorm = [InstanceNorm] - gpu_gradtest("InstanceNorm with $act", instancenorm, r, 1, act, test_cpu = false) + gpu_gradtest("InstanceNorm with $act", instancenorm, r, 1, act) groupnorm = [GroupNorm] - gpu_gradtest("GroupNorm with $act", groupnorm, rand(Float32, 28,28,3,1), 3, 1, act, test_cpu = false) + gpu_gradtest("GroupNorm with $act", groupnorm, rand(Float32, 28,28,3,1), 3, 1, act) end r = rand(Float32, 28, 28, 1, 1) @@ -122,13 +73,13 @@ pooling_layers = [MaxPool, MeanPool] gpu_gradtest("Pooling", pooling_layers, r, (2,2)) adaptive_pooling_layers = [AdaptiveMaxPool, AdaptiveMeanPool] -gpu_gradtest("AdaptivePooling", adaptive_pooling_layers, r, (7,7), test_cpu = false) +gpu_gradtest("AdaptivePooling", adaptive_pooling_layers, r, (7,7)) dropout_layers = [Dropout, AlphaDropout] -gpu_gradtest("Dropout", dropout_layers, r, 0.5f0; test_cpu = false) # dropout is not deterministic +gpu_gradtest("Dropout", dropout_layers, r, 1e-6) # dropout is not deterministic layer_norm = [LayerNorm] -gpu_gradtest("LayerNorm 1", layer_norm, rand(Float32, 28,28,3,4), 28, test_cpu = false) #TODO fix errors +gpu_gradtest("LayerNorm 1", layer_norm, rand(Float32, 28,28,3,4), 28) gpu_gradtest("LayerNorm 2", layer_norm, rand(Float32, 5,4), 5) upsample = [x -> Upsample(scale=x)] @@ -140,32 +91,27 @@ gpu_gradtest("PixelShuffle 2d", pixelshuffle, rand(Float32, 3, 4, 18, 3), 3) gpu_gradtest("PixelShuffle 1d", pixelshuffle, rand(Float32, 3, 18, 3), 3) embedding = [Flux.Embedding] -gpu_gradtest("Embedding", embedding, [1,3,5], 5, 2) -gpu_gradtest("Embedding repeated indices", embedding, [1,3,5,3], 5, 2) -gpu_gradtest("Embedding integer index", embedding, 1, 5, 2) -gpu_gradtest("Embedding 2d index", embedding, [1 2; 3 4], 5, 2) -gpu_gradtest("Embedding OneHotVec index", embedding, OneHotVector(1, 5), 5, 2) -gpu_gradtest("Embedding OneHotMatrix index", embedding, OneHotMatrix([1,2,3], 5), 5, 2) -gpu_gradtest("Embedding OneHotMatrix repeated indices", embedding, OneHotMatrix([1,2,2], 5), 5, 2) +gpu_gradtest("Embedding", embedding, [1,3,5], 5, 2, test_grad_x=false) +gpu_gradtest("Embedding repeated indices", embedding, [1,3,5,3], 5, 2, test_grad_x=false) +gpu_gradtest("Embedding integer index", embedding, 1, 5, 2, test_grad_x=false) +gpu_gradtest("Embedding 2d index", embedding, [1 2; 3 4], 5, 2, test_grad_x=false) +gpu_gradtest("Embedding OneHotVec index", embedding, OneHotVector(1, 5), 5, 2, test_grad_x=false) +gpu_gradtest("Embedding OneHotMatrix index", embedding, OneHotMatrix([1,2,3], 5), 5, 2, test_grad_x=false) +gpu_gradtest("Embedding OneHotMatrix repeated indices", embedding, OneHotMatrix([1,2,2], 5), 5, 2, test_grad_x=false) @testset "function layers" begin - x = rand(Float32, 3,3) - gpu_autodiff_test(x -> sum(Flux.normalise(x; dims=1)), x) - gpu_autodiff_test(x -> sum(Flux.normalise(x; dims=2)), x) - gpu_autodiff_test(x -> sum(Flux.normalise(x)), x) + x = rand(Float32, 3, 3) + test_gradients(x -> sum(Flux.normalise(x; dims=1)), x, test_gpu=true, compare_finite_diff=false) + test_gradients(x -> sum(Flux.normalise(x; dims=2)), x, test_gpu=true, compare_finite_diff=false) + test_gradients(x -> sum(Flux.normalise(x)), x, test_gpu=true, compare_finite_diff=false) end @testset "Zeros mapped for $cl" for cl in (Conv, ConvTranspose, CrossCor, DepthwiseConv) l = cl((2,2), 1=>3, bias = false) |> gpu ip = zeros(Float32, 28,28,1,1) |> gpu - if typeof(l) <: BROKEN_LAYERS - @test_broken sum(l(ip)) ≈ 0.f0 - @test_broken gradient(() -> sum(l(ip)), Flux.params(l)) isa Flux.Zygote.Grads - else - @test sum(l(ip)) ≈ 0.f0 - gs = gradient(() -> sum(l(ip)), Flux.params(l)) - @test l.bias ∉ gs.params - end + @test sum(l(ip)) ≈ 0.f0 + gs = gradient(() -> sum(l(ip)), Flux.params(l)) + @test l.bias ∉ gs.params end @testset "Dense without bias" begin @@ -366,14 +312,6 @@ end @test Array(y_gpu) ≈ y_cpu atol=1e-4 @test Array(α_gpu) ≈ α_cpu atol=1e-4 - gm_cpu, gx_cpu = gradient(mha_cpu, x_cpu) do mha, x - y, α = mha(x) - return sum(y.^2) + sum(α.^2) - end - gm_gpu, gx_gpu = gradient(mha_gpu, x_gpu) do mha, x - y, α = mha(x) - return sum(y.^2) + sum(α.^2) - end - check_grad(gm_gpu, gm_cpu) - check_grad(gx_gpu, gx_cpu) + test_gradients(mha_cpu, x_cpu, loss = o -> sum(o[1].^2) + sum(o[2].^2), + test_gpu=true, compare_finite_diff=false) end diff --git a/test/ext_cuda/losses.jl b/test/ext_cuda/losses.jl index b339b352bb..cf56c7119d 100644 --- a/test/ext_cuda/losses.jl +++ b/test/ext_cuda/losses.jl @@ -27,11 +27,12 @@ y = [1 0 0 0 1 @test focal_loss(x, y) ≈ focal_loss(gpu(x), gpu(y)) @testset "GPU: $loss" for loss in ALL_LOSSES - x = rand(Float32, 3,4) - y = rand(Float32, 3,4) + # let's stay far from the boundaries to avoid problems with finite differences gradients + x = 0.1f0 .+ 0.8f0 .* rand(Float32, 3, 4) + y = 0.1f0 .+ 0.8f0 .* rand(Float32, 3, 4) @test loss(x, y) ≈ loss(gpu(x), gpu(y)) - gpu_autodiff_test(loss, x, y) + test_gradients(loss, x, y, test_gpu=true, test_grad_f=false, compare_finite_diff=false) # Float16 tests @test loss(f16(x), f16(y)) ≈ loss(gpu(f16(x)), gpu(f16(y))) diff --git a/test/ext_cuda/runtests.jl b/test/ext_cuda/runtests.jl index aa1f431fe7..012a62d41a 100644 --- a/test/ext_cuda/runtests.jl +++ b/test/ext_cuda/runtests.jl @@ -7,13 +7,9 @@ using Random, LinearAlgebra, Statistics @assert CUDA.functional() CUDA.allowscalar(false) -# include("../test_utils.jl") -include("test_utils.jl") - @testset "get_devices" begin include("get_devices.jl") end - @testset "cuda" begin include("cuda.jl") end diff --git a/test/ext_cuda/test_utils.jl b/test/ext_cuda/test_utils.jl deleted file mode 100644 index 10a8d0dfdf..0000000000 --- a/test/ext_cuda/test_utils.jl +++ /dev/null @@ -1,4 +0,0 @@ -check_grad(g_gpu::CuArray{Float32}, g_cpu::Array{Float32}; rtol=1e-4, atol=1e-4, allow_nothing::Bool=false) = - @test g_cpu ≈ collect(g_gpu) rtol=rtol atol=atol - -check_type(x::CuArray{Float32}) = true diff --git a/test/ext_metal/basic.jl b/test/ext_metal/basic.jl index 9e4a9ef9cb..97ba8066a3 100644 --- a/test/ext_metal/basic.jl +++ b/test/ext_metal/basic.jl @@ -23,5 +23,5 @@ end m = Chain(Dense(10, 5, tanh), Dense(5, 2), softmax) x = rand(Float32, 10, 10) @test (m|>gpu)(x|>gpu) isa MtlArray{Float32, 2} - gpu_autodiff_test(m, x) + test_gradients(m, x, test_gpu=true, compare_finite_diff=false) end diff --git a/test/ext_metal/runtests.jl b/test/ext_metal/runtests.jl index 8c8af7d896..cb9532390e 100644 --- a/test/ext_metal/runtests.jl +++ b/test/ext_metal/runtests.jl @@ -5,8 +5,6 @@ using Random, Statistics using Zygote Flux.gpu_backend!("Metal") # needs a restart -include("test_utils.jl") - @testset "data movement" begin metal_device = Flux.gpu_device() cdev = cpu_device() diff --git a/test/ext_metal/test_utils.jl b/test/ext_metal/test_utils.jl deleted file mode 100644 index f6ed32a8f4..0000000000 --- a/test/ext_metal/test_utils.jl +++ /dev/null @@ -1,16 +0,0 @@ - -function check_grad( - g_gpu::MtlArray{Float32}, g_cpu::Array{Float32}; - atol, rtol, allow_nothing::Bool, -) - @test g_cpu ≈ collect(g_gpu) atol=atol rtol=rtol -end - -function check_grad( - g_gpu::MtlArray{Float32}, g_cpu::Zygote.FillArrays.AbstractFill; - atol, rtol, allow_nothing::Bool, -) - @test g_cpu ≈ collect(g_gpu) atol=atol rtol=rtol -end - -check_type(x::MtlArray{Float32}) = true diff --git a/test/functors.jl b/test/functors.jl index 280b76d6f0..734eadc574 100644 --- a/test/functors.jl +++ b/test/functors.jl @@ -3,10 +3,7 @@ if !(Flux.CUDA_LOADED[] || Flux.AMDGPU_LOADED[] || Flux.METAL_LOADED[]) @test x === gpu(x) end -dev = Flux.get_device() +dev = Flux.cpu_device() @test typeof(dev) <: Flux.CPUDevice @test dev(x) == x -# specifically getting CPU device -dev = Flux.get_device("CPU") -@test typeof(dev) <: Flux.CPUDevice diff --git a/test/layers/attention.jl b/test/layers/attention.jl index a4c90b36ed..2c6fd7d514 100644 --- a/test/layers/attention.jl +++ b/test/layers/attention.jl @@ -54,12 +54,7 @@ end @testset "gradient" begin - gm, gq = gradient(mha, q) do mha, q - y, α = mha(q) - return sum(y.^2) + sum(α.^2) - end - check_grad_type(gm, mha) - check_grad_type(gq, q) + test_gradients(mha, q, loss = o -> sum(o[1].^2) + sum(o[2].^2)) end end diff --git a/test/runtests.jl b/test/runtests.jl index ef3d67f4d7..f44c4b7758 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -6,17 +6,19 @@ using Random, Statistics, LinearAlgebra using IterTools: ncycle using Zygote using Pkg +using FiniteDifferences: FiniteDifferences +using Functors: fmapstructure_with_path ## Uncomment below to change the default test settings # ENV["FLUX_TEST_AMDGPU"] = "true" -# ENV["FLUX_TEST_CUDA"] = "true" +ENV["FLUX_TEST_CUDA"] = "true" # ENV["FLUX_TEST_METAL"] = "true" -# ENV["FLUX_TEST_CPU"] = "false" +ENV["FLUX_TEST_CPU"] = "false" # ENV["FLUX_TEST_DISTRIBUTED_MPI"] = "true" # ENV["FLUX_TEST_DISTRIBUTED_NCCL"] = "true" ENV["FLUX_TEST_ENZYME"] = "false" # We temporarily disable Enzyme tests since they are failing -include("test_utils.jl") +include("test_utils.jl") # for test_gradients Random.seed!(0) diff --git a/test/test_utils.jl b/test/test_utils.jl index 004d3035ad..f9a6b6655f 100644 --- a/test/test_utils.jl +++ b/test/test_utils.jl @@ -8,113 +8,99 @@ const ALL_LOSSES = [Flux.Losses.mse, Flux.Losses.mae, Flux.Losses.msle, Flux.Losses.dice_coeff_loss, Flux.Losses.poisson_loss, Flux.Losses.hinge_loss, Flux.Losses.squared_hinge_loss, - Flux.Losses.binary_focal_loss, Flux.Losses.focal_loss, Flux.Losses.siamese_contrastive_loss] + Flux.Losses.binary_focal_loss, Flux.Losses.focal_loss, + Flux.Losses.siamese_contrastive_loss] - -function check_grad(g_gpu, g_cpu; - rtol=1e-4, atol=1e-4, - allow_nothing::Bool=false) - allow_nothing && return - @warn "Unsupported types in `check_grad`: $(typeof(g_gpu)), $(typeof(g_cpu))" - @show g_gpu g_cpu - @test false +function finitediff_withgradient(f, x...) + y = f(x...) + # We set a range to avoid domain errors + fdm = FiniteDifferences.central_fdm(5, 1, max_range=1e-2) + return y, FiniteDifferences.grad(fdm, f, x...) end -check_grad(g_gpu::Base.RefValue, g_cpu::Base.RefValue; rtol=1e-4, atol=1e-4, allow_nothing::Bool=false) = - check_grad(g_gpu[], g_cpu[]; rtol, atol, allow_nothing) - -check_grad(g_gpu::Nothing, g_cpu::Nothing; rtol=1e-4, atol=1e-4, allow_nothing::Bool=false) = - @test true - -check_grad(g_gpu::Float32, g_cpu::Float32; rtol=1e-4, atol=1e-4, allow_nothing::Bool=false) = - @test g_cpu ≈ g_gpu rtol=rtol atol=atol - -function check_grad(g_gpu::Tuple, g_cpu::Tuple; rtol=1e-4, atol=1e-4, allow_nothing::Bool=false) - for (v1, v2) in zip(g_gpu, g_cpu) - check_grad(v1, v2; rtol, atol, allow_nothing) - end -end -function check_grad(g_gpu::NamedTuple, g_cpu::NamedTuple; rtol=1e-4, atol=1e-4, allow_nothing::Bool=false) - for ((k1,v1), (k2,v2)) in zip(pairs(g_gpu), pairs(g_cpu)) - @test k1 == k2 - check_grad(v1, v2; rtol, atol, allow_nothing) +function check_equal_leaves(a, b; rtol=1e-4, atol=1e-4) + fmapstructure_with_path(a, b) do kp, x, y + if x isa AbstractArray + @test x ≈ y rtol=rtol atol=atol + elseif x isa Number + @test x ≈ y rtol=rtol atol=atol + end end end -check_type(x) = false -check_type(x::Float32) = true -check_type(x::Array{Float32}) = true -function gpu_autodiff_test( - f_cpu, - xs_cpu::Array{Float32}...; - test_equal=true, +function test_gradients( + f, + xs...; rtol=1e-4, atol=1e-4, - checkgrad::Bool = true, - allow_nothing::Bool = false, - ) - - # Compare CPU & GPU function outputs. - f_gpu = f_cpu |> gpu - xs_gpu = gpu.(xs_cpu) - - y_cpu = f_cpu(xs_cpu...) - y_gpu = f_gpu(xs_gpu...) - @test collect(y_cpu) ≈ collect(y_gpu) atol=atol rtol=rtol - - checkgrad || return - - ### GRADIENT WITH RESPECT TO INPUT ### - - y_cpu, back_cpu = pullback((x...) -> f_cpu(x...), xs_cpu...) - @test check_type(y_cpu) - Δ_cpu = size(y_cpu) == () ? randn(Float32) : randn(Float32, size(y_cpu)) - gs_cpu = back_cpu(Δ_cpu) - - Δ_gpu = Δ_cpu |> gpu - y_gpu, back_gpu = pullback((x...) -> f_gpu(x...), xs_gpu...) - @test check_type(y_gpu) - gs_gpu = back_gpu(Δ_gpu) - - if test_equal - @test collect(y_cpu) ≈ collect(y_gpu) rtol=rtol atol=atol - for (g_gpu, g_cpu) in zip(gs_gpu, gs_cpu) - check_grad(g_gpu, g_cpu; atol, rtol, allow_nothing) - end + test_gpu = false, + test_grad_f = true, + test_grad_x = true, + compare_finite_diff = true, + loss = mean, + ) + + if !test_gpu && !compare_finite_diff + error("You should either compare finite diff vs CPU AD \ + or CPU AD vs GPU AD.") end - ### GRADIENT WITH RESPECT TO f ### - - ps_cpu = Flux.params(f_cpu) - y_cpu, back_cpu = pullback(() -> f_cpu(xs_cpu...), ps_cpu) - gs_cpu = back_cpu(Δ_cpu) - - ps_gpu = Flux.params(f_gpu) - y_gpu, back_gpu = pullback(() -> f_gpu(xs_gpu...), ps_gpu) - gs_gpu = back_gpu(Δ_gpu) + if test_grad_x + # Zygote gradient with respect to input. + y, g = Zygote.withgradient((xs...) -> loss(f(xs...)), xs...) + + if compare_finite_diff + # Cast to Float64 to avoid precision issues. + f64 = f |> Flux.f64 + xs64 = xs .|> Flux.f64 + y_fd, g_fd = finitediff_withgradient((xs...) -> loss(f64(xs...)), xs64...) + @test y ≈ y_fd rtol=rtol atol=atol + check_equal_leaves(g, g_fd; rtol, atol) + end - if test_equal - @test collect(y_cpu) ≈ collect(y_gpu) rtol=rtol atol=atol - @assert length(ps_gpu) == length(ps_cpu) - for (p_gpu, p_cpu) in zip(ps_gpu, ps_cpu) - check_grad(gs_gpu[p_gpu], gs_cpu[p_cpu]; atol, rtol, allow_nothing) + if test_gpu + gpu_dev = gpu_device(force=true) + cpu_dev = cpu_device() + xs_gpu = xs |> gpu_dev + f_gpu = f |> gpu_dev + + # Zygote gradient with respect to input on GPU. + y_gpu, g_gpu = Zygote.withgradient((xs...) -> loss(f_gpu(xs...)), xs_gpu...) + @test get_device(g_gpu) == get_device(xs_gpu) + @test y_gpu ≈ y rtol=rtol atol=atol + check_equal_leaves(g_gpu |> cpu_dev, g; rtol, atol) end end -end - -# check_grad_type checks that the gradient type matches the primal type. - -check_grad_type(g::Nothing, x) = nothing -function check_grad_type(g::AbstractArray{T1}, x::AbstractArray{T2}) where {T1, T2} - @test T1 == T2 - @test size(g) == size(x) -end + if test_grad_f + # Zygote gradient with respect to f. + y, g = Zygote.withgradient(f -> loss(f(xs...)), f) + + if compare_finite_diff + # Use finite differences gradient as a reference. + # y_fd, g_fd = finitediff_withgradient(f -> loss(f(x)), f) + # Cast to Float64 to avoid precision issues. + f64 = f |> Flux.f64 + ps, re = Flux.destructure(f64) + y_fd, g_fd = finitediff_withgradient(ps -> loss(re(ps)(xs...)), ps) + g_fd = (re(g_fd[1]),) + @test y ≈ y_fd rtol=rtol atol=atol + check_equal_leaves(g, g_fd; rtol, atol) + end -function check_grad_type(g::NamedTuple, x::T) where T - for f in fieldnames(T) - check_grad_type(g[f], getfield(x, f)) + if test_gpu + gpu_dev = gpu_device(force=true) + cpu_dev = cpu_device() + xs_gpu = xs |> gpu_dev + f_gpu = f |> gpu_dev + + # Zygote gradient with respect to f on GPU. + y_gpu, g_gpu = Zygote.withgradient(f -> loss(f(xs_gpu...)), f_gpu) + # @test get_device(g_gpu) == get_device(xs_gpu) + @test y_gpu ≈ y rtol=rtol atol=atol + check_equal_leaves(g_gpu |> cpu_dev, g; rtol, atol) + end end end