diff --git a/test/ext_amdgpu/basic.jl b/test/ext_amdgpu/basic.jl index 163064c072..8962c7bedb 100644 --- a/test/ext_amdgpu/basic.jl +++ b/test/ext_amdgpu/basic.jl @@ -1,5 +1,3 @@ -@test Flux.AMDGPU_LOADED[] - @testset "Basic GPU movement" begin @test Flux.gpu(rand(Float64, 16)) isa ROCArray{Float32, 1} @test Flux.gpu(rand(Float64, 16, 16)) isa ROCArray{Float32, 2} diff --git a/test/ext_metal/basic.jl b/test/ext_metal/basic.jl index 97ba8066a3..9febd8e455 100644 --- a/test/ext_metal/basic.jl +++ b/test/ext_metal/basic.jl @@ -1,5 +1,3 @@ -@test Flux.METAL_LOADED[] - @testset "Basic GPU movement" begin @test Flux.gpu(rand(Float64, 16)) isa MtlArray{Float32, 1} @test Flux.gpu(rand(Float64, 16, 16)) isa MtlArray{Float32, 2} diff --git a/test/functors.jl b/test/functors.jl index 734eadc574..111da50ea8 100644 --- a/test/functors.jl +++ b/test/functors.jl @@ -1,5 +1,5 @@ x = rand(Float32, 10, 10) -if !(Flux.CUDA_LOADED[] || Flux.AMDGPU_LOADED[] || Flux.METAL_LOADED[]) +if gpu_device() isa CPUDevice @test x === gpu(x) end