From 35030c18a6c851bb6cbf7e1baaba225d4a759870 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 23 Feb 2024 23:43:00 -0500 Subject: [PATCH] Add setup for multiGPU setups --- Project.toml | 3 +- ext/LuxDeviceUtilsLuxAMDGPUExt.jl | 15 +++- ext/LuxDeviceUtilsLuxCUDAExt.jl | 15 +++- ext/LuxDeviceUtilsMetalGPUArraysExt.jl | 6 +- ext/LuxDeviceUtilsSparseArraysExt.jl | 9 +++ src/LuxDeviceUtils.jl | 107 ++++++++++++++----------- 6 files changed, 100 insertions(+), 55 deletions(-) create mode 100644 ext/LuxDeviceUtilsSparseArraysExt.jl diff --git a/Project.toml b/Project.toml index da0cab4..8e83cce 100644 --- a/Project.toml +++ b/Project.toml @@ -11,7 +11,6 @@ LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" Preferences = "21216c6a-2e73-6563-6e65-726566657250" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" [weakdeps] FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" @@ -20,6 +19,7 @@ LuxAMDGPU = "83120cb1-ca15-4f04-bf3b-6967d2e6b60b" LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" Metal = "dde4c033-4e86-420c-a63e-0dd931031962" RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" +SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [extensions] @@ -29,6 +29,7 @@ LuxDeviceUtilsLuxAMDGPUExt = "LuxAMDGPU" LuxDeviceUtilsLuxCUDAExt = "LuxCUDA" LuxDeviceUtilsMetalGPUArraysExt = ["GPUArrays", "Metal"] LuxDeviceUtilsRecursiveArrayToolsExt = "RecursiveArrayTools" +LuxDeviceUtilsSparseArraysExt = "SparseArrays" LuxDeviceUtilsZygoteExt = "Zygote" [compat] diff --git a/ext/LuxDeviceUtilsLuxAMDGPUExt.jl b/ext/LuxDeviceUtilsLuxAMDGPUExt.jl index ac951f1..f061fcb 100644 --- a/ext/LuxDeviceUtilsLuxAMDGPUExt.jl +++ b/ext/LuxDeviceUtilsLuxAMDGPUExt.jl @@ -5,8 +5,19 @@ import Adapt: adapt_storage, adapt __init__() = reset_gpu_device!() -LuxDeviceUtils.__is_loaded(::LuxAMDGPUDevice) = true -LuxDeviceUtils.__is_functional(::LuxAMDGPUDevice) = LuxAMDGPU.functional() +LuxDeviceUtils.__is_loaded(::Union{LuxAMDGPUDevice, <:Type{LuxAMDGPUDevice}}) = true +function LuxDeviceUtils.__is_functional(::Union{LuxAMDGPUDevice, <:Type{LuxAMDGPUDevice}}) + return LuxAMDGPU.functional() +end + +function LuxDeviceUtils._with_device_id(::Type{LuxAMDGPUDevice}, device_id) + id = ifelse(device_id === nothing, 0, device_id) + old_id = AMDGPU.device_id(AMDGPU.device()) - 1 + AMDGPU.device!(AMDGPU.devices()[id + 1]) + device = LuxAMDGPUDevice(AMDGPU.device()) + AMDGPU.device!(AMDGPU.devices()[old_id + 1]) + return device +end # Default RNG LuxDeviceUtils.default_device_rng(::LuxAMDGPUDevice) = AMDGPU.rocrand_rng() diff --git a/ext/LuxDeviceUtilsLuxCUDAExt.jl b/ext/LuxDeviceUtilsLuxCUDAExt.jl index 4edf554..d57fc97 100644 --- a/ext/LuxDeviceUtilsLuxCUDAExt.jl +++ b/ext/LuxDeviceUtilsLuxCUDAExt.jl @@ -5,8 +5,19 @@ import Adapt: adapt_storage, adapt __init__() = reset_gpu_device!() -LuxDeviceUtils.__is_loaded(::LuxCUDADevice) = true -LuxDeviceUtils.__is_functional(::LuxCUDADevice) = LuxCUDA.functional() +LuxDeviceUtils.__is_loaded(::Union{LuxCUDADevice, Type{<:LuxCUDADevice}}) = true +function LuxDeviceUtils.__is_functional(::Union{LuxCUDADevice, Type{<:LuxCUDADevice}}) + return LuxCUDA.functional() +end + +function LuxDeviceUtils._with_device_id(::Type{LuxCUDADevice}, device_id) + id = ifelse(device_id === nothing, 0, device_id) + old_id = CUDA.device().handle + CUDA.device!(id) + device = LuxCUDADevice(CUDA.device()) + CUDA.device!(old_id) + return device +end # Default RNG LuxDeviceUtils.default_device_rng(::LuxCUDADevice) = CUDA.default_rng() diff --git a/ext/LuxDeviceUtilsMetalGPUArraysExt.jl b/ext/LuxDeviceUtilsMetalGPUArraysExt.jl index 836ab07..8272d6c 100644 --- a/ext/LuxDeviceUtilsMetalGPUArraysExt.jl +++ b/ext/LuxDeviceUtilsMetalGPUArraysExt.jl @@ -5,8 +5,10 @@ import Adapt: adapt_storage, adapt __init__() = reset_gpu_device!() -LuxDeviceUtils.__is_loaded(::LuxMetalDevice) = true -LuxDeviceUtils.__is_functional(::LuxMetalDevice) = Metal.functional() +LuxDeviceUtils.__is_loaded(::Union{LuxMetalDevice, Type{<:LuxMetalDevice}}) = true +function LuxDeviceUtils.__is_functional(::Union{LuxMetalDevice, Type{<:LuxMetalDevice}}) + return Metal.functional() +end # Default RNG LuxDeviceUtils.default_device_rng(::LuxMetalDevice) = GPUArrays.default_rng(MtlArray) diff --git a/ext/LuxDeviceUtilsSparseArraysExt.jl b/ext/LuxDeviceUtilsSparseArraysExt.jl new file mode 100644 index 0000000..80f5e35 --- /dev/null +++ b/ext/LuxDeviceUtilsSparseArraysExt.jl @@ -0,0 +1,9 @@ +module LuxDeviceUtilsSparseArraysExt + +import Adapt: adapt_storage +import LuxDeviceUtils: LuxCPUAdaptor +import SparseArrays: AbstractSparseArray + +adapt_storage(::LuxCPUAdaptor, x::AbstractSparseArray) = x + +end diff --git a/src/LuxDeviceUtils.jl b/src/LuxDeviceUtils.jl index 04347dc..2b9d020 100644 --- a/src/LuxDeviceUtils.jl +++ b/src/LuxDeviceUtils.jl @@ -3,7 +3,7 @@ module LuxDeviceUtils import PrecompileTools: @recompile_invalidations @recompile_invalidations begin - using ChainRulesCore, Functors, LuxCore, Preferences, Random, SparseArrays + using ChainRulesCore, Functors, LuxCore, Preferences, Random import Adapt: adapt, adapt_storage import ChainRulesCore as CRC end @@ -21,33 +21,49 @@ __is_functional(::AbstractLuxDevice) = false __is_loaded(::AbstractLuxDevice) = false struct LuxCPUDevice <: AbstractLuxDevice end -struct LuxCUDADevice <: AbstractLuxGPUDevice end -struct LuxAMDGPUDevice <: AbstractLuxGPUDevice end +@kwdef struct LuxCUDADevice{ID} <: AbstractLuxGPUDevice + device_id::ID = nothing +end +@kwdef struct LuxAMDGPUDevice{ID} <: AbstractLuxGPUDevice + device_id::ID = nothing +end struct LuxMetalDevice <: AbstractLuxGPUDevice end -__is_functional(::LuxCPUDevice) = true -__is_loaded(::LuxCPUDevice) = true +_with_device_id(::Type{LuxCPUDevice}, ::Nothing) = LuxCPUDevice() +function _with_device_id(::Type{LuxCPUDevice}, device_id) + @warn "`device_id` is not applicable for `LuxCPUDevice`." maxlog=1 + return LuxCPUDevice() +end + +_with_device_id(::Type{LuxMetalDevice}, ::Nothing) = LuxMetalDevice() +function _with_device_id(::Type{LuxMetalDevice}, device_id) + @warn "`device_id` is not applicable for `LuxMetalDevice`." maxlog=1 + return LuxMetalDevice() +end + +__is_functional(::Union{LuxCPUDevice, Type{<:LuxCPUDevice}}) = true +__is_loaded(::Union{LuxCPUDevice, Type{<:LuxCPUDevice}}) = true -_get_device_name(::LuxCPUDevice) = "CPU" -_get_device_name(::LuxCUDADevice) = "CUDA" -_get_device_name(::LuxAMDGPUDevice) = "AMDGPU" -_get_device_name(::LuxMetalDevice) = "Metal" +_get_device_name(::Union{LuxCPUDevice, Type{<:LuxCPUDevice}}) = "CPU" +_get_device_name(::Union{LuxCUDADevice, Type{<:LuxCUDADevice}}) = "CUDA" +_get_device_name(::Union{LuxAMDGPUDevice, Type{<:LuxAMDGPUDevice}}) = "AMDGPU" +_get_device_name(::Union{LuxMetalDevice, Type{<:LuxMetalDevice}}) = "Metal" -_get_triggerpkg_name(::LuxCPUDevice) = "" -_get_triggerpkg_name(::LuxCUDADevice) = "LuxCUDA" -_get_triggerpkg_name(::LuxAMDGPUDevice) = "LuxAMDGPU" -_get_triggerpkg_name(::LuxMetalDevice) = "Metal" +_get_triggerpkg_name(::Union{LuxCPUDevice, Type{<:LuxCPUDevice}}) = "" +_get_triggerpkg_name(::Union{LuxCUDADevice, Type{<:LuxCUDADevice}}) = "LuxCUDA" +_get_triggerpkg_name(::Union{LuxAMDGPUDevice, Type{<:LuxAMDGPUDevice}}) = "LuxAMDGPU" +_get_triggerpkg_name(::Union{LuxMetalDevice, Type{<:LuxMetalDevice}}) = "Metal" Base.show(io::IO, dev::AbstractLuxDevice) = print(io, nameof(dev)) struct LuxDeviceSelectionException <: Exception end -function Base.showerror(io::IO, e::LuxDeviceSelectionException) +function Base.showerror(io::IO, ::LuxDeviceSelectionException) return print(io, "LuxDeviceSelectionException(No functional GPU device found!!)") end # Order is important here -const GPU_DEVICES = (LuxCUDADevice(), LuxAMDGPUDevice(), LuxMetalDevice()) +const GPU_DEVICES = (LuxCUDADevice, LuxAMDGPUDevice, LuxMetalDevice) const GPU_DEVICE = Ref{Union{Nothing, AbstractLuxDevice}}(nothing) @@ -57,27 +73,22 @@ const GPU_DEVICE = Ref{Union{Nothing, AbstractLuxDevice}}(nothing) Resets the selected GPU device. This is useful when automatic GPU selection needs to be run again. """ -function reset_gpu_device!() - return GPU_DEVICE[] = nothing -end +reset_gpu_device!() = (GPU_DEVICE[] = nothing) """ supported_gpu_backends() -> Tuple{String, ...} Return a tuple of supported GPU backends. -::: warning - -This is not the list of functional backends on the system, but rather backends which -`Lux.jl` supports. +!!! warning -::: + This is not the list of functional backends on the system, but rather backends which + `Lux.jl` supports. -::: danger +!!! danger -`Metal.jl` support is **extremely** experimental and most things are not expected to work. - -::: + `Metal.jl` support is **extremely** experimental and most things are not expected to + work. """ supported_gpu_backends() = map(_get_device_name, GPU_DEVICES) @@ -95,14 +106,15 @@ Selects GPU device based on the following criteria: invoked. 4. If nothing works, an error is thrown. """ -function gpu_device(; force_gpu_usage::Bool=false)::AbstractLuxDevice +function gpu_device(device_id=nothing; force_gpu_usage::Bool=false)::AbstractLuxDevice if GPU_DEVICE[] !== nothing force_gpu_usage && !(GPU_DEVICE[] isa AbstractLuxGPUDevice) && throw(LuxDeviceSelectionException()) return GPU_DEVICE[] end - device = _get_gpu_device(; force_gpu_usage) + device_type = _get_gpu_device(; force_gpu_usage) + device = _with_device_id(device_type, device_id) GPU_DEVICE[] = device return device @@ -116,25 +128,25 @@ function _get_gpu_device(; force_gpu_usage::Bool) allowed_backends = supported_gpu_backends() idx = findfirst(isequal(backend), allowed_backends) if backend ∉ allowed_backends - @warn """ - `gpu_backend` preference is set to $backend, which is not a valid backend. - Valid backends are $allowed_backends. - Defaulting to automatic GPU Backend selection. - """ maxlog=1 + @warn "`gpu_backend` preference is set to $backend, which is not a valid \ + backend. Valid backends are $allowed_backends. Defaulting to automatic \ + GPU Backend selection." maxlog=1 else @debug "Using GPU backend set in preferences: $backend." device = GPU_DEVICES[idx] if !__is_loaded(device) - @warn """Trying to use backend: $(_get_device_name(device)) but the trigger package $(device.pkgid) is not loaded. - Ignoring the Preferences backend!!! - Please load the package and call this function again to respect the Preferences backend.""" maxlog=1 + @warn "Trying to use backend: $(_get_device_name(device)) but the trigger \ + package $(device.pkgid) is not loaded. Ignoring the Preferences \ + backend!!! Please load the package and call this function again to \ + respect the Preferences backend." maxlog=1 else if __is_functional(device) @debug "Using GPU backend: $(_get_device_name(device))." return device else - @warn "GPU backend: $(_get_device_name(device)) set via Preferences.jl is not functional. - Defaulting to automatic GPU Backend selection." maxlog=1 + @warn "GPU backend: $(_get_device_name(device)) set via Preferences.jl \ + is not functional. Defaulting to automatic GPU Backend \ + selection." maxlog=1 end end end @@ -150,7 +162,8 @@ function _get_gpu_device(; force_gpu_usage::Bool) end @debug "GPU backend: $(_get_device_name(device)) is not functional." else - @debug "Trigger package for backend ($(_get_device_name(device))): $(_get_trigger_pkgname(device)) not loaded." + @debug "Trigger package for backend ($(_get_device_name(device))): \ + $(_get_trigger_pkgname(device)) not loaded." end end @@ -188,7 +201,8 @@ gpu_backend!() = gpu_backend!("") function gpu_backend!(backend::String) if backend == "" @delete_preferences!("gpu_backend") - @info "Deleted the local preference for `gpu_backend`. Restart Julia to use the new backend." + @info "Deleted the local preference for `gpu_backend`. Restart Julia to use the \ + new backend." return end @@ -250,8 +264,8 @@ for (dev) in (:CPU, :CUDA, :AMDGPU, :Metal) end function (::$(ldev))(NN::LuxCore.AbstractExplicitLayer) @warn "Lux layers are stateless and hence don't participate in device \ - transfers. Apply this function on the parameters and states generated \ - using `Lux.setup`." maxlog=1 + transfers. Apply this function on the parameters and states generated \ + using `Lux.setup`." maxlog=1 return NN end end @@ -264,7 +278,7 @@ end Returns the device of the array `x`. Trigger Packages must be loaded for this to return the correct device. """ -get_device(x::AbstractArray) = LuxCPUDevice() +get_device(::AbstractArray) = LuxCPUDevice() # Adapt Interface abstract type AbstractLuxDeviceAdaptor end @@ -274,10 +288,7 @@ struct LuxCUDAAdaptor <: AbstractLuxDeviceAdaptor end struct LuxAMDGPUAdaptor <: AbstractLuxDeviceAdaptor end struct LuxMetalAdaptor <: AbstractLuxDeviceAdaptor end -function adapt_storage(::LuxCPUAdaptor, - x::Union{AbstractRange, SparseArrays.AbstractSparseArray}) - return x -end +adapt_storage(::LuxCPUAdaptor, x::AbstractRange) = x adapt_storage(::LuxCPUAdaptor, x::AbstractArray) = adapt(Array, x) adapt_storage(::LuxCPUAdaptor, rng::AbstractRNG) = rng