Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Commit

Permalink
Add setup for multiGPU setups
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Feb 24, 2024
1 parent 6137fb7 commit 35030c1
Show file tree
Hide file tree
Showing 6 changed files with 100 additions and 55 deletions.
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"

[weakdeps]
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
Expand All @@ -20,6 +19,7 @@ LuxAMDGPU = "83120cb1-ca15-4f04-bf3b-6967d2e6b60b"
LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda"
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[extensions]
Expand All @@ -29,6 +29,7 @@ LuxDeviceUtilsLuxAMDGPUExt = "LuxAMDGPU"
LuxDeviceUtilsLuxCUDAExt = "LuxCUDA"
LuxDeviceUtilsMetalGPUArraysExt = ["GPUArrays", "Metal"]
LuxDeviceUtilsRecursiveArrayToolsExt = "RecursiveArrayTools"
LuxDeviceUtilsSparseArraysExt = "SparseArrays"
LuxDeviceUtilsZygoteExt = "Zygote"

[compat]
Expand Down
15 changes: 13 additions & 2 deletions ext/LuxDeviceUtilsLuxAMDGPUExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,19 @@ import Adapt: adapt_storage, adapt

__init__() = reset_gpu_device!()

LuxDeviceUtils.__is_loaded(::LuxAMDGPUDevice) = true
LuxDeviceUtils.__is_functional(::LuxAMDGPUDevice) = LuxAMDGPU.functional()
LuxDeviceUtils.__is_loaded(::Union{LuxAMDGPUDevice, <:Type{LuxAMDGPUDevice}}) = true
function LuxDeviceUtils.__is_functional(::Union{LuxAMDGPUDevice, <:Type{LuxAMDGPUDevice}})
return LuxAMDGPU.functional()
end

function LuxDeviceUtils._with_device_id(::Type{LuxAMDGPUDevice}, device_id)
id = ifelse(device_id === nothing, 0, device_id)
old_id = AMDGPU.device_id(AMDGPU.device()) - 1
AMDGPU.device!(AMDGPU.devices()[id + 1])
device = LuxAMDGPUDevice(AMDGPU.device())
AMDGPU.device!(AMDGPU.devices()[old_id + 1])
return device
end

# Default RNG
LuxDeviceUtils.default_device_rng(::LuxAMDGPUDevice) = AMDGPU.rocrand_rng()
Expand Down
15 changes: 13 additions & 2 deletions ext/LuxDeviceUtilsLuxCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,19 @@ import Adapt: adapt_storage, adapt

__init__() = reset_gpu_device!()

LuxDeviceUtils.__is_loaded(::LuxCUDADevice) = true
LuxDeviceUtils.__is_functional(::LuxCUDADevice) = LuxCUDA.functional()
LuxDeviceUtils.__is_loaded(::Union{LuxCUDADevice, Type{<:LuxCUDADevice}}) = true
function LuxDeviceUtils.__is_functional(::Union{LuxCUDADevice, Type{<:LuxCUDADevice}})
return LuxCUDA.functional()
end

function LuxDeviceUtils._with_device_id(::Type{LuxCUDADevice}, device_id)
id = ifelse(device_id === nothing, 0, device_id)
old_id = CUDA.device().handle
CUDA.device!(id)
device = LuxCUDADevice(CUDA.device())
CUDA.device!(old_id)
return device
end

# Default RNG
LuxDeviceUtils.default_device_rng(::LuxCUDADevice) = CUDA.default_rng()
Expand Down
6 changes: 4 additions & 2 deletions ext/LuxDeviceUtilsMetalGPUArraysExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@ import Adapt: adapt_storage, adapt

__init__() = reset_gpu_device!()

LuxDeviceUtils.__is_loaded(::LuxMetalDevice) = true
LuxDeviceUtils.__is_functional(::LuxMetalDevice) = Metal.functional()
LuxDeviceUtils.__is_loaded(::Union{LuxMetalDevice, Type{<:LuxMetalDevice}}) = true
function LuxDeviceUtils.__is_functional(::Union{LuxMetalDevice, Type{<:LuxMetalDevice}})
return Metal.functional()
end

# Default RNG
LuxDeviceUtils.default_device_rng(::LuxMetalDevice) = GPUArrays.default_rng(MtlArray)
Expand Down
9 changes: 9 additions & 0 deletions ext/LuxDeviceUtilsSparseArraysExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
module LuxDeviceUtilsSparseArraysExt

import Adapt: adapt_storage
import LuxDeviceUtils: LuxCPUAdaptor
import SparseArrays: AbstractSparseArray

adapt_storage(::LuxCPUAdaptor, x::AbstractSparseArray) = x

end
107 changes: 59 additions & 48 deletions src/LuxDeviceUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module LuxDeviceUtils
import PrecompileTools: @recompile_invalidations

