Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Commit

Permalink
refactor: move internal functions into separate modules
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Aug 19, 2024
1 parent e759c2c commit 877e7aa
Show file tree
Hide file tree
Showing 16 changed files with 551 additions and 559 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MLDataDevices"
uuid = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "1.0.1"
version = "1.0.2"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down
20 changes: 9 additions & 11 deletions ext/MLDataDevicesAMDGPUExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@ module MLDataDevicesAMDGPUExt

using Adapt: Adapt
using AMDGPU: AMDGPU
using MLDataDevices: MLDataDevices, AMDGPUDevice, CPUDevice, reset_gpu_device!
using MLDataDevices: MLDataDevices, Internal, AMDGPUDevice, CPUDevice, reset_gpu_device!
using Random: Random

__init__() = reset_gpu_device!()

# This code used to be in `LuxAMDGPU.jl`, but we no longer need that package.
const USE_AMD_GPU = Ref{Union{Nothing, Bool}}(nothing)

function _check_use_amdgpu!()
function check_use_amdgpu!()
USE_AMD_GPU[] === nothing || return

USE_AMD_GPU[] = AMDGPU.functional()
Expand All @@ -23,14 +23,12 @@ end

MLDataDevices.loaded(::Union{AMDGPUDevice, <:Type{AMDGPUDevice}}) = true
function MLDataDevices.functional(::Union{AMDGPUDevice, <:Type{AMDGPUDevice}})::Bool
_check_use_amdgpu!()
check_use_amdgpu!()
return USE_AMD_GPU[]
end

function MLDataDevices._with_device(::Type{AMDGPUDevice}, ::Nothing)
return AMDGPUDevice(nothing)
end
function MLDataDevices._with_device(::Type{AMDGPUDevice}, id::Integer)
Internal.with_device(::Type{AMDGPUDevice}, ::Nothing) = AMDGPUDevice(nothing)
function Internal.with_device(::Type{AMDGPUDevice}, id::Integer)
id > length(AMDGPU.devices()) &&
throw(ArgumentError("id = $id > length(AMDGPU.devices()) = $(length(AMDGPU.devices()))"))
old_dev = AMDGPU.device()
Expand All @@ -40,19 +38,19 @@ function MLDataDevices._with_device(::Type{AMDGPUDevice}, id::Integer)
return device
end

MLDataDevices._get_device_id(dev::AMDGPUDevice) = AMDGPU.device_id(dev.device)
Internal.get_device_id(dev::AMDGPUDevice) = AMDGPU.device_id(dev.device)

# Default RNG
MLDataDevices.default_device_rng(::AMDGPUDevice) = AMDGPU.rocrand_rng()

# Query Device from Array
function MLDataDevices._get_device(x::AMDGPU.AnyROCArray)
function Internal.get_device(x::AMDGPU.AnyROCArray)
parent_x = parent(x)
parent_x === x && return AMDGPUDevice(AMDGPU.device(x))
return MLDataDevices._get_device(parent_x)
return Internal.get_device(parent_x)
end

MLDataDevices._get_device_type(::AMDGPU.AnyROCArray) = AMDGPUDevice
Internal.get_device_type(::AMDGPU.AnyROCArray) = AMDGPUDevice

# Set Device
function MLDataDevices.set_device!(::Type{AMDGPUDevice}, dev::AMDGPU.HIPDevice)
Expand Down
28 changes: 9 additions & 19 deletions ext/MLDataDevicesCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@ module MLDataDevicesCUDAExt

using Adapt: Adapt
using CUDA: CUDA
using CUDA.CUSPARSE: AbstractCuSparseMatrix, AbstractCuSparseVector
using MLDataDevices: MLDataDevices, CUDADevice, CPUDevice
using CUDA.CUSPARSE: AbstractCuSparseMatrix, AbstractCuSparseVector, AbstractCuSparseArray
using MLDataDevices: MLDataDevices, Internal, CUDADevice, CPUDevice
using Random: Random

function MLDataDevices._with_device(::Type{CUDADevice}, id::Integer)
Internal.with_device(::Type{CUDADevice}, ::Nothing) = CUDADevice(nothing)
function Internal.with_device(::Type{CUDADevice}, id::Integer)
id > length(CUDA.devices()) &&
throw(ArgumentError("id = $id > length(CUDA.devices()) = $(length(CUDA.devices()))"))
old_dev = CUDA.device()
Expand All @@ -16,34 +17,23 @@ function MLDataDevices._with_device(::Type{CUDADevice}, id::Integer)
return device
end

function MLDataDevices._with_device(::Type{CUDADevice}, ::Nothing)
return CUDADevice(nothing)
end

MLDataDevices._get_device_id(dev::CUDADevice) = CUDA.deviceid(dev.device) + 1
Internal.get_device_id(dev::CUDADevice) = CUDA.deviceid(dev.device) + 1

# Default RNG
MLDataDevices.default_device_rng(::CUDADevice) = CUDA.default_rng()

# Query Device from Array
function MLDataDevices._get_device(x::CUDA.AnyCuArray)
function Internal.get_device(x::CUDA.AnyCuArray)
parent_x = parent(x)
parent_x === x && return CUDADevice(CUDA.device(x))
return MLDataDevices.get_device(parent_x)
end
function MLDataDevices._get_device(x::CUDA.CUSPARSE.AbstractCuSparseArray)
return CUDADevice(CUDA.device(x.nzVal))
end
Internal.get_device(x::AbstractCuSparseArray) = CUDADevice(CUDA.device(x.nzVal))

function MLDataDevices._get_device_type(::Union{
<:CUDA.AnyCuArray, <:CUDA.CUSPARSE.AbstractCuSparseArray})
return CUDADevice
end
Internal.get_device_type(::Union{<:CUDA.AnyCuArray, <:AbstractCuSparseArray}) = CUDADevice

# Set Device
function MLDataDevices.set_device!(::Type{CUDADevice}, dev::CUDA.CuDevice)
return CUDA.device!(dev)
end
MLDataDevices.set_device!(::Type{CUDADevice}, dev::CUDA.CuDevice) = CUDA.device!(dev)
function MLDataDevices.set_device!(::Type{CUDADevice}, id::Integer)
return MLDataDevices.set_device!(CUDADevice, collect(CUDA.devices())[id])
end
Expand Down
10 changes: 4 additions & 6 deletions ext/MLDataDevicesMetalExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,21 @@ module MLDataDevicesMetalExt

using Adapt: Adapt
using GPUArrays: GPUArrays
using MLDataDevices: MLDataDevices, MetalDevice, reset_gpu_device!
using MLDataDevices: MLDataDevices, Internal, MetalDevice, reset_gpu_device!
using Metal: Metal, MtlArray

__init__() = reset_gpu_device!()

MLDataDevices.loaded(::Union{MetalDevice, Type{<:MetalDevice}}) = true
function MLDataDevices.functional(::Union{MetalDevice, Type{<:MetalDevice}})
return Metal.functional()
end
MLDataDevices.functional(::Union{MetalDevice, Type{<:MetalDevice}}) = Metal.functional()

# Default RNG
MLDataDevices.default_device_rng(::MetalDevice) = GPUArrays.default_rng(MtlArray)

# Query Device from Array
MLDataDevices._get_device(::MtlArray) = MetalDevice()
Internal.get_device(::MtlArray) = MetalDevice()

MLDataDevices._get_device_type(::MtlArray) = MetalDevice
Internal.get_device_type(::MtlArray) = MetalDevice

# Device Transfer
## To GPU
Expand Down
10 changes: 5 additions & 5 deletions ext/MLDataDevicesRecursiveArrayToolsExt.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module MLDataDevicesRecursiveArrayToolsExt

using Adapt: Adapt, adapt
using MLDataDevices: MLDataDevices, AbstractDevice
using MLDataDevices: MLDataDevices, Internal, AbstractDevice
using RecursiveArrayTools: VectorOfArray, DiffEqArray

# We want to preserve the structure
Expand All @@ -14,10 +14,10 @@ function Adapt.adapt_structure(to::AbstractDevice, x::DiffEqArray)
return DiffEqArray(map(Base.Fix1(adapt, to), x.u), x.t)
end

for op in (:_get_device, :_get_device_type)
@eval function MLDataDevices.$op(x::Union{VectorOfArray, DiffEqArray})
length(x.u) == 0 && return $(op == :_get_device ? nothing : Nothing)
return mapreduce(MLDataDevices.$op, MLDataDevices.__combine_devices, x.u)
for op in (:get_device, :get_device_type)
@eval function Internal.$(op)(x::Union{VectorOfArray, DiffEqArray})
length(x.u) == 0 && return $(op == :get_device ? nothing : Nothing)
return mapreduce(Internal.$(op), Internal.combine_devices, x.u)
end
end

Expand Down
12 changes: 4 additions & 8 deletions ext/MLDataDevicesReverseDiffExt.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,12 @@
module MLDataDevicesReverseDiffExt

using MLDataDevices: MLDataDevices
using MLDataDevices: Internal
using ReverseDiff: ReverseDiff

for op in (:_get_device, :_get_device_type)
for op in (:get_device, :get_device_type)
@eval begin
function MLDataDevices.$op(x::ReverseDiff.TrackedArray)
return MLDataDevices.$op(ReverseDiff.value(x))
end
function MLDataDevices.$op(x::AbstractArray{<:ReverseDiff.TrackedReal})
return MLDataDevices.$op(ReverseDiff.value.(x))
end
Internal.$(op)(x::ReverseDiff.TrackedArray) = Internal.$(op)(ReverseDiff.value(x))
Internal.$(op)(x::AbstractArray{<:ReverseDiff.TrackedReal}) = Internal.$(op)(ReverseDiff.value.(x))
end
end

Expand Down
14 changes: 5 additions & 9 deletions ext/MLDataDevicesTrackerExt.jl
Original file line number Diff line number Diff line change
@@ -1,19 +1,15 @@
module MLDataDevicesTrackerExt

using Adapt: Adapt
using MLDataDevices: MLDataDevices, AMDGPUDevice, CUDADevice, MetalDevice, oneAPIDevice
using MLDataDevices: Internal, AMDGPUDevice, CUDADevice, MetalDevice, oneAPIDevice
using Tracker: Tracker

for op in (:_get_device, :_get_device_type)
@eval begin
MLDataDevices.$op(x::Tracker.TrackedArray) = MLDataDevices.$op(Tracker.data(x))
function MLDataDevices.$op(x::AbstractArray{<:Tracker.TrackedReal})
return MLDataDevices.$op(Tracker.data.(x))
end
end
for op in (:get_device, :get_device_type)
@eval Internal.$(op)(x::Tracker.TrackedArray) = Internal.$(op)(Tracker.data(x))
@eval Internal.$(op)(x::AbstractArray{<:Tracker.TrackedReal}) = Internal.$(op)(Tracker.data.(x))
end

MLDataDevices.__special_aos(::AbstractArray{<:Tracker.TrackedReal}) = true
Internal.special_aos(::AbstractArray{<:Tracker.TrackedReal}) = true

for T in (AMDGPUDevice, AMDGPUDevice{Nothing}, CUDADevice,
CUDADevice{Nothing}, MetalDevice, oneAPIDevice)
Expand Down
6 changes: 3 additions & 3 deletions ext/MLDataDevicesoneAPIExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module MLDataDevicesoneAPIExt

using Adapt: Adapt
using GPUArrays: GPUArrays
using MLDataDevices: MLDataDevices, oneAPIDevice, reset_gpu_device!
using MLDataDevices: MLDataDevices, Internal, oneAPIDevice, reset_gpu_device!
using oneAPI: oneAPI, oneArray, oneL0

const SUPPORTS_FP64 = Dict{oneL0.ZeDevice, Bool}()
Expand All @@ -25,9 +25,9 @@ end
MLDataDevices.default_device_rng(::oneAPIDevice) = GPUArrays.default_rng(oneArray)

# Query Device from Array
MLDataDevices._get_device(::oneArray) = oneAPIDevice()
Internal.get_device(::oneArray) = oneAPIDevice()

MLDataDevices._get_device_type(::oneArray) = oneAPIDevice
Internal.get_device_type(::oneArray) = oneAPIDevice

# Device Transfer
## To GPU
Expand Down
Loading

0 comments on commit 877e7aa

Please sign in to comment.