Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Commit

Permalink
Remove uses of LuxAMDGPU.jl
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jun 6, 2024
1 parent b90de03 commit 278ab07
Show file tree
Hide file tree
Showing 13 changed files with 93 additions and 62 deletions.
4 changes: 1 addition & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
LuxAMDGPU = "83120cb1-ca15-4f04-bf3b-6967d2e6b60b"
LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda"
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
Expand All @@ -32,7 +31,6 @@ LuxDeviceUtilsAMDGPUExt = "AMDGPU"
LuxDeviceUtilsCUDAExt = "CUDA"
LuxDeviceUtilsFillArraysExt = "FillArrays"
LuxDeviceUtilsGPUArraysExt = "GPUArrays"
LuxDeviceUtilsLuxAMDGPUExt = "LuxAMDGPU"
LuxDeviceUtilsLuxCUDAExt = "LuxCUDA"
LuxDeviceUtilsMetalExt = ["GPUArrays", "Metal"]
LuxDeviceUtilsRecursiveArrayToolsExt = "RecursiveArrayTools"
Expand All @@ -53,10 +51,10 @@ FastClosures = "0.3.2"
FillArrays = "1"
Functors = "0.4.4"
GPUArrays = "10"
LuxAMDGPU = "0.2.2"
LuxCUDA = "0.3.2"
LuxCore = "0.1.4"
Metal = "1"
Pkg = "1.10"
PrecompileTools = "1.2"
Preferences = "1.4"
Random = "1.10"
Expand Down
24 changes: 23 additions & 1 deletion ext/LuxDeviceUtilsAMDGPUExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,31 @@ module LuxDeviceUtilsAMDGPUExt

using Adapt: Adapt
using AMDGPU: AMDGPU
using LuxDeviceUtils: LuxDeviceUtils, LuxAMDGPUDevice, LuxCPUDevice
using LuxDeviceUtils: LuxDeviceUtils, LuxAMDGPUDevice, LuxCPUDevice, reset_gpu_device!
using Random: Random

__init__() = reset_gpu_device!()

# This code used to be in `LuxAMDGPU.jl`, but we no longer need that package.
const USE_AMD_GPU = Ref{Union{Nothing, Bool}}(nothing)

function _check_use_amdgpu!()
USE_AMD_GPU[] === nothing || return

USE_AMD_GPU[] = AMDGPU.functional()
if USE_AMD_GPU[] && !AMDGPU.functional(:MIOpen)
@warn "MIOpen is not functional in AMDGPU.jl, some functionality will not be \
available." maxlog=1
end
return
end

LuxDeviceUtils.loaded(::Union{LuxAMDGPUDevice, <:Type{LuxAMDGPUDevice}}) = true
function LuxDeviceUtils.functional(::Union{LuxAMDGPUDevice, <:Type{LuxAMDGPUDevice}})::Bool
_check_use_amdgpu!()
return USE_AMD_GPU[]
end

function LuxDeviceUtils._with_device(::Type{LuxAMDGPUDevice}, ::Nothing)
return LuxAMDGPUDevice(nothing)
end
Expand Down
13 changes: 0 additions & 13 deletions ext/LuxDeviceUtilsLuxAMDGPUExt.jl

This file was deleted.

4 changes: 2 additions & 2 deletions ext/LuxDeviceUtilsLuxCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ using LuxDeviceUtils: LuxDeviceUtils, LuxCUDADevice, reset_gpu_device!

__init__() = reset_gpu_device!()

LuxDeviceUtils.__is_loaded(::Union{LuxCUDADevice, Type{<:LuxCUDADevice}}) = true
function LuxDeviceUtils.__is_functional(::Union{LuxCUDADevice, Type{<:LuxCUDADevice}})
LuxDeviceUtils.loaded(::Union{LuxCUDADevice, Type{<:LuxCUDADevice}}) = true
function LuxDeviceUtils.functional(::Union{LuxCUDADevice, Type{<:LuxCUDADevice}})
return LuxCUDA.functional()
end