@recompile_invalidations begin
using ChainRulesCore, Functors, LuxCore, Preferences, Random, SparseArrays
using ChainRulesCore, Functors, LuxCore, Preferences, Random
import Adapt: adapt, adapt_storage
import ChainRulesCore as CRC
end
Expand All @@ -21,33 +21,49 @@ __is_functional(::AbstractLuxDevice) = false
__is_loaded(::AbstractLuxDevice) = false

struct LuxCPUDevice <: AbstractLuxDevice end
struct LuxCUDADevice <: AbstractLuxGPUDevice end
struct LuxAMDGPUDevice <: AbstractLuxGPUDevice end
@kwdef struct LuxCUDADevice{ID} <: AbstractLuxGPUDevice
device_id::ID = nothing
end
@kwdef struct LuxAMDGPUDevice{ID} <: AbstractLuxGPUDevice
device_id::ID = nothing
end
struct LuxMetalDevice <: AbstractLuxGPUDevice end

__is_functional(::LuxCPUDevice) = true
__is_loaded(::LuxCPUDevice) = true
_with_device_id(::Type{LuxCPUDevice}, ::Nothing) = LuxCPUDevice()
function _with_device_id(::Type{LuxCPUDevice}, device_id)
@warn "`device_id` is not applicable for `LuxCPUDevice`." maxlog=1
return LuxCPUDevice()
end

_with_device_id(::Type{LuxMetalDevice}, ::Nothing) = LuxMetalDevice()
function _with_device_id(::Type{LuxMetalDevice}, device_id)
@warn "`device_id` is not applicable for `LuxMetalDevice`." maxlog=1
return LuxMetalDevice()
end

__is_functional(::Union{LuxCPUDevice, Type{<:LuxCPUDevice}}) = true
__is_loaded(::Union{LuxCPUDevice, Type{<:LuxCPUDevice}}) = true

_get_device_name(::LuxCPUDevice) = "CPU"
_get_device_name(::LuxCUDADevice) = "CUDA"
_get_device_name(::LuxAMDGPUDevice) = "AMDGPU"
_get_device_name(::LuxMetalDevice) = "Metal"
_get_device_name(::Union{LuxCPUDevice, Type{<:LuxCPUDevice}}) = "CPU"
_get_device_name(::Union{LuxCUDADevice, Type{<:LuxCUDADevice}}) = "CUDA"
_get_device_name(::Union{LuxAMDGPUDevice, Type{<:LuxAMDGPUDevice}}) = "AMDGPU"
_get_device_name(::Union{LuxMetalDevice, Type{<:LuxMetalDevice}}) = "Metal"

_get_triggerpkg_name(::LuxCPUDevice) = ""
_get_triggerpkg_name(::LuxCUDADevice) = "LuxCUDA"
_get_triggerpkg_name(::LuxAMDGPUDevice) = "LuxAMDGPU"
_get_triggerpkg_name(::LuxMetalDevice) = "Metal"
_get_triggerpkg_name(::Union{LuxCPUDevice, Type{<:LuxCPUDevice}}) = ""
_get_triggerpkg_name(::Union{LuxCUDADevice, Type{<:LuxCUDADevice}}) = "LuxCUDA"
_get_triggerpkg_name(::Union{LuxAMDGPUDevice, Type{<:LuxAMDGPUDevice}}) = "LuxAMDGPU"
_get_triggerpkg_name(::Union{LuxMetalDevice, Type{<:LuxMetalDevice}}) = "Metal"

Base.show(io::IO, dev::AbstractLuxDevice) = print(io, nameof(dev))

struct LuxDeviceSelectionException <: Exception end

function Base.showerror(io::IO, e::LuxDeviceSelectionException)
function Base.showerror(io::IO, ::LuxDeviceSelectionException)
return print(io, "LuxDeviceSelectionException(No functional GPU device found!!)")
end

# Order is important here
const GPU_DEVICES = (LuxCUDADevice(), LuxAMDGPUDevice(), LuxMetalDevice())
const GPU_DEVICES = (LuxCUDADevice, LuxAMDGPUDevice, LuxMetalDevice)

const GPU_DEVICE = Ref{Union{Nothing, AbstractLuxDevice}}(nothing)

Expand All @@ -57,27 +73,22 @@ const GPU_DEVICE = Ref{Union{Nothing, AbstractLuxDevice}}(nothing)
Resets the selected GPU device. This is useful when automatic GPU selection needs to be
run again.
"""
function reset_gpu_device!()
return GPU_DEVICE[] = nothing
end
reset_gpu_device!() = (GPU_DEVICE[] = nothing)

"""
supported_gpu_backends() -> Tuple{String, ...}
Return a tuple of supported GPU backends.
::: warning
This is not the list of functional backends on the system, but rather backends which
`Lux.jl` supports.
!!! warning
:::
This is not the list of functional backends on the system, but rather backends which
`Lux.jl` supports.
::: danger
!!! danger
`Metal.jl` support is **extremely** experimental and most things are not expected to work.
:::
`Metal.jl` support is **extremely** experimental and most things are not expected to
work.
"""
supported_gpu_backends() = map(_get_device_name, GPU_DEVICES)

Expand All @@ -95,14 +106,15 @@ Selects GPU device based on the following criteria:
invoked.
4. If nothing works, an error is thrown.
"""
function gpu_device(; force_gpu_usage::Bool=false)::AbstractLuxDevice
function gpu_device(device_id=nothing; force_gpu_usage::Bool=false)::AbstractLuxDevice
if GPU_DEVICE[] !== nothing
force_gpu_usage && !(GPU_DEVICE[] isa AbstractLuxGPUDevice) &&
throw(LuxDeviceSelectionException())
return GPU_DEVICE[]
end

device = _get_gpu_device(; force_gpu_usage)
device_type = _get_gpu_device(; force_gpu_usage)
device = _with_device_id(device_type, device_id)
GPU_DEVICE[] = device

return device
Expand All @@ -116,25 +128,25 @@ function _get_gpu_device(; force_gpu_usage::Bool)
allowed_backends = supported_gpu_backends()
idx = findfirst(isequal(backend), allowed_backends)
if backend allowed_backends
@warn """
`gpu_backend` preference is set to $backend, which is not a valid backend.
Valid backends are $allowed_backends.
Defaulting to automatic GPU Backend selection.
""" maxlog=1
@warn "`gpu_backend` preference is set to $backend, which is not a valid \
backend. Valid backends are $allowed_backends. Defaulting to automatic \
GPU Backend selection." maxlog=1
else
@debug "Using GPU backend set in preferences: $backend."
device = GPU_DEVICES[idx]
if !__is_loaded(device)
@warn """Trying to use backend: $(_get_device_name(device)) but the trigger package $(device.pkgid) is not loaded.
Ignoring the Preferences backend!!!
Please load the package and call this function again to respect the Preferences backend.""" maxlog=1
@warn "Trying to use backend: $(_get_device_name(device)) but the trigger \
package $(device.pkgid) is not loaded. Ignoring the Preferences \
backend!!! Please load the package and call this function again to \
respect the Preferences backend." maxlog=1
else
if __is_functional(device)
@debug "Using GPU backend: $(_get_device_name(device))."
return device
else
@warn "GPU backend: $(_get_device_name(device)) set via Preferences.jl is not functional.
Defaulting to automatic GPU Backend selection." maxlog=1
@warn "GPU backend: $(_get_device_name(device)) set via Preferences.jl \
is not functional. Defaulting to automatic GPU Backend \
selection." maxlog=1
end
end
end
Expand All @@ -150,7 +162,8 @@ function _get_gpu_device(; force_gpu_usage::Bool)
end
@debug "GPU backend: $(_get_device_name(device)) is not functional."
else
@debug "Trigger package for backend ($(_get_device_name(device))): $(_get_trigger_pkgname(device)) not loaded."
@debug "Trigger package for backend ($(_get_device_name(device))): \
$(_get_trigger_pkgname(device)) not loaded."
end
end

Expand Down Expand Up @@ -188,7 +201,8 @@ gpu_backend!() = gpu_backend!("")
function gpu_backend!(backend::String)
if backend == ""
@delete_preferences!("gpu_backend")
@info "Deleted the local preference for `gpu_backend`. Restart Julia to use the new backend."
@info "Deleted the local preference for `gpu_backend`. Restart Julia to use the \
new backend."
return
end

Expand Down Expand Up @@ -250,8 +264,8 @@ for (dev) in (:CPU, :CUDA, :AMDGPU, :Metal)
end
function (::$(ldev))(NN::LuxCore.AbstractExplicitLayer)
@warn "Lux layers are stateless and hence don't participate in device \
transfers. Apply this function on the parameters and states generated \
using `Lux.setup`." maxlog=1
transfers. Apply this function on the parameters and states generated \
using `Lux.setup`." maxlog=1
return NN
end
end
Expand All @@ -264,7 +278,7 @@ end
Returns the device of the array `x`. Trigger Packages must be loaded for this to return the
correct device.
"""
get_device(x::AbstractArray) = LuxCPUDevice()
get_device(::AbstractArray) = LuxCPUDevice()

# Adapt Interface
abstract type AbstractLuxDeviceAdaptor end
Expand All @@ -274,10 +288,7 @@ struct LuxCUDAAdaptor <: AbstractLuxDeviceAdaptor end
struct LuxAMDGPUAdaptor <: AbstractLuxDeviceAdaptor end
struct LuxMetalAdaptor <: AbstractLuxDeviceAdaptor end

function adapt_storage(::LuxCPUAdaptor,
x::Union{AbstractRange, SparseArrays.AbstractSparseArray})
return x
end
adapt_storage(::LuxCPUAdaptor, x::AbstractRange) = x
adapt_storage(::LuxCPUAdaptor, x::AbstractArray) = adapt(Array, x)
adapt_storage(::LuxCPUAdaptor, rng::AbstractRNG) = rng

Expand Down

0 comments on commit 35030c1

Please sign in to comment.