Skip to content

Commit

Permalink
Handle using MLUtils
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Apr 6, 2024
1 parent 1d0b3b6 commit 6293e2a
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 20 deletions.
8 changes: 4 additions & 4 deletions examples/ImageNet/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@ using Formatting # TODO: Get rid of this

# Distributed Training: NCCL for NVIDIA GPUs and MPI for anything else
if LuxCUDA.functional()
DistributedUtils.initialize(Val(:NCCL))
backend = DistributedUtils.get_distributed_backend(Val(:NCCL))
DistributedUtils.initialize(NCCLBackend)
backend = DistributedUtils.get_distributed_backend(NCCLBackend)
else
DistributedUtils.initialize(Val(:MPI))
backend = DistributedUtils.get_distributed_backend(Val(:MPI))
DistributedUtils.initialize(MPIBackend)
backend = DistributedUtils.get_distributed_backend(MPIBackend)
end

const local_rank = DistributedUtils.local_rank(backend)
Expand Down
6 changes: 5 additions & 1 deletion ext/LuxMLUtilsExt.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module LuxMLUtilsExt

using Lux: DistributedUtils
using MLUtils: numobs
using MLUtils: MLUtils, numobs

function DistributedUtils.__construct_distributed_data_container(
backend::DistributedUtils.AbstractLuxDistributedBackend, data)
Expand All @@ -15,4 +15,8 @@ function DistributedUtils.__construct_distributed_data_container(
return DistributedUtils.DistributedDataContainer(data, idxs)
end

function MLUtils.getobs(dc::DistributedUtils.DistributedDataContainer, idx)
return MLUtils.getobs(dc.data, dc.idxs[idx])
end

end
4 changes: 2 additions & 2 deletions ext/LuxMPIExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ using LuxDeviceUtils: AbstractLuxDevice, LuxCUDADevice, LuxAMDGPUDevice, cpu_dev
using MPI: MPI

function DistributedUtils.__initialize(
::Val{:MPI}; cuda_devices=nothing, amdgpu_devices=nothing)
::Type{MPIBackend}; cuda_devices=nothing, amdgpu_devices=nothing)
!MPI.Initialized() && MPI.Init()
DistributedUtils.MPI_Initialized[] = true

Expand All @@ -32,7 +32,7 @@ function DistributedUtils.__initialize(
return
end

DistributedUtils.__get_distributed_backend(::Val{:MPI}) = MPIBackend(MPI.COMM_WORLD)
DistributedUtils.__get_distributed_backend(::Type{MPIBackend}) = MPIBackend(MPI.COMM_WORLD)

DistributedUtils.local_rank(backend::MPIBackend) = MPI.Comm_rank(backend.comm)

Expand Down
4 changes: 2 additions & 2 deletions ext/LuxMPINCCLExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@ using NCCL: NCCL
using Setfield: @set!

function DistributedUtils.__initialize(
::Val{:NCCL}; cuda_devices=nothing, amdgpu_devices=missing)
::Type{NCCLBackend}; cuda_devices=nothing, amdgpu_devices=missing)
DistributedUtils.NCCL_Initialized[] = true
@assert amdgpu_devices===missing "`AMDGPU` is not supported by `NCCL`."
DistributedUtils.__initialize(Val(:MPI); cuda_devices, amdgpu_devices)
return
end

function DistributedUtils.__get_distributed_backend(::Val{:NCCL})
function DistributedUtils.__get_distributed_backend(::Type{NCCLBackend})
unique_id = NCCL.UniqueID() # Generate on all ranks to know the type
mpi_backend = DistributedUtils.__get_distributed_backend(Val(:MPI))
buf = [unique_id.internal...]
Expand Down
22 changes: 11 additions & 11 deletions src/distributed/public_api.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@ const NCCL_Initialized = Ref(false)
const MPI_Initialized = Ref(false)

"""
initialized(backend::Val)
initialized(backend::Type{<:AbstractLuxDistributedBackend})
Check if the given backend is initialized.
"""
initialized(::Val{:MPI}) = MPI_Initialized[]
initialized(::Val{:NCCL}) = NCCL_Initialized[]
initialized(::Type{<:MPIBackend}) = MPI_Initialized[]
initialized(::Type{<:NCCLBackend}) = NCCL_Initialized[]

"""
initialize(backend::Val; kwargs...)
initialize(backend::Type{<:AbstractLuxDistributedBackend}; kwargs...)
Initialize the given backend. Users can supply `cuda_devices` and `amdgpu_devices` to
initialize the backend with the given devices. These can be set to `missing` to prevent
Expand All @@ -31,12 +31,12 @@ initialize the backend with the given devices.
Possible values for `backend` are:
- `Val(:MPI)`: MPI backend for distributed training. Requires `MPI.jl` to be installed.
- `Val(:NCCL)`: NCCL backend for CUDA distributed training. Requires `CUDA.jl`,
- `MPIBackend`: MPI backend for distributed training. Requires `MPI.jl` to be installed.
- `NCCLBackend`: NCCL backend for CUDA distributed training. Requires `CUDA.jl`,
`MPI.jl`, and `NCCL.jl` to be installed. This also wraps `MPI` backend for non-CUDA
communications.
"""
function initialize(backend::Val; kwargs...)
function initialize(backend::Type{<:AbstractLuxDistributedBackend}; kwargs...)
initialized(backend) && return
__initialize(backend; kwargs...)
return
Expand All @@ -45,20 +45,20 @@ end
function __initialize end

"""
get_distributed_backend(backend::Val)
get_distributed_backend(backend::Type{<:AbstractLuxDistributedBackend})
Get the distributed backend for the given backend type. Possible values are:
- `Val(:MPI)`: MPI backend for distributed training. Requires `MPI.jl` to be installed.
- `Val(:NCCL)`: NCCL backend for CUDA distributed training. Requires `CUDA.jl`,
- `MPIBackend`: MPI backend for distributed training. Requires `MPI.jl` to be installed.
- `NCCLBackend`: NCCL backend for CUDA distributed training. Requires `CUDA.jl`,
`MPI.jl`, and `NCCL.jl` to be installed. This also wraps `MPI` backend for non-CUDA
communications.
!!! danger
`initialize(backend; kwargs...)` must be called before calling this function.
"""
function get_distributed_backend(backend::Val)
function get_distributed_backend(backend::Type{<:AbstractLuxDistributedBackend})
initialized(backend) ||
error("Backend `$(backend)` is not initialized. Call `DistributedUtils.initialize` first.")
return __get_distributed_backend(backend)
Expand Down

0 comments on commit 6293e2a

Please sign in to comment.