diff --git a/Project.toml b/Project.toml index f62e954..2df85f1 100644 --- a/Project.toml +++ b/Project.toml @@ -19,7 +19,6 @@ AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" -LuxAMDGPU = "83120cb1-ca15-4f04-bf3b-6967d2e6b60b" LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" Metal = "dde4c033-4e86-420c-a63e-0dd931031962" RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" @@ -32,7 +31,6 @@ LuxDeviceUtilsAMDGPUExt = "AMDGPU" LuxDeviceUtilsCUDAExt = "CUDA" LuxDeviceUtilsFillArraysExt = "FillArrays" LuxDeviceUtilsGPUArraysExt = "GPUArrays" -LuxDeviceUtilsLuxAMDGPUExt = "LuxAMDGPU" LuxDeviceUtilsLuxCUDAExt = "LuxCUDA" LuxDeviceUtilsMetalExt = ["GPUArrays", "Metal"] LuxDeviceUtilsRecursiveArrayToolsExt = "RecursiveArrayTools" @@ -53,10 +51,10 @@ FastClosures = "0.3.2" FillArrays = "1" Functors = "0.4.4" GPUArrays = "10" -LuxAMDGPU = "0.2.2" LuxCUDA = "0.3.2" LuxCore = "0.1.4" Metal = "1" +Pkg = "1.10" PrecompileTools = "1.2" Preferences = "1.4" Random = "1.10" diff --git a/ext/LuxDeviceUtilsAMDGPUExt.jl b/ext/LuxDeviceUtilsAMDGPUExt.jl index 842bbcb..6d8147c 100644 --- a/ext/LuxDeviceUtilsAMDGPUExt.jl +++ b/ext/LuxDeviceUtilsAMDGPUExt.jl @@ -2,9 +2,31 @@ module LuxDeviceUtilsAMDGPUExt using Adapt: Adapt using AMDGPU: AMDGPU -using LuxDeviceUtils: LuxDeviceUtils, LuxAMDGPUDevice, LuxCPUDevice +using LuxDeviceUtils: LuxDeviceUtils, LuxAMDGPUDevice, LuxCPUDevice, reset_gpu_device! using Random: Random +__init__() = reset_gpu_device!() + +# This code used to be in `LuxAMDGPU.jl`, but we no longer need that package. +const USE_AMD_GPU = Ref{Union{Nothing, Bool}}(nothing) + +function _check_use_amdgpu!() + USE_AMD_GPU[] === nothing || return + + USE_AMD_GPU[] = AMDGPU.functional() + if USE_AMD_GPU[] && !AMDGPU.functional(:MIOpen) + @warn "MIOpen is not functional in AMDGPU.jl, some functionality will not be \ + available." maxlog=1 + end + return +end + +LuxDeviceUtils.loaded(::Union{LuxAMDGPUDevice, <:Type{LuxAMDGPUDevice}}) = true +function LuxDeviceUtils.functional(::Union{LuxAMDGPUDevice, <:Type{LuxAMDGPUDevice}})::Bool + _check_use_amdgpu!() + return USE_AMD_GPU[] +end + function LuxDeviceUtils._with_device(::Type{LuxAMDGPUDevice}, ::Nothing) return LuxAMDGPUDevice(nothing) end diff --git a/ext/LuxDeviceUtilsLuxAMDGPUExt.jl b/ext/LuxDeviceUtilsLuxAMDGPUExt.jl deleted file mode 100644 index 15fcb9f..0000000 --- a/ext/LuxDeviceUtilsLuxAMDGPUExt.jl +++ /dev/null @@ -1,13 +0,0 @@ -module LuxDeviceUtilsLuxAMDGPUExt - -using LuxAMDGPU: LuxAMDGPU -using LuxDeviceUtils: LuxDeviceUtils, LuxAMDGPUDevice, reset_gpu_device! - -__init__() = reset_gpu_device!() - -LuxDeviceUtils.__is_loaded(::Union{LuxAMDGPUDevice, <:Type{LuxAMDGPUDevice}}) = true -function LuxDeviceUtils.__is_functional(::Union{LuxAMDGPUDevice, <:Type{LuxAMDGPUDevice}}) - return LuxAMDGPU.functional() -end - -end diff --git a/ext/LuxDeviceUtilsLuxCUDAExt.jl b/ext/LuxDeviceUtilsLuxCUDAExt.jl index 4e386ad..4870710 100644 --- a/ext/LuxDeviceUtilsLuxCUDAExt.jl +++ b/ext/LuxDeviceUtilsLuxCUDAExt.jl @@ -5,8 +5,8 @@ using LuxDeviceUtils: LuxDeviceUtils, LuxCUDADevice, reset_gpu_device! __init__() = reset_gpu_device!() -LuxDeviceUtils.__is_loaded(::Union{LuxCUDADevice, Type{<:LuxCUDADevice}}) = true -function LuxDeviceUtils.__is_functional(::Union{LuxCUDADevice, Type{<:LuxCUDADevice}}) +LuxDeviceUtils.loaded(::Union{LuxCUDADevice, Type{<:LuxCUDADevice}}) = true +function LuxDeviceUtils.functional(::Union{LuxCUDADevice, Type{<:LuxCUDADevice}}) return LuxCUDA.functional() end diff --git a/ext/LuxDeviceUtilsMetalExt.jl b/ext/LuxDeviceUtilsMetalExt.jl index 2db6866..f53e7c5 100644 --- a/ext/LuxDeviceUtilsMetalExt.jl +++ b/ext/LuxDeviceUtilsMetalExt.jl @@ -7,8 +7,8 @@ using Metal: Metal, MtlArray __init__() = reset_gpu_device!() -LuxDeviceUtils.__is_loaded(::Union{LuxMetalDevice, Type{<:LuxMetalDevice}}) = true -function LuxDeviceUtils.__is_functional(::Union{LuxMetalDevice, Type{<:LuxMetalDevice}}) +LuxDeviceUtils.loaded(::Union{LuxMetalDevice, Type{<:LuxMetalDevice}}) = true +function LuxDeviceUtils.functional(::Union{LuxMetalDevice, Type{<:LuxMetalDevice}}) return Metal.functional() end diff --git a/ext/LuxDeviceUtilsoneAPIExt.jl b/ext/LuxDeviceUtilsoneAPIExt.jl index 881eb66..00b8faa 100644 --- a/ext/LuxDeviceUtilsoneAPIExt.jl +++ b/ext/LuxDeviceUtilsoneAPIExt.jl @@ -16,8 +16,8 @@ function __init__() end end -LuxDeviceUtils.__is_loaded(::Union{LuxoneAPIDevice, Type{<:LuxoneAPIDevice}}) = true -function LuxDeviceUtils.__is_functional(::Union{LuxoneAPIDevice, Type{<:LuxoneAPIDevice}}) +LuxDeviceUtils.loaded(::Union{LuxoneAPIDevice, Type{<:LuxoneAPIDevice}}) = true +function LuxDeviceUtils.functional(::Union{LuxoneAPIDevice, Type{<:LuxoneAPIDevice}}) return oneAPI.functional() end diff --git a/src/LuxDeviceUtils.jl b/src/LuxDeviceUtils.jl index 06e5007..ec8930d 100644 --- a/src/LuxDeviceUtils.jl +++ b/src/LuxDeviceUtils.jl @@ -24,8 +24,30 @@ export get_device abstract type AbstractLuxDevice <: Function end abstract type AbstractLuxGPUDevice <: AbstractLuxDevice end -@inline __is_functional(x) = false -@inline __is_loaded(x) = false +""" + functional(x::AbstractLuxDevice) -> Bool + functional(::Type{<:AbstractLuxDevice}) -> Bool + +Checks if the device is functional. This is used to determine if the device can be used for +computation. Note that even if the backend is loaded (as checked via +[`LuxDeviceUtils.loaded`](@ref)), the device may not be functional. + +Note that while this function is not exported, it is considered part of the public API. +""" +@inline functional(x) = false + +""" + loaded(x::AbstractLuxDevice) -> Bool + loaded(::Type{<:AbstractLuxDevice}) -> Bool + +Checks if the trigger package for the device is loaded. Trigger packages are as follows: + + - `LuxCUDA.jl` for NVIDIA CUDA Support. + - `AMDGPU.jl` for AMD GPU ROCM Support. + - `Metal.jl` for Apple Metal GPU Support. + - `oneAPI.jl` for Intel oneAPI GPU Support. +""" +@inline loaded(x) = false struct LuxCPUDevice <: AbstractLuxDevice end @kwdef struct LuxCUDADevice{D} <: AbstractLuxGPUDevice @@ -47,11 +69,11 @@ for dev in (LuxCPUDevice, LuxMetalDevice, LuxoneAPIDevice) end end -@inline __is_functional(::Union{LuxCPUDevice, Type{<:LuxCPUDevice}}) = true -@inline __is_loaded(::Union{LuxCPUDevice, Type{<:LuxCPUDevice}}) = true +@inline functional(::Union{LuxCPUDevice, Type{<:LuxCPUDevice}}) = true +@inline loaded(::Union{LuxCPUDevice, Type{<:LuxCPUDevice}}) = true for name in (:CPU, :CUDA, :AMDGPU, :Metal, :oneAPI) - tpkg = name === :CPU ? "" : (name ∈ (:CUDA, :AMDGPU) ? "Lux$(name)" : string(name)) + tpkg = name === :CPU ? "" : (name == :CUDA ? "Lux$(name)" : string(name)) ldev = eval(Symbol(:Lux, name, :Device)) @eval begin @inline _get_device_name(::Union{$ldev, Type{<:$ldev}}) = $(string(name)) @@ -173,13 +195,13 @@ function _get_gpu_device(; force_gpu_usage::Bool) @debug "Using GPU backend set in preferences: $backend." idx = findfirst(isequal(backend), allowed_backends) device = GPU_DEVICES[idx] - if !__is_loaded(device) + if !loaded(device) @warn "Trying to use backend: $(_get_device_name(device)) but the trigger \ package $(_get_triggerpkg_name(device)) is not loaded. Ignoring the \ Preferences backend!!! Please load the package and call this \ function again to respect the Preferences backend." maxlog=1 else - if __is_functional(device) + if functional(device) @debug "Using GPU backend: $(_get_device_name(device))." return device else @@ -193,9 +215,9 @@ function _get_gpu_device(; force_gpu_usage::Bool) @debug "Running automatic GPU backend selection..." for device in GPU_DEVICES - if __is_loaded(device) + if loaded(device) @debug "Trying backend: $(_get_device_name(device))." - if __is_functional(device) + if functional(device) @debug "Using GPU backend: $(_get_device_name(device))." return device end @@ -214,7 +236,7 @@ function _get_gpu_device(; force_gpu_usage::Bool) 1. If no GPU is available, nothing needs to be done. 2. If GPU is available, load the corresponding trigger package. a. `LuxCUDA.jl` for NVIDIA CUDA Support. - b. `LuxAMDGPU.jl` for AMD GPU ROCM Support. + b. `AMDGPU.jl` for AMD GPU ROCM Support. c. `Metal.jl` for Apple Metal GPU Support. d. `oneAPI.jl` for Intel oneAPI GPU Support.""" maxlog=1 return LuxCPUDevice diff --git a/test/amdgpu.jl b/test/amdgpu.jl index 9247fdb..be58ccd 100644 --- a/test/amdgpu.jl +++ b/test/amdgpu.jl @@ -7,17 +7,17 @@ using LuxDeviceUtils, Random force_gpu_usage=true) end -using LuxAMDGPU +using AMDGPU @testset "Loaded Trigger Package" begin @test LuxDeviceUtils.GPU_DEVICE[] === nothing - if LuxAMDGPU.functional() - @info "LuxAMDGPU is functional" + if LuxDeviceUtils.functional(LuxAMDGPUDevice) + @info "AMDGPU is functional" @test gpu_device() isa LuxAMDGPUDevice @test gpu_device(; force_gpu_usage=true) isa LuxAMDGPUDevice else - @info "LuxAMDGPU is NOT functional" + @info "AMDGPU is NOT functional" @test gpu_device() isa LuxCPUDevice @test_throws LuxDeviceUtils.LuxDeviceSelectionException gpu_device(; force_gpu_usage=true) @@ -33,8 +33,9 @@ using FillArrays, Zygote # Extensions one_elem=Zygote.OneElement(2.0f0, (2, 3), (1:3, 1:4)), farray=Fill(1.0f0, (2, 3))) device = gpu_device() - aType = LuxAMDGPU.functional() ? ROCArray : Array - rngType = LuxAMDGPU.functional() ? AMDGPU.rocRAND.RNG : Random.AbstractRNG + aType = LuxDeviceUtils.functional(LuxAMDGPUDevice) ? ROCArray : Array + rngType = LuxDeviceUtils.functional(LuxAMDGPUDevice) ? AMDGPU.rocRAND.RNG : + Random.AbstractRNG ps_xpu = ps |> device @test ps_xpu.a.c isa aType @@ -45,7 +46,7 @@ using FillArrays, Zygote # Extensions @test ps_xpu.rng_default isa rngType @test ps_xpu.rng == ps.rng - if LuxAMDGPU.functional() + if LuxDeviceUtils.functional(LuxAMDGPUDevice) @test ps_xpu.one_elem isa ROCArray @test ps_xpu.farray isa ROCArray else @@ -64,7 +65,7 @@ using FillArrays, Zygote # Extensions @test ps_cpu.rng_default isa Random.TaskLocalRNG @test ps_cpu.rng == ps.rng - if LuxAMDGPU.functional() + if LuxDeviceUtils.functional(LuxAMDGPUDevice) @test ps_cpu.one_elem isa Array @test ps_cpu.farray isa Array else @@ -73,7 +74,7 @@ using FillArrays, Zygote # Extensions end end -if LuxAMDGPU.functional() +if LuxDeviceUtils.functional(LuxAMDGPUDevice) ps = (; weight=rand(Float32, 10), bias=rand(Float32, 10)) ps_cpu = deepcopy(ps) cdev = cpu_device() diff --git a/test/cuda.jl b/test/cuda.jl index e0dc343..694f14b 100644 --- a/test/cuda.jl +++ b/test/cuda.jl @@ -12,7 +12,7 @@ using LuxCUDA @testset "Loaded Trigger Package" begin @test LuxDeviceUtils.GPU_DEVICE[] === nothing - if LuxCUDA.functional() + if LuxDeviceUtils.functional(LuxCUDADevice) @info "LuxCUDA is functional" @test gpu_device() isa LuxCUDADevice @test gpu_device(; force_gpu_usage=true) isa LuxCUDADevice @@ -33,8 +33,8 @@ using FillArrays, Zygote # Extensions one_elem=Zygote.OneElement(2.0f0, (2, 3), (1:3, 1:4)), farray=Fill(1.0f0, (2, 3))) device = gpu_device() - aType = LuxCUDA.functional() ? CuArray : Array - rngType = LuxCUDA.functional() ? CUDA.RNG : Random.AbstractRNG + aType = LuxDeviceUtils.functional(LuxCUDADevice) ? CuArray : Array + rngType = LuxDeviceUtils.functional(LuxCUDADevice) ? CUDA.RNG : Random.AbstractRNG ps_xpu = ps |> device @test ps_xpu.a.c isa aType @@ -45,7 +45,7 @@ using FillArrays, Zygote # Extensions @test ps_xpu.rng_default isa rngType @test ps_xpu.rng == ps.rng - if LuxCUDA.functional() + if LuxDeviceUtils.functional(LuxCUDADevice) @test ps_xpu.one_elem isa CuArray @test ps_xpu.farray isa CuArray else @@ -64,7 +64,7 @@ using FillArrays, Zygote # Extensions @test ps_cpu.rng_default isa Random.TaskLocalRNG @test ps_cpu.rng == ps.rng - if LuxCUDA.functional() + if LuxDeviceUtils.functional(LuxCUDADevice) @test ps_cpu.one_elem isa Array @test ps_cpu.farray isa Array else @@ -73,7 +73,7 @@ using FillArrays, Zygote # Extensions end end -if LuxCUDA.functional() +if LuxDeviceUtils.functional(LuxCUDADevice) ps = (; weight=rand(Float32, 10), bias=rand(Float32, 10)) ps_cpu = deepcopy(ps) cdev = cpu_device() diff --git a/test/explicit_imports.jl b/test/explicit_imports.jl index 1e2846f..6cf767e 100644 --- a/test/explicit_imports.jl +++ b/test/explicit_imports.jl @@ -1,6 +1,5 @@ # Load all trigger packages -import LuxAMDGPU, LuxCUDA, FillArrays, Metal, RecursiveArrayTools, SparseArrays, Zygote, - oneAPI +import FillArrays, RecursiveArrayTools, SparseArrays, Zygote using ExplicitImports, LuxDeviceUtils @test check_no_implicit_imports(LuxDeviceUtils) === nothing diff --git a/test/metal.jl b/test/metal.jl index 96c930e..9da2402 100644 --- a/test/metal.jl +++ b/test/metal.jl @@ -12,7 +12,7 @@ using Metal @testset "Loaded Trigger Package" begin @test LuxDeviceUtils.GPU_DEVICE[] === nothing - if Metal.functional() + if LuxDeviceUtils.functional(LuxMetalDevice) @info "Metal is functional" @test gpu_device() isa LuxMetalDevice @test gpu_device(; force_gpu_usage=true) isa LuxMetalDevice @@ -33,8 +33,9 @@ using FillArrays, Zygote # Extensions one_elem=Zygote.OneElement(2.0f0, (2, 3), (1:3, 1:4)), farray=Fill(1.0f0, (2, 3))) device = gpu_device() - aType = Metal.functional() ? MtlArray : Array - rngType = Metal.functional() ? Metal.GPUArrays.RNG : Random.AbstractRNG + aType = LuxDeviceUtils.functional(LuxMetalDevice) ? MtlArray : Array + rngType = LuxDeviceUtils.functional(LuxMetalDevice) ? Metal.GPUArrays.RNG : + Random.AbstractRNG ps_xpu = ps |> device @test ps_xpu.a.c isa aType @@ -45,7 +46,7 @@ using FillArrays, Zygote # Extensions @test ps_xpu.rng_default isa rngType @test ps_xpu.rng == ps.rng - if Metal.functional() + if LuxDeviceUtils.functional(LuxMetalDevice) @test ps_xpu.one_elem isa MtlArray @test ps_xpu.farray isa MtlArray else @@ -64,7 +65,7 @@ using FillArrays, Zygote # Extensions @test ps_cpu.rng_default isa Random.TaskLocalRNG @test ps_cpu.rng == ps.rng - if Metal.functional() + if LuxDeviceUtils.functional(LuxMetalDevice) @test ps_cpu.one_elem isa Array @test ps_cpu.farray isa Array else diff --git a/test/oneapi.jl b/test/oneapi.jl index 418830a..0694171 100644 --- a/test/oneapi.jl +++ b/test/oneapi.jl @@ -12,7 +12,7 @@ using oneAPI @testset "Loaded Trigger Package" begin @test LuxDeviceUtils.GPU_DEVICE[] === nothing - if oneAPI.functional() + if LuxDeviceUtils.functional(LuxoneAPIDevice) @info "oneAPI is functional" @test gpu_device() isa LuxoneAPIDevice @test gpu_device(; force_gpu_usage=true) isa LuxoneAPIDevice @@ -33,8 +33,9 @@ using FillArrays, Zygote # Extensions one_elem=Zygote.OneElement(2.0f0, (2, 3), (1:3, 1:4)), farray=Fill(1.0f0, (2, 3))) device = gpu_device() - aType = oneAPI.functional() ? oneArray : Array - rngType = oneAPI.functional() ? oneAPI.GPUArrays.RNG : Random.AbstractRNG + aType = LuxDeviceUtils.functional(LuxoneAPIDevice) ? oneArray : Array + rngType = LuxDeviceUtils.functional(LuxoneAPIDevice) ? oneAPI.GPUArrays.RNG : + Random.AbstractRNG ps_xpu = ps |> device @test ps_xpu.a.c isa aType @@ -45,7 +46,7 @@ using FillArrays, Zygote # Extensions @test ps_xpu.rng_default isa rngType @test ps_xpu.rng == ps.rng - if oneAPI.functional() + if LuxDeviceUtils.functional(LuxoneAPIDevice) @test ps_xpu.one_elem isa oneArray @test ps_xpu.farray isa oneArray else @@ -64,7 +65,7 @@ using FillArrays, Zygote # Extensions @test ps_cpu.rng_default isa Random.TaskLocalRNG @test ps_cpu.rng == ps.rng - if oneAPI.functional() + if LuxDeviceUtils.functional(LuxoneAPIDevice) @test ps_cpu.one_elem isa Array @test ps_cpu.farray isa Array else diff --git a/test/runtests.jl b/test/runtests.jl index 35e34d6..cb5e92c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -10,7 +10,7 @@ const GROUP = get(ENV, "GROUP", "NONE") end if GROUP == "AMDGPU" || GROUP == "ALL" - Pkg.add("LuxAMDGPU") + Pkg.add("AMDGPU") @safetestset "AMDGPU" include("amdgpu.jl") end @@ -29,7 +29,7 @@ const GROUP = get(ENV, "GROUP", "NONE") @safetestset "Component Arrays" include("component_arrays.jl") - Pkg.add(["LuxCUDA", "LuxAMDGPU", "Metal", "oneAPI"]) + Pkg.add(["LuxCUDA", "AMDGPU", "Metal", "oneAPI"]) @safetestset "Explicit Imports" include("explicit_imports.jl") end end