Expand Down
4 changes: 2 additions & 2 deletions ext/LuxDeviceUtilsMetalExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ using Metal: Metal, MtlArray

__init__() = reset_gpu_device!()

LuxDeviceUtils.__is_loaded(::Union{LuxMetalDevice, Type{<:LuxMetalDevice}}) = true
function LuxDeviceUtils.__is_functional(::Union{LuxMetalDevice, Type{<:LuxMetalDevice}})
LuxDeviceUtils.loaded(::Union{LuxMetalDevice, Type{<:LuxMetalDevice}}) = true
function LuxDeviceUtils.functional(::Union{LuxMetalDevice, Type{<:LuxMetalDevice}})
return Metal.functional()
end

Expand Down
4 changes: 2 additions & 2 deletions ext/LuxDeviceUtilsoneAPIExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ function __init__()
end
end

LuxDeviceUtils.__is_loaded(::Union{LuxoneAPIDevice, Type{<:LuxoneAPIDevice}}) = true
function LuxDeviceUtils.__is_functional(::Union{LuxoneAPIDevice, Type{<:LuxoneAPIDevice}})
LuxDeviceUtils.loaded(::Union{LuxoneAPIDevice, Type{<:LuxoneAPIDevice}}) = true
function LuxDeviceUtils.functional(::Union{LuxoneAPIDevice, Type{<:LuxoneAPIDevice}})
return oneAPI.functional()
end

Expand Down
42 changes: 32 additions & 10 deletions src/LuxDeviceUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,30 @@ export get_device
abstract type AbstractLuxDevice <: Function end
abstract type AbstractLuxGPUDevice <: AbstractLuxDevice end

@inline __is_functional(x) = false
@inline __is_loaded(x) = false
"""
functional(x::AbstractLuxDevice) -> Bool
functional(::Type{<:AbstractLuxDevice}) -> Bool
Checks if the device is functional. This is used to determine if the device can be used for
computation. Note that even if the backend is loaded (as checked via
[`LuxDeviceUtils.loaded`](@ref)), the device may not be functional.
Note that while this function is not exported, it is considered part of the public API.
"""
@inline functional(x) = false

"""
loaded(x::AbstractLuxDevice) -> Bool
loaded(::Type{<:AbstractLuxDevice}) -> Bool
Checks if the trigger package for the device is loaded. Trigger packages are as follows:
- `LuxCUDA.jl` for NVIDIA CUDA Support.
- `AMDGPU.jl` for AMD GPU ROCM Support.
- `Metal.jl` for Apple Metal GPU Support.
- `oneAPI.jl` for Intel oneAPI GPU Support.
"""
@inline loaded(x) = false

struct LuxCPUDevice <: AbstractLuxDevice end
@kwdef struct LuxCUDADevice{D} <: AbstractLuxGPUDevice
Expand All @@ -47,11 +69,11 @@ for dev in (LuxCPUDevice, LuxMetalDevice, LuxoneAPIDevice)
end
end

@inline __is_functional(::Union{LuxCPUDevice, Type{<:LuxCPUDevice}}) = true
@inline __is_loaded(::Union{LuxCPUDevice, Type{<:LuxCPUDevice}}) = true
@inline functional(::Union{LuxCPUDevice, Type{<:LuxCPUDevice}}) = true
@inline loaded(::Union{LuxCPUDevice, Type{<:LuxCPUDevice}}) = true

for name in (:CPU, :CUDA, :AMDGPU, :Metal, :oneAPI)
tpkg = name === :CPU ? "" : (name (:CUDA, :AMDGPU) ? "Lux$(name)" : string(name))
tpkg = name === :CPU ? "" : (name == :CUDA ? "Lux$(name)" : string(name))
ldev = eval(Symbol(:Lux, name, :Device))
@eval begin
@inline _get_device_name(::Union{$ldev, Type{<:$ldev}}) = $(string(name))
Expand Down Expand Up @@ -173,13 +195,13 @@ function _get_gpu_device(; force_gpu_usage::Bool)
@debug "Using GPU backend set in preferences: $backend."
idx = findfirst(isequal(backend), allowed_backends)
device = GPU_DEVICES[idx]
if !__is_loaded(device)
if !loaded(device)
@warn "Trying to use backend: $(_get_device_name(device)) but the trigger \
package $(_get_triggerpkg_name(device)) is not loaded. Ignoring the \
Preferences backend!!! Please load the package and call this \
function again to respect the Preferences backend." maxlog=1
else
if __is_functional(device)
if functional(device)
@debug "Using GPU backend: $(_get_device_name(device))."
return device
else
Expand All @@ -193,9 +215,9 @@ function _get_gpu_device(; force_gpu_usage::Bool)

@debug "Running automatic GPU backend selection..."
for device in GPU_DEVICES
if __is_loaded(device)
if loaded(device)
@debug "Trying backend: $(_get_device_name(device))."
if __is_functional(device)
if functional(device)
@debug "Using GPU backend: $(_get_device_name(device))."
return device
end
Expand All @@ -214,7 +236,7 @@ function _get_gpu_device(; force_gpu_usage::Bool)
1. If no GPU is available, nothing needs to be done.
2. If GPU is available, load the corresponding trigger package.
a. `LuxCUDA.jl` for NVIDIA CUDA Support.
b. `LuxAMDGPU.jl` for AMD GPU ROCM Support.
b. `AMDGPU.jl` for AMD GPU ROCM Support.
c. `Metal.jl` for Apple Metal GPU Support.
d. `oneAPI.jl` for Intel oneAPI GPU Support.""" maxlog=1
return LuxCPUDevice
Expand Down
19 changes: 10 additions & 9 deletions test/amdgpu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,17 @@ using LuxDeviceUtils, Random
force_gpu_usage=true)
end

using LuxAMDGPU
using AMDGPU

@testset "Loaded Trigger Package" begin
@test LuxDeviceUtils.GPU_DEVICE[] === nothing

if LuxAMDGPU.functional()
@info "LuxAMDGPU is functional"
if LuxDeviceUtils.functional(LuxAMDGPUDevice)
@info "AMDGPU is functional"
@test gpu_device() isa LuxAMDGPUDevice
@test gpu_device(; force_gpu_usage=true) isa LuxAMDGPUDevice
else
@info "LuxAMDGPU is NOT functional"
@info "AMDGPU is NOT functional"
@test gpu_device() isa LuxCPUDevice
@test_throws LuxDeviceUtils.LuxDeviceSelectionException gpu_device(;
force_gpu_usage=true)
Expand All @@ -33,8 +33,9 @@ using FillArrays, Zygote # Extensions
one_elem=Zygote.OneElement(2.0f0, (2, 3), (1:3, 1:4)), farray=Fill(1.0f0, (2, 3)))

device = gpu_device()
aType = LuxAMDGPU.functional() ? ROCArray : Array
rngType = LuxAMDGPU.functional() ? AMDGPU.rocRAND.RNG : Random.AbstractRNG
aType = LuxDeviceUtils.functional(LuxAMDGPUDevice) ? ROCArray : Array
rngType = LuxDeviceUtils.functional(LuxAMDGPUDevice) ? AMDGPU.rocRAND.RNG :
Random.AbstractRNG

ps_xpu = ps |> device
@test ps_xpu.a.c isa aType
Expand All @@ -45,7 +46,7 @@ using FillArrays, Zygote # Extensions
@test ps_xpu.rng_default isa rngType
@test ps_xpu.rng == ps.rng

if LuxAMDGPU.functional()
if LuxDeviceUtils.functional(LuxAMDGPUDevice)
@test ps_xpu.one_elem isa ROCArray
@test ps_xpu.farray isa ROCArray
else
Expand All @@ -64,7 +65,7 @@ using FillArrays, Zygote # Extensions
@test ps_cpu.rng_default isa Random.TaskLocalRNG
@test ps_cpu.rng == ps.rng

if LuxAMDGPU.functional()
if LuxDeviceUtils.functional(LuxAMDGPUDevice)
@test ps_cpu.one_elem isa Array
@test ps_cpu.farray isa Array
else
Expand All @@ -73,7 +74,7 @@ using FillArrays, Zygote # Extensions
end
end

if LuxAMDGPU.functional()
if LuxDeviceUtils.functional(LuxAMDGPUDevice)
ps = (; weight=rand(Float32, 10), bias=rand(Float32, 10))
ps_cpu = deepcopy(ps)
cdev = cpu_device()
Expand Down
12 changes: 6 additions & 6 deletions test/cuda.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ using LuxCUDA
@testset "Loaded Trigger Package" begin
@test LuxDeviceUtils.GPU_DEVICE[] === nothing

if LuxCUDA.functional()
if LuxDeviceUtils.functional(LuxCUDADevice)
@info "LuxCUDA is functional"
@test gpu_device() isa LuxCUDADevice
@test gpu_device(; force_gpu_usage=true) isa LuxCUDADevice
Expand All @@ -33,8 +33,8 @@ using FillArrays, Zygote # Extensions
one_elem=Zygote.OneElement(2.0f0, (2, 3), (1:3, 1:4)), farray=Fill(1.0f0, (2, 3)))

device = gpu_device()
aType = LuxCUDA.functional() ? CuArray : Array
rngType = LuxCUDA.functional() ? CUDA.RNG : Random.AbstractRNG
aType = LuxDeviceUtils.functional(LuxCUDADevice) ? CuArray : Array
rngType = LuxDeviceUtils.functional(LuxCUDADevice) ? CUDA.RNG : Random.AbstractRNG

ps_xpu = ps |> device
@test ps_xpu.a.c isa aType
Expand All @@ -45,7 +45,7 @@ using FillArrays, Zygote # Extensions
@test ps_xpu.rng_default isa rngType
@test ps_xpu.rng == ps.rng

if LuxCUDA.functional()
if LuxDeviceUtils.functional(LuxCUDADevice)
@test ps_xpu.one_elem isa CuArray
@test ps_xpu.farray isa CuArray
else
Expand All @@ -64,7 +64,7 @@ using FillArrays, Zygote # Extensions
@test ps_cpu.rng_default isa Random.TaskLocalRNG
@test ps_cpu.rng == ps.rng

if LuxCUDA.functional()
if LuxDeviceUtils.functional(LuxCUDADevice)
@test ps_cpu.one_elem isa Array
@test ps_cpu.farray isa Array
else
Expand All @@ -73,7 +73,7 @@ using FillArrays, Zygote # Extensions
end
end

if LuxCUDA.functional()
if LuxDeviceUtils.functional(LuxCUDADevice)
ps = (; weight=rand(Float32, 10), bias=rand(Float32, 10))
ps_cpu = deepcopy(ps)
cdev = cpu_device()
Expand Down
3 changes: 1 addition & 2 deletions test/explicit_imports.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# Load all trigger packages
import LuxAMDGPU, LuxCUDA, FillArrays, Metal, RecursiveArrayTools, SparseArrays, Zygote,
oneAPI
import FillArrays, RecursiveArrayTools, SparseArrays, Zygote
using ExplicitImports, LuxDeviceUtils

@test check_no_implicit_imports(LuxDeviceUtils) === nothing
Expand Down
11 changes: 6 additions & 5 deletions test/metal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ using Metal
@testset "Loaded Trigger Package" begin
@test LuxDeviceUtils.GPU_DEVICE[] === nothing

if Metal.functional()
if LuxDeviceUtils.functional(LuxMetalDevice)
@info "Metal is functional"
@test gpu_device() isa LuxMetalDevice
@test gpu_device(; force_gpu_usage=true) isa LuxMetalDevice
Expand All @@ -33,8 +33,9 @@ using FillArrays, Zygote # Extensions
one_elem=Zygote.OneElement(2.0f0, (2, 3), (1:3, 1:4)), farray=Fill(1.0f0, (2, 3)))

device = gpu_device()
aType = Metal.functional() ? MtlArray : Array
rngType = Metal.functional() ? Metal.GPUArrays.RNG : Random.AbstractRNG
aType = LuxDeviceUtils.functional(LuxMetalDevice) ? MtlArray : Array
rngType = LuxDeviceUtils.functional(LuxMetalDevice) ? Metal.GPUArrays.RNG :
Random.AbstractRNG

ps_xpu = ps |> device
@test ps_xpu.a.c isa aType
Expand All @@ -45,7 +46,7 @@ using FillArrays, Zygote # Extensions
@test ps_xpu.rng_default isa rngType
@test ps_xpu.rng == ps.rng

if Metal.functional()
if LuxDeviceUtils.functional(LuxMetalDevice)
@test ps_xpu.one_elem isa MtlArray
@test ps_xpu.farray isa MtlArray
else
Expand All @@ -64,7 +65,7 @@ using FillArrays, Zygote # Extensions
@test ps_cpu.rng_default isa Random.TaskLocalRNG
@test ps_cpu.rng == ps.rng

if Metal.functional()
if LuxDeviceUtils.functional(LuxMetalDevice)
@test ps_cpu.one_elem isa Array
@test ps_cpu.farray isa Array
else
Expand Down
11 changes: 6 additions & 5 deletions test/oneapi.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ using oneAPI
@testset "Loaded Trigger Package" begin
@test LuxDeviceUtils.GPU_DEVICE[] === nothing

if oneAPI.functional()
if LuxDeviceUtils.functional(LuxoneAPIDevice)
@info "oneAPI is functional"
@test gpu_device() isa LuxoneAPIDevice
@test gpu_device(; force_gpu_usage=true) isa LuxoneAPIDevice
Expand All @@ -33,8 +33,9 @@ using FillArrays, Zygote # Extensions
one_elem=Zygote.OneElement(2.0f0, (2, 3), (1:3, 1:4)), farray=Fill(1.0f0, (2, 3)))

device = gpu_device()
aType = oneAPI.functional() ? oneArray : Array
rngType = oneAPI.functional() ? oneAPI.GPUArrays.RNG : Random.AbstractRNG
aType = LuxDeviceUtils.functional(LuxoneAPIDevice) ? oneArray : Array
rngType = LuxDeviceUtils.functional(LuxoneAPIDevice) ? oneAPI.GPUArrays.RNG :
Random.AbstractRNG

ps_xpu = ps |> device
@test ps_xpu.a.c isa aType
Expand All @@ -45,7 +46,7 @@ using FillArrays, Zygote # Extensions
@test ps_xpu.rng_default isa rngType
@test ps_xpu.rng == ps.rng

if oneAPI.functional()
if LuxDeviceUtils.functional(LuxoneAPIDevice)
@test ps_xpu.one_elem isa oneArray
@test ps_xpu.farray isa oneArray
else
Expand All @@ -64,7 +65,7 @@ using FillArrays, Zygote # Extensions
@test ps_cpu.rng_default isa Random.TaskLocalRNG
@test ps_cpu.rng == ps.rng

if oneAPI.functional()
if LuxDeviceUtils.functional(LuxoneAPIDevice)
@test ps_cpu.one_elem isa Array
@test ps_cpu.farray isa Array
else
Expand Down
Loading

0 comments on commit 278ab07

Please sign in to comment.