Skip to content

Commit

Permalink
fix gpu tests
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello committed Nov 17, 2024
1 parent c08d866 commit 764fa14
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 4 deletions.
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
name = "Flux"
uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c"
version = "0.15-DEV"
version = "0.15.0-DEV"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
Expand Down Expand Up @@ -46,6 +47,7 @@ Adapt = "4"
CUDA = "5"
ChainRulesCore = "1.12"
Compat = "4.10.0"
ConstructionBase = "1.5.8"
Enzyme = "0.13"
Functors = "0.5"
MLDataDevices = "1.4.2"
Expand Down
8 changes: 8 additions & 0 deletions src/functor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,10 @@ julia> m.bias
"""
cpu(x) = cpu_device()(x)

# TODO remove after https://github.com/LuxDL/Lux.jl/pull/1089
ChainRulesCore.@non_differentiable cpu_device()


# Remove when
# https://github.com/JuliaPackaging/Preferences.jl/issues/39
# is resolved
Expand Down Expand Up @@ -149,6 +153,10 @@ CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}
"""
gpu(x) = gpu_device()(x)

# TODO remove after https://github.com/LuxDL/Lux.jl/pull/1089
ChainRulesCore.@non_differentiable gpu_device()
ChainRulesCore.@non_differentiable gpu_device(::Any)

# Precision

struct FluxEltypeAdaptor{T} end
Expand Down
3 changes: 3 additions & 0 deletions src/layers/upsample.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ struct Upsample{mode, S, T}
size::T
end

Functors.@leaf Upsample # mark leaf since the constructor is not compatible with Functors
# by default but we don't need to recurse into it

function Upsample(mode::Symbol = :nearest; scale = nothing, size = nothing)
mode in [:nearest, :bilinear, :trilinear] ||
throw(ArgumentError("mode=:$mode is not supported."))
Expand Down
5 changes: 2 additions & 3 deletions test/ext_cuda/layers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,10 @@

# generic movement tests
@testset "Basic GPU Movement" begin
@test gradient(x -> sum(gpu(x)), rand(3,3)) isa Tuple
@test gradient(x -> sum(cpu(x)), gpu(rand(3,3))) isa Tuple
@test gradient(x -> sum(gpu(x)), rand(Float32, 3, 3))[1] isa Matrix{Float32}
@test gradient(x -> sum(cpu(x)), gpu(rand(Float32, 3, 3)))[1] isa CuMatrix{Float32}
end


const ACTIVATIONS = [identity, tanh]

function gpu_gradtest(name::String, layers::Vector, x_cpu, args...;
Expand Down

0 comments on commit 764fa14

Please sign in to comment.