diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index c59a63c94f..dbce3d11e3 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -1,16 +1,16 @@ steps: - - label: "GPU integration with julia v1.6" - plugins: - - JuliaCI/julia#v1: - # Drop default "registries" directory, so it is not persisted from execution to execution - # Taken from https://github.com/JuliaLang/julia/blob/v1.7.2/.buildkite/pipelines/main/platforms/package_linux.yml#L11-L12 - persist_depot_dirs: packages,artifacts,compiled - version: "1.6" - - JuliaCI/julia-test#v1: ~ - agents: - queue: "juliagpu" - cuda: "*" - timeout_in_minutes: 60 + # - label: "GPU integration with julia v1.9" + # plugins: + # - JuliaCI/julia#v1: + # # Drop default "registries" directory, so it is not persisted from execution to execution + # # Taken from https://github.com/JuliaLang/julia/blob/v1.7.2/.buildkite/pipelines/main/platforms/package_linux.yml#L11-L12 + # persist_depot_dirs: packages,artifacts,compiled + # version: "1.9" + # - JuliaCI/julia-test#v1: ~ + # agents: + # queue: "juliagpu" + # cuda: "*" + # timeout_in_minutes: 60 - label: "GPU integration with julia v1" plugins: diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index fe97c0e1c0..e377282549 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -17,7 +17,7 @@ jobs: fail-fast: false matrix: version: - - '1.6' # Replace this with the minimum Julia version that your package supports. + # - '1.9' # Uncomment when 1.10 is out. Replace this with the minimum Julia version that your package supports. - '1' os: [ubuntu-latest] arch: [x64] @@ -47,17 +47,17 @@ jobs: - uses: julia-actions/julia-buildpkg@v1 - name: "Run test without coverage report" uses: julia-actions/julia-runtest@v1 - if: ${{ !contains(fromJson('["1", "1.6"]'), matrix.version) || matrix.os != 'ubuntu-latest' }} + if: ${{ !contains(fromJson('["1", "1.9"]'), matrix.version) || matrix.os != 'ubuntu-latest' }} with: coverage: false - name: "Run test with coverage report" uses: julia-actions/julia-runtest@v1 - if: contains(fromJson('["1", "1.6"]'), matrix.version) && matrix.os == 'ubuntu-latest' + if: contains(fromJson('["1", "1.9"]'), matrix.version) && matrix.os == 'ubuntu-latest' - uses: julia-actions/julia-processcoverage@v1 - if: contains(fromJson('["1", "1.6"]'), matrix.version) && matrix.os == 'ubuntu-latest' + if: contains(fromJson('["1", "1.9"]'), matrix.version) && matrix.os == 'ubuntu-latest' - uses: codecov/codecov-action@v3 - if: contains(fromJson('["1", "1.6"]'), matrix.version) && matrix.os == 'ubuntu-latest' + if: contains(fromJson('["1", "1.9"]'), matrix.version) && matrix.os == 'ubuntu-latest' with: file: lcov.info diff --git a/NEWS.md b/NEWS.md index b31eebac39..af3e0de984 100644 --- a/NEWS.md +++ b/NEWS.md @@ -2,6 +2,12 @@ See also [github's page](https://github.com/FluxML/Flux.jl/releases) for a complete list of PRs merged before each release. +## v0.14.0 +* Flux now requires julia v1.9 or later. +* CUDA.jl is not a hard dependency anymore. CUDA support is now provided through the extension mechanism. In order to unlock the CUDA +functionalities user are required to load CUDA, e.g. with `using CUDA`. +The package `cuDNN.jl` also needs to be installed in the environment. + ## v0.13.17 * Apple's Metal GPU acceleration preliminary support via the extension mechanism. diff --git a/Project.toml b/Project.toml index 38285573a5..3768eb7308 100644 --- a/Project.toml +++ b/Project.toml @@ -4,14 +4,12 @@ version = "0.13.17" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" -CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" -NNlibCUDA = "a00861dc-f156-4864-bf3c-e6376f28a68d" OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" Preferences = "21216c6a-2e73-6563-6e65-726566657250" @@ -22,41 +20,44 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" -cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" [weakdeps] AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" +CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" Metal = "dde4c033-4e86-420c-a63e-0dd931031962" [extensions] -AMDGPUExt = "AMDGPU" +FluxAMDGPUExt = "AMDGPU" +FluxCUDAExt = "CUDA" +FluxCUDAcuDNNExt = ["CUDA", "cuDNN"] FluxMetalExt = "Metal" [compat] AMDGPU = "0.4.13" Adapt = "3.0" -CUDA = "3, 4" +CUDA = "4" ChainRulesCore = "1.12" -Functors = "0.3, 0.4" -MLUtils = "0.2, 0.3.1, 0.4" +Functors = "0.4" +MLUtils = "0.4" MacroTools = "0.5" Metal = "0.4" -NNlib = "0.8.19" -NNlibCUDA = "0.2.6" -OneHotArrays = "0.1, 0.2" +NNlib = "0.9.1" +OneHotArrays = "0.2.4" Optimisers = "0.2.12" Preferences = "1" ProgressLogging = "0.1" -Reexport = "0.2, 1.0" -SpecialFunctions = "1.8.2, 2.1.2" +Reexport = "1.0" +SpecialFunctions = "2.1.2" Zygote = "0.6.49" cuDNN = "1" -julia = "1.6" +julia = "1.9" [extras] AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0" ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" +CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e" @@ -64,8 +65,9 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Metal = "dde4c033-4e86-420c-a63e-0dd931031962" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" [targets] test = ["Test", "Documenter", "IterTools", "LinearAlgebra", - "FillArrays", "ComponentArrays", "BSON", - "Pkg"] + "FillArrays", "ComponentArrays", "BSON", "Pkg", + "CUDA", "cuDNN", "Metal"] diff --git a/cuda.jl b/cuda.jl new file mode 100644 index 0000000000..cb47466120 --- /dev/null +++ b/cuda.jl @@ -0,0 +1,9 @@ +using Flux, CUDA + +BN = BatchNorm(3) |> gpu; +x = randn(2, 2, 3, 4) |> gpu; + +NNlib.batchnorm(BN.γ, BN.β, x, BN.μ, BN.σ², BN.momentum; + alpha=1, beta=0, eps=BN.ϵ, + training=Flux._isactive(BN, x)) + diff --git a/docs/src/index.md b/docs/src/index.md index 833c85e5e8..48364f74a8 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -8,7 +8,7 @@ Flux is a library for machine learning. It comes "batteries-included" with many ### Installation -Download [Julia 1.6](https://julialang.org/downloads/) or later, preferably the current stable release. You can add Flux using Julia's package manager, by typing `] add Flux` in the Julia prompt. This will automatically install several other packages, including [CUDA.jl](https://github.com/JuliaGPU/CUDA.jl) for Nvidia GPU support. +Download [Julia 1.9](https://julialang.org/downloads/) or later, preferably the current stable release. You can add Flux using Julia's package manager, by typing `] add Flux` in the Julia prompt. This will automatically install several other packages, including [CUDA.jl](https://github.com/JuliaGPU/CUDA.jl) for Nvidia GPU support. ### Learning Flux diff --git a/ext/AMDGPUExt/AMDGPUExt.jl b/ext/FluxAMDGPUExt/FluxAMDGPUExt.jl similarity index 78% rename from ext/AMDGPUExt/AMDGPUExt.jl rename to ext/FluxAMDGPUExt/FluxAMDGPUExt.jl index a8c768f332..0fbd8a04a4 100644 --- a/ext/AMDGPUExt/AMDGPUExt.jl +++ b/ext/FluxAMDGPUExt/FluxAMDGPUExt.jl @@ -1,9 +1,9 @@ -module AMDGPUExt +module FluxAMDGPUExt import ChainRulesCore import ChainRulesCore: NoTangent import Flux -import Flux: FluxCPUAdaptor, FluxAMDAdaptor, _amd, _isleaf, adapt_storage, fmap +import Flux: FluxCPUAdaptor, FluxAMDAdaptor, _amd, adapt_storage, fmap import Flux: DenseConvDims, Conv, ConvTranspose, conv, conv_reshape_bias import NNlib @@ -13,10 +13,14 @@ using Random using Zygote const MIOPENFloat = AMDGPU.MIOpen.MIOPENFloat + +# Set to boolean on the first call to check_use_amdgpu const USE_AMDGPU = Ref{Union{Nothing, Bool}}(nothing) function check_use_amdgpu() - isnothing(USE_AMDGPU[]) || return + if !isnothing(USE_AMDGPU[]) + return + end USE_AMDGPU[] = AMDGPU.functional() if USE_AMDGPU[] @@ -25,12 +29,13 @@ function check_use_amdgpu() end else @info """ - The AMDGPU function is being called but the AMDGPU is not functional. + The AMDGPU function is being called but AMDGPU.jl is not functional. Defaulting back to the CPU. (No action is required if you want to run on the CPU). """ maxlog=1 end return end + ChainRulesCore.@non_differentiable check_use_amdgpu() include("functor.jl") diff --git a/ext/AMDGPUExt/batchnorm.jl b/ext/FluxAMDGPUExt/batchnorm.jl similarity index 100% rename from ext/AMDGPUExt/batchnorm.jl rename to ext/FluxAMDGPUExt/batchnorm.jl diff --git a/ext/AMDGPUExt/conv.jl b/ext/FluxAMDGPUExt/conv.jl similarity index 100% rename from ext/AMDGPUExt/conv.jl rename to ext/FluxAMDGPUExt/conv.jl diff --git a/ext/AMDGPUExt/functor.jl b/ext/FluxAMDGPUExt/functor.jl similarity index 98% rename from ext/AMDGPUExt/functor.jl rename to ext/FluxAMDGPUExt/functor.jl index 27327bfebb..dc3d3cbcce 100644 --- a/ext/AMDGPUExt/functor.jl +++ b/ext/FluxAMDGPUExt/functor.jl @@ -42,7 +42,7 @@ _conv_basetype(::ConvTranspose) = ConvTranspose Flux._isleaf(::AMD_CONV) = true -_exclude(x) = _isleaf(x) +_exclude(x) = Flux._isleaf(x) _exclude(::CPU_CONV) = true function _amd(x) diff --git a/ext/FluxCUDAExt/FluxCUDAExt.jl b/ext/FluxCUDAExt/FluxCUDAExt.jl new file mode 100644 index 0000000000..9f0dae1aa9 --- /dev/null +++ b/ext/FluxCUDAExt/FluxCUDAExt.jl @@ -0,0 +1,49 @@ +module FluxCUDAExt + +using Flux +import Flux: _cuda +using Flux: FluxCPUAdaptor, FluxCUDAAdaptor, fmap +using CUDA +using NNlib +using Zygote +using ChainRulesCore +using Random +using Adapt +import Adapt: adapt_storage + + +const USE_CUDA = Ref{Union{Nothing, Bool}}(nothing) + +function check_use_cuda() + if !isnothing(USE_CUDA[]) + return + end + + USE_CUDA[] = CUDA.functional() + if !USE_CUDA[] + @info """ + The CUDA function is being called but CUDA.jl is not functional. + Defaulting back to the CPU. (No action is required if you want to run on the CPU). + """ maxlog=1 + end + return +end + +ChainRulesCore.@non_differentiable check_use_cuda() + +include("functor.jl") + +function __init__() + Flux.CUDA_LOADED[] = true + + try + Base.require(Main, :cuDNN) + catch + @warn """Package cuDNN not found in current path. + - Run `import Pkg; Pkg.add(\"cuDNN\")` to install the cuDNN package, then restart julia. + - If cuDNN is not installed, some Flux functionalities will not be available when running on the GPU. + """ + end +end + +end diff --git a/ext/FluxCUDAExt/functor.jl b/ext/FluxCUDAExt/functor.jl new file mode 100644 index 0000000000..347cfce372 --- /dev/null +++ b/ext/FluxCUDAExt/functor.jl @@ -0,0 +1,32 @@ + +adapt_storage(to::FluxCUDAAdaptor, x) = CUDA.cu(x) +adapt_storage(to::FluxCUDAAdaptor, x::Zygote.FillArrays.AbstractFill) = CUDA.cu(collect(x)) +adapt_storage(to::FluxCUDAAdaptor, x::Random.TaskLocalRNG) = CUDA.default_rng() +adapt_storage(to::FluxCUDAAdaptor, x::CUDA.RNG) = x +adapt_storage(to::FluxCUDAAdaptor, x::AbstractRNG) = + error("Cannot map RNG of type $(typeof(x)) to GPU. GPU execution only supports Random.default_rng().") + +# TODO: figure out the correct design for OneElement +adapt_storage(to::FluxCUDAAdaptor, x::Zygote.OneElement) = CUDA.cu(collect(x)) + +adapt_storage(to::FluxCPUAdaptor, x::T) where T <: CUDA.CUSPARSE.CUDA.CUSPARSE.AbstractCuSparseMatrix = adapt(Array, x) +adapt_storage(to::FluxCPUAdaptor, x::CUDA.RNG) = Random.default_rng() + +function ChainRulesCore.rrule(::typeof(Adapt.adapt_storage), to::FluxCPUAdaptor, x::CUDA.AbstractGPUArray) + adapt_storage(to, x), dx -> (NoTangent(), NoTangent(), adapt_storage(FluxCUDAAdaptor(), unthunk(dx))) +end + +ChainRulesCore.rrule(::typeof(adapt), a::FluxCPUAdaptor, x::AnyCuArray) = + adapt(a, x), Δ -> (NoTangent(), NoTangent(), adapt(FluxCUDAAdaptor(), unthunk(Δ))) + +ChainRulesCore.rrule(::typeof(adapt), a::FluxCUDAAdaptor, x::AnyCuArray) = + adapt(a, x), Δ -> (NoTangent(), NoTangent(), Δ) + +ChainRulesCore.rrule(::typeof(adapt), a::FluxCUDAAdaptor, x::AbstractArray) = + adapt(a, x), Δ -> (NoTangent(), NoTangent(), adapt(FluxCPUAdaptor(), unthunk(Δ))) + +function _cuda(x) + check_use_cuda() + USE_CUDA[] || return x + fmap(x -> Adapt.adapt(FluxCUDAAdaptor(), x), x; exclude=Flux._isleaf) +end diff --git a/ext/FluxCUDAExt/utils.jl b/ext/FluxCUDAExt/utils.jl new file mode 100644 index 0000000000..f6ba3751ad --- /dev/null +++ b/ext/FluxCUDAExt/utils.jl @@ -0,0 +1 @@ +rng_from_array(::CuArray) = CUDA.default_rng() \ No newline at end of file diff --git a/src/cuda/cudnn.jl b/ext/FluxCUDAcuDNNExt/FluxCUDAcuDNNExt.jl similarity index 50% rename from src/cuda/cudnn.jl rename to ext/FluxCUDAcuDNNExt/FluxCUDAcuDNNExt.jl index 24226ab4b1..1f808709c2 100644 --- a/src/cuda/cudnn.jl +++ b/ext/FluxCUDAcuDNNExt/FluxCUDAcuDNNExt.jl @@ -1,4 +1,24 @@ -import NNlibCUDA: batchnorm, ∇batchnorm +module FluxCUDAcuDNNExt + +using Flux +using CUDA, cuDNN +using NNlib + +const USE_CUDNN = Ref{Union{Nothing, Bool}}(nothing) + +function check_use_cudnn() + if !isnothing(USE_CUDNN[]) + return + end + + USE_CUDNN[] = cuDNN.has_cudnn() + if !USE_CUDNN[] + @warn """ + cuDNN.jl didn't found libcudnn, some Flux functionality will not be available. + """ maxlog=1 + end + return +end function (BN::Flux.BatchNorm)(x::Union{CuArray{T,2},CuArray{T,4},CuArray{T,5}}, cache=nothing) where T<:Union{Float32, Float64} @@ -6,16 +26,12 @@ function (BN::Flux.BatchNorm)(x::Union{CuArray{T,2},CuArray{T,4},CuArray{T,5}}, @assert BN.affine "BatchNorm: only affine=true supported on gpu" @assert BN.track_stats "BatchNorm: only track_stats=true supported on gpu" @assert length(BN.β) == size(x, ndims(x)-1) "BatchNorm: input has wrong number of channels" - return BN.λ.(batchnorm(BN.γ, BN.β, x, BN.μ, BN.σ², BN.momentum; + + return BN.λ.(NNlib.batchnorm(BN.γ, BN.β, x, BN.μ, BN.σ², BN.momentum; cache=cache, alpha=1, beta=0, eps=BN.ϵ, training=Flux._isactive(BN, x))) end -function ChainRulesCore.rrule(::typeof(batchnorm), g, b, x, running_mean, running_var, momentum; kw...) - y = batchnorm(g, b, x, running_mean, running_var, momentum; kw...) - function batchnorm_pullback(Δ) - grad = ∇batchnorm(g, b, x, unthunk(Δ), running_mean, running_var, momentum; kw...) - (NoTangent(), grad..., NoTangent(), NoTangent(), NoTangent()) - end - y, batchnorm_pullback + + end diff --git a/src/Flux.jl b/src/Flux.jl index 5ec715dae2..132231c9ad 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -45,10 +45,6 @@ include("train.jl") using .Train using .Train: setup -using CUDA -import cuDNN -const use_cuda = Ref{Union{Nothing,Bool}}(nothing) - using Adapt, Functors, OneHotArrays include("utils.jl") include("functor.jl") @@ -75,6 +71,5 @@ include("deprecations.jl") include("losses/Losses.jl") using .Losses -include("cuda/cuda.jl") end # module diff --git a/src/cuda/cuda.jl b/src/cuda/cuda.jl deleted file mode 100644 index 6e18a066af..0000000000 --- a/src/cuda/cuda.jl +++ /dev/null @@ -1,11 +0,0 @@ -module CUDAint - -using ..CUDA - -import ..Flux: Flux -using ChainRulesCore -import NNlib, NNlibCUDA - -include("cudnn.jl") - -end diff --git a/src/functor.jl b/src/functor.jl index a371a1e130..7e4d552753 100644 --- a/src/functor.jl +++ b/src/functor.jl @@ -133,51 +133,22 @@ end # From @macroexpand Zygote.@non_differentiable params(m...) and https://github.com/FluxML/Zygote.jl/pull/1248 Zygote._pullback(::Zygote.Context{true}, ::typeof(params), m...) = params(m), _ -> nothing -struct FluxCUDAAdaptor end -adapt_storage(to::FluxCUDAAdaptor, x) = CUDA.cu(x) -adapt_storage(to::FluxCUDAAdaptor, x::Zygote.FillArrays.AbstractFill) = CUDA.cu(collect(x)) -if VERSION >= v"1.7" - adapt_storage(to::FluxCUDAAdaptor, x::Random.TaskLocalRNG) = CUDA.default_rng() -else - adapt_storage(to::FluxCUDAAdaptor, x::Random._GLOBAL_RNG) = CUDA.default_rng() -end -adapt_storage(to::FluxCUDAAdaptor, x::CUDA.RNG) = x -adapt_storage(to::FluxCUDAAdaptor, x::AbstractRNG) = - error("Cannot map RNG of type $(typeof(x)) to GPU. GPU execution only supports Random.default_rng().") - -# TODO: figure out the correct design for OneElement -adapt_storage(to::FluxCUDAAdaptor, x::Zygote.OneElement) = CUDA.cu(collect(x)) - struct FluxCPUAdaptor end # define rules for handling structured arrays adapt_storage(to::FluxCPUAdaptor, x::AbstractArray) = adapt(Array, x) adapt_storage(to::FluxCPUAdaptor, x::AbstractRange) = x adapt_storage(to::FluxCPUAdaptor, x::Zygote.FillArrays.AbstractFill) = x -adapt_storage(to::FluxCPUAdaptor, x::T) where T <: CUDA.CUSPARSE.CUDA.CUSPARSE.AbstractCuSparseMatrix = adapt(Array, x) adapt_storage(to::FluxCPUAdaptor, x::Zygote.OneElement) = x adapt_storage(to::FluxCPUAdaptor, x::AbstractSparseArray) = x -adapt_storage(to::FluxCPUAdaptor, x::CUDA.RNG) = Random.default_rng() adapt_storage(to::FluxCPUAdaptor, x::AbstractRNG) = x -function ChainRulesCore.rrule(::typeof(Adapt.adapt_storage), to::FluxCPUAdaptor, x::CUDA.AbstractGPUArray) - adapt_storage(to, x), dx -> (NoTangent(), NoTangent(), adapt_storage(FluxCUDAAdaptor(), unthunk(dx))) -end # The following rrules for adapt are here to avoid double wrapping issues # as seen in https://github.com/FluxML/Flux.jl/pull/2117#discussion_r1027321801 - -ChainRulesCore.rrule(::typeof(adapt), a::FluxCPUAdaptor, x::AnyCuArray) = - adapt(a, x), Δ -> (NoTangent(), NoTangent(), adapt(FluxCUDAAdaptor(), unthunk(Δ))) - ChainRulesCore.rrule(::typeof(adapt), a::FluxCPUAdaptor, x::AbstractArray) = adapt(a, x), Δ -> (NoTangent(), NoTangent(), Δ) -ChainRulesCore.rrule(::typeof(adapt), a::FluxCUDAAdaptor, x::AnyCuArray) = - adapt(a, x), Δ -> (NoTangent(), NoTangent(), Δ) - -ChainRulesCore.rrule(::typeof(adapt), a::FluxCUDAAdaptor, x::AbstractArray) = - adapt(a, x), Δ -> (NoTangent(), NoTangent(), adapt(FluxCPUAdaptor(), unthunk(Δ))) # CPU/GPU movement conveniences @@ -286,26 +257,6 @@ function gpu(x) end end -function gpu(::FluxCUDAAdaptor, x) - check_use_cuda() - use_cuda[] ? fmap(x -> Adapt.adapt(FluxCUDAAdaptor(), x), x; exclude = _isleaf) : x -end - -function check_use_cuda() - if use_cuda[] === nothing - use_cuda[] = CUDA.functional() - if use_cuda[] && !cuDNN.has_cudnn() - @warn "CUDA.jl found cuda, but did not find libcudnn. Some functionality will not be available." maxlog=1 - end - if !(use_cuda[]) - @info """The GPU function is being called but the GPU is not accessible. - Defaulting back to the CPU. (No action is required if you want to run on the CPU).""" maxlog=1 - end - end -end - -ChainRulesCore.@non_differentiable check_use_cuda() - # Precision struct FluxEltypeAdaptor{T} end @@ -375,7 +326,27 @@ f16(m) = _paramtype(Float16, m) @functor Cholesky trainable(c::Cholesky) = () -# AMDGPU extension. +# CUDA extension. ######## + +struct FluxCUDAAdaptor end + +const CUDA_LOADED = Ref{Bool}(false) + +function gpu(::FluxCUDAAdaptor, x) + if CUDA_LOADED[] + return _cuda(x) + else + @info """ + The CUDA functionality is being called but + `CUDA.jl` must be loaded to access it. + Add `using CUDA` or `import CUDA` to your code. + """ maxlog=1 + end +end + +function _cuda end + +# AMDGPU extension. ######## struct FluxAMDAdaptor end @@ -386,15 +357,16 @@ function gpu(::FluxAMDAdaptor, x) return _amd(x) else @info """ - The AMDGPU functionality is being called via `Flux.amd` but - `AMDGPU` must be loaded to access it. + The AMDGPU functionality is being called but + `AMDGPU.jl` must be loaded to access it. + Add `using AMDGPU` or `import AMDGPU` to your code. """ maxlog=1 end end function _amd end -# Metal extension. +# Metal extension. ###### struct FluxMetalAdaptor end @@ -413,6 +385,7 @@ end function _metal end +################################ """ gpu(data::DataLoader) diff --git a/src/losses/Losses.jl b/src/losses/Losses.jl index 7f9fcbe429..34315baadd 100644 --- a/src/losses/Losses.jl +++ b/src/losses/Losses.jl @@ -5,7 +5,6 @@ using Zygote using Zygote: @adjoint using ChainRulesCore using ..Flux: ofeltype, epseltype, _greek_ascii_depwarn -using CUDA using NNlib: logsoftmax, logσ, ctc_loss, ctc_alpha, ∇ctc_loss import Base.Broadcast: broadcasted diff --git a/src/utils.jl b/src/utils.jl index d7e733ea59..00d8df0cba 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -44,7 +44,6 @@ The current defaults are: - Julia version is >= 1.7: `Random.default_rng()` """ rng_from_array(::AbstractArray) = default_rng_value() -rng_from_array(::CuArray) = CUDA.default_rng() @non_differentiable rng_from_array(::Any) diff --git a/test/cuda/runtests.jl b/test/cuda/runtests.jl deleted file mode 100644 index 5e25829999..0000000000 --- a/test/cuda/runtests.jl +++ /dev/null @@ -1,19 +0,0 @@ -using Flux, Test, CUDA -using Zygote -using Zygote: pullback -using Random, LinearAlgebra, Statistics - -@info "Testing GPU Support" -CUDA.allowscalar(false) - -include("cuda.jl") -include("losses.jl") -include("layers.jl") - -if CUDA.functional() - @info "Testing Flux/CUDNN" - include("cudnn.jl") - include("curnn.jl") -else - @warn "CUDNN unavailable, not testing GPU DNN support" -end diff --git a/test/amd/basic.jl b/test/ext_amdgpu/basic.jl similarity index 100% rename from test/amd/basic.jl rename to test/ext_amdgpu/basic.jl diff --git a/test/ext_amdgpu/runtests.jl b/test/ext_amdgpu/runtests.jl new file mode 100644 index 0000000000..ec4f04663f --- /dev/null +++ b/test/ext_amdgpu/runtests.jl @@ -0,0 +1,10 @@ + +@assert AMDGPU.functional() +AMDGPU.allowscalar(false) + +include("../test_utils.jl") +include("test_utils.jl") + +@testset "Basic" begin + include("basic.jl") +end diff --git a/test/amd/runtests.jl b/test/ext_amdgpu/test_utils.jl similarity index 88% rename from test/amd/runtests.jl rename to test/ext_amdgpu/test_utils.jl index fa3f22d2ec..3c84f01048 100644 --- a/test/amd/runtests.jl +++ b/test/ext_amdgpu/test_utils.jl @@ -13,7 +13,3 @@ function check_grad( end check_type(x::ROCArray{Float32}) = true - -@testset "Basic" begin - include("basic.jl") -end diff --git a/test/ctc-gpu.jl b/test/ext_cuda/ctc.jl similarity index 100% rename from test/ctc-gpu.jl rename to test/ext_cuda/ctc.jl diff --git a/test/cuda/cuda.jl b/test/ext_cuda/cuda.jl similarity index 99% rename from test/cuda/cuda.jl rename to test/ext_cuda/cuda.jl index c42baa7076..b52fa6c296 100644 --- a/test/cuda/cuda.jl +++ b/test/ext_cuda/cuda.jl @@ -1,5 +1,4 @@ using Flux, Test -using Flux.CUDA using Flux: cpu, gpu using Statistics: mean using LinearAlgebra: I, cholesky, Cholesky diff --git a/test/cuda/cudnn.jl b/test/ext_cuda/cudnn.jl similarity index 100% rename from test/cuda/cudnn.jl rename to test/ext_cuda/cudnn.jl diff --git a/test/cuda/curnn.jl b/test/ext_cuda/curnn.jl similarity index 100% rename from test/cuda/curnn.jl rename to test/ext_cuda/curnn.jl diff --git a/test/cuda/layers.jl b/test/ext_cuda/layers.jl similarity index 97% rename from test/cuda/layers.jl rename to test/ext_cuda/layers.jl index 90c7ab0b40..e59ff35aa4 100644 --- a/test/cuda/layers.jl +++ b/test/ext_cuda/layers.jl @@ -45,7 +45,11 @@ function gpu_gradtest(name::String, layers::Vector, x_cpu = nothing, args...; te # test if test_cpu - @test y_gpu ≈ y_cpu rtol=1f-3 atol=1f-3 + if layer === GroupedConvTranspose + @test y_gpu ≈ y_cpu rtol=1f-2 atol=1f-3 + else + @test y_gpu ≈ y_cpu rtol=1f-3 atol=1f-3 + end if isnothing(xg_cpu) @test isnothing(xg_gpu) else @@ -61,7 +65,7 @@ function gpu_gradtest(name::String, layers::Vector, x_cpu = nothing, args...; te if isnothing(gs_cpu[p_cpu]) @test isnothing(gs_gpu[p_gpu]) else - @test gs_gpu[p_gpu] isa Flux.CUDA.CuArray + @test gs_gpu[p_gpu] isa CuArray if test_cpu @test Array(gs_gpu[p_gpu]) ≈ gs_cpu[p_cpu] rtol=1f-3 atol=1f-3 end @@ -259,7 +263,7 @@ end input = randn(10, 10, 10, 10) |> gpu layer_gpu = Parallel(+, zero, identity) |> gpu @test layer_gpu(input) == input - @test layer_gpu(input) isa Flux.CUDA.CuArray + @test layer_gpu(input) isa CuArray end @testset "vararg input" begin diff --git a/test/cuda/losses.jl b/test/ext_cuda/losses.jl similarity index 100% rename from test/cuda/losses.jl rename to test/ext_cuda/losses.jl diff --git a/test/ext_cuda/runtests.jl b/test/ext_cuda/runtests.jl new file mode 100644 index 0000000000..65dc51dbb0 --- /dev/null +++ b/test/ext_cuda/runtests.jl @@ -0,0 +1,30 @@ +using CUDA +using Flux, Test +using Zygote +using Zygote: pullback +using Random, LinearAlgebra, Statistics + +@assert CUDA.functional() +CUDA.allowscalar(false) + +# include("../test_utils.jl") +include("test_utils.jl") + +@testset "cuda" begin + include("cuda.jl") +end +@testset "losses" begin + include("losses.jl") +end +@testset "layers" begin + include("layers.jl") +end +@testset "cudnn" begin + include("cudnn.jl") +end +@testset "curnn" begin + include("curnn.jl") +end +@testset "ctc" begin + include("ctc.jl") +end diff --git a/test/ext_cuda/test_utils.jl b/test/ext_cuda/test_utils.jl new file mode 100644 index 0000000000..10a8d0dfdf --- /dev/null +++ b/test/ext_cuda/test_utils.jl @@ -0,0 +1,4 @@ +check_grad(g_gpu::CuArray{Float32}, g_cpu::Array{Float32}; rtol=1e-4, atol=1e-4, allow_nothing::Bool=false) = + @test g_cpu ≈ collect(g_gpu) rtol=rtol atol=atol + +check_type(x::CuArray{Float32}) = true diff --git a/test/ext_cuda/utils.jl b/test/ext_cuda/utils.jl new file mode 100644 index 0000000000..3cd2a048d4 --- /dev/null +++ b/test/ext_cuda/utils.jl @@ -0,0 +1,13 @@ + + +@testset "Rrule" begin + @testset "issue 2033" begin + if CUDA.functional() + struct Wrapped{T} + x::T + end + y, _ = Flux.pullback(Wrapped, cu(randn(3,3))) + @test y isa Wrapped{<:CuArray} + end + end +end diff --git a/test/losses.jl b/test/losses.jl index a7f23a06c4..a5ce1139df 100644 --- a/test/losses.jl +++ b/test/losses.jl @@ -5,19 +5,6 @@ using Statistics: mean using Flux.Losses: mse, label_smoothing, crossentropy, logitcrossentropy, binarycrossentropy, logitbinarycrossentropy using Flux.Losses: xlogx, xlogy -# group here all losses, used in tests -const ALL_LOSSES = [Flux.Losses.mse, Flux.Losses.mae, Flux.Losses.msle, - Flux.Losses.crossentropy, Flux.Losses.logitcrossentropy, - Flux.Losses.binarycrossentropy, Flux.Losses.logitbinarycrossentropy, - Flux.Losses.kldivergence, - Flux.Losses.huber_loss, - Flux.Losses.tversky_loss, - Flux.Losses.dice_coeff_loss, - Flux.Losses.poisson_loss, - Flux.Losses.hinge_loss, Flux.Losses.squared_hinge_loss, - Flux.Losses.binary_focal_loss, Flux.Losses.focal_loss, Flux.Losses.siamese_contrastive_loss] - - @testset "xlogx & xlogy" begin @test iszero(xlogx(0)) @test isnan(xlogx(NaN)) diff --git a/test/runtests.jl b/test/runtests.jl index 3ee1ebf649..50c9ea8b01 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -6,17 +6,12 @@ using Random, Statistics, LinearAlgebra using IterTools: ncycle using Zygote using CUDA +using cuDNN - +# ENV["FLUX_TEST_AMDGPU"] = "true" +ENV["FLUX_TEST_CUDA"] = "true" # ENV["FLUX_TEST_METAL"] = "true" -if VERSION >= v"1.9" && get(ENV, "FLUX_TEST_METAL", "false") == "true" - using Pkg - Pkg.add("Metal") # Hack to allow testing on julia 1.6 - # since Metal is not registered for julia < 1.8 - # When 1.6 is dropped, remove this and add Metal to test targets in Project.toml -end - include("test_utils.jl") Random.seed!(0) @@ -43,7 +38,6 @@ Random.seed!(0) @testset "Losses" begin include("losses.jl") include("ctc.jl") - CUDA.functional() && include("ctc-gpu.jl") end @testset "Layers" begin @@ -62,30 +56,29 @@ Random.seed!(0) include("outputsize.jl") end - @testset "CUDA" begin - if CUDA.functional() - include("cuda/runtests.jl") - else - @warn "CUDA unavailable, not testing GPU support" - end - end - @static if VERSION == v"1.6" - using Documenter - @testset "Docs" begin - DocMeta.setdocmeta!(Flux, :DocTestSetup, :(using Flux); recursive=true) - doctest(Flux) + if get(ENV, "FLUX_TEST_CUDA", "false") == "true" + using CUDA + Flux.gpu_backend!("CUDA") + @testset "CUDA" begin + if CUDA.functional() + @info "Testing CUDA Support" + include("ext_cuda/runtests.jl") + else + @warn "CUDA.jl package is not functional. Skipping CUDA tests." + end end + else + @info "Skipping CUDA tests, set FLUX_TEST_CUDA=true to run them." end if get(ENV, "FLUX_TEST_AMDGPU", "false") == "true" using AMDGPU Flux.gpu_backend!("AMD") - AMDGPU.allowscalar(false) if AMDGPU.functional() && AMDGPU.functional(:MIOpen) @testset "AMDGPU" begin - include("amd/runtests.jl") + include("ext_amdgpu/runtests.jl") end else @info "AMDGPU.jl package is not functional. Skipping AMDGPU tests." @@ -108,4 +101,12 @@ Random.seed!(0) else @info "Skipping Metal tests, set FLUX_TEST_METAL=true to run them." end + + @static if VERSION == v"1.9" + using Documenter + @testset "Docs" begin + DocMeta.setdocmeta!(Flux, :DocTestSetup, :(using Flux); recursive=true) + doctest(Flux) + end + end end diff --git a/test/test_utils.jl b/test/test_utils.jl index 14cfd61774..004d3035ad 100644 --- a/test/test_utils.jl +++ b/test/test_utils.jl @@ -1,3 +1,17 @@ +# group here all losses, used in tests +const ALL_LOSSES = [Flux.Losses.mse, Flux.Losses.mae, Flux.Losses.msle, + Flux.Losses.crossentropy, Flux.Losses.logitcrossentropy, + Flux.Losses.binarycrossentropy, Flux.Losses.logitbinarycrossentropy, + Flux.Losses.kldivergence, + Flux.Losses.huber_loss, + Flux.Losses.tversky_loss, + Flux.Losses.dice_coeff_loss, + Flux.Losses.poisson_loss, + Flux.Losses.hinge_loss, Flux.Losses.squared_hinge_loss, + Flux.Losses.binary_focal_loss, Flux.Losses.focal_loss, Flux.Losses.siamese_contrastive_loss] + + + function check_grad(g_gpu, g_cpu; rtol=1e-4, atol=1e-4, allow_nothing::Bool=false) @@ -16,9 +30,6 @@ check_grad(g_gpu::Nothing, g_cpu::Nothing; rtol=1e-4, atol=1e-4, allow_nothing:: check_grad(g_gpu::Float32, g_cpu::Float32; rtol=1e-4, atol=1e-4, allow_nothing::Bool=false) = @test g_cpu ≈ g_gpu rtol=rtol atol=atol -check_grad(g_gpu::CuArray{Float32}, g_cpu::Array{Float32}; rtol=1e-4, atol=1e-4, allow_nothing::Bool=false) = - @test g_cpu ≈ collect(g_gpu) rtol=rtol atol=atol - function check_grad(g_gpu::Tuple, g_cpu::Tuple; rtol=1e-4, atol=1e-4, allow_nothing::Bool=false) for (v1, v2) in zip(g_gpu, g_cpu) check_grad(v1, v2; rtol, atol, allow_nothing) @@ -34,7 +45,6 @@ end check_type(x) = false check_type(x::Float32) = true -check_type(x::CuArray{Float32}) = true check_type(x::Array{Float32}) = true function gpu_autodiff_test( diff --git a/test/utils.jl b/test/utils.jl index bac8deefa6..620a4d40b4 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -660,18 +660,6 @@ end end end -@testset "Rrule" begin - @testset "issue 2033" begin - if CUDA.functional() - struct Wrapped{T} - x::T - end - y, _ = Flux.pullback(Wrapped, cu(randn(3,3))) - @test y isa Wrapped{<:CuArray} - end - end -end - # make sure rng_from_array is non_differentiable @testset "rng_from_array" begin m(x) = (rand(rng_from_array(x)) * x)[1]