Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Commit

Permalink
feat: add support for OpenCL devices
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Sep 20, 2024
1 parent 3d49642 commit b232d1b
Show file tree
Hide file tree
Showing 6 changed files with 57 additions and 13 deletions.
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MLDataDevices"
uuid = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "1.1.1"
version = "1.2.0"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand All @@ -17,6 +17,7 @@ FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
OpenCL = "08131aa3-fb12-5dee-8b74-c09406e224a2"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Expand All @@ -33,6 +34,7 @@ MLDataDevicesFillArraysExt = "FillArrays"
MLDataDevicesGPUArraysExt = "GPUArrays"
MLDataDevicesMLUtilsExt = "MLUtils"
MLDataDevicesMetalExt = ["GPUArrays", "Metal"]
MLDataDevicesOpenCLExt = ["GPUArrays", "OpenCL"]
MLDataDevicesRecursiveArrayToolsExt = "RecursiveArrayTools"
MLDataDevicesReverseDiffExt = "ReverseDiff"
MLDataDevicesSparseArraysExt = "SparseArrays"
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ Currently we provide support for the following backends:
2. `AMDGPU.jl` for AMD ROCM GPUs.
3. `Metal.jl` for Apple Metal GPUs. **(Experimental)**
4. `oneAPI.jl` for Intel GPUs. **(Experimental)**
5. `OpenCL.jl` for openCL devices. **(Extremely Experimental)**

## Updating to v1.0

Expand Down
36 changes: 36 additions & 0 deletions ext/MLDataDevicesOpenCLExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
module MLDataDevicesOpenCLExt

using Adapt: Adapt
using MLDataDevices: MLDataDevices, Internal, OpenCLDevice, reset_gpu_device!
using GPUArrays: GPUArrays
using OpenCL: OpenCL, CLArray

__init__() = reset_gpu_device!()

MLDataDevices.loaded(::Union{OpenCLDevice, Type{<:OpenCLDevice}}) = true
# TODO: Check if OpenCL can provide a `functional` function.
MLDataDevices.functional(::Union{OpenCLDevice, Type{<:OpenCLDevice}}) = true

# Default RNG
MLDataDevices.default_device_rng(::OpenCLDevice) = GPUArrays.default_rng(CLArray)

# Query Device from Array
Internal.get_device(::CLArray) = OpenCLDevice()

Internal.get_device_type(::CLArray) = OpenCLDevice

# unsafe_free!
function Internal.unsafe_free_internal!(::Type{OpenCLDevice}, ::AbstractArray)
# TODO: Implement this
@warn "Support for `unsafe_free!` for OpenCL is not implemented yet. This is a no-op." maxlog=1
return
end

# Device Transfer
Adapt.adapt_storage(::OpenCLDevice, x::AbstractArray) = CLArray(x)

# TODO: Eventually we want to do robust device management, since it is possible users
# change the device after creating the OpenCLDevice and that might cuase unwanted

Check warning on line 33 in ext/MLDataDevicesOpenCLExt.jl

View workflow job for this annotation

GitHub Actions / Spell Check with Typos

"cuase" should be "cause".
# behavior.

end
4 changes: 3 additions & 1 deletion src/MLDataDevices.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ using Preferences: @delete_preferences!, @load_preference, @set_preferences!
using Random: AbstractRNG, Random

abstract type AbstractDevice <: Function end
abstract type AbstractCPUDevice <: AbstractDevice end
abstract type AbstractGPUDevice <: AbstractDevice end

include("public.jl")
Expand All @@ -16,7 +17,8 @@ export gpu_backend!, supported_gpu_backends, reset_gpu_device!
export default_device_rng
export gpu_device, cpu_device

export CPUDevice, CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice
export CPUDevice
export CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice, OpenCLDevice
export get_device, get_device_type

export DeviceIterator
Expand Down
14 changes: 8 additions & 6 deletions src/internal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@ using Preferences: load_preference
using Random: AbstractRNG

using ..MLDataDevices: MLDataDevices, AbstractDevice, CPUDevice, CUDADevice, AMDGPUDevice,
MetalDevice, oneAPIDevice, supported_gpu_backends, GPU_DEVICES,
loaded, functional
MetalDevice, oneAPIDevice, OpenCLDevice, supported_gpu_backends,
GPU_DEVICES, loaded, functional

for dev in (CPUDevice, MetalDevice, oneAPIDevice)
for dev in (CPUDevice, MetalDevice, oneAPIDevice, OpenCLDevice)
msg = "`device_id` is not applicable for `$dev`."
@eval begin
with_device(::Type{$dev}, ::Nothing) = $dev()
Expand All @@ -19,7 +19,7 @@ for dev in (CPUDevice, MetalDevice, oneAPIDevice)
end
end

for name in (:CPU, :CUDA, :AMDGPU, :Metal, :oneAPI)
for name in (:CPU, :CUDA, :AMDGPU, :Metal, :oneAPI, :OpenCL)
tpkg = name === :CPU ? "" : string(name)
ldev = Symbol(name, :Device)
@eval begin
Expand All @@ -28,7 +28,8 @@ for name in (:CPU, :CUDA, :AMDGPU, :Metal, :oneAPI)
end
end

for T in (CPUDevice, CUDADevice{Nothing}, AMDGPUDevice{Nothing}, MetalDevice, oneAPIDevice)
for T in (CPUDevice, CUDADevice{Nothing}, AMDGPUDevice{Nothing},
MetalDevice, oneAPIDevice, OpenCLDevice)
@eval get_device_id(::$(T)) = nothing
end

Expand Down Expand Up @@ -93,7 +94,8 @@ function get_gpu_device(; force_gpu_usage::Bool)
a. `CUDA.jl` and `cuDNN.jl` (or just `LuxCUDA.jl`) for NVIDIA CUDA Support.
b. `AMDGPU.jl` for AMD GPU ROCM Support.
c. `Metal.jl` for Apple Metal GPU Support. (Experimental)
d. `oneAPI.jl` for Intel oneAPI GPU Support. (Experimental)""" maxlog=1
d. `oneAPI.jl` for Intel oneAPI GPU Support. (Experimental)
e. `OpenCL.jl` for OpenCL Support. (Extremely Experimental)""" maxlog=1
return CPUDevice
end

Expand Down
11 changes: 6 additions & 5 deletions src/public.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
struct CPUDevice <: AbstractDevice end
struct CPUDevice <: AbstractCPUDevice end
@kwdef struct CUDADevice{D} <: AbstractGPUDevice
device::D = nothing
end
Expand All @@ -7,6 +7,7 @@ end
end
struct MetalDevice <: AbstractGPUDevice end
struct oneAPIDevice <: AbstractGPUDevice end
struct OpenCLDevice <: AbstractGPUDevice end

"""
functional(x::AbstractDevice) -> Bool
Expand Down Expand Up @@ -36,7 +37,7 @@ loaded(x) = false
loaded(::Union{CPUDevice, Type{<:CPUDevice}}) = true

# Order is important here
const GPU_DEVICES = (CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice)
const GPU_DEVICES = (CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice, OpenCLDevice)

const GPU_DEVICE = Ref{Union{Nothing, AbstractDevice}}(nothing)

Expand Down Expand Up @@ -292,7 +293,7 @@ end
# Abstract Array / Tuples / NamedTuples have special fast paths to facilitate type stability
# For all other types we rely on fmap which means we lose type stability.
# For Lux, typically models only has these 3 datastructures so we should be mostly fine.
for (dev) in (:CPU, :CUDA, :AMDGPU, :Metal, :oneAPI)
for (dev) in (:CPU, :CUDA, :AMDGPU, :Metal, :oneAPI, :OpenCL)
ldev = Symbol(dev, :Device)
@eval begin
function (D::$(ldev))(x::AbstractArray{T}) where {T}
Expand All @@ -318,7 +319,7 @@ end
Adapt.adapt_storage(::CPUDevice, x::AbstractArray) = Adapt.adapt(Array, x)
Adapt.adapt_storage(::CPUDevice, rng::AbstractRNG) = rng

for T in (AMDGPUDevice, CUDADevice, MetalDevice, oneAPIDevice)
for T in (AMDGPUDevice, CUDADevice, MetalDevice, oneAPIDevice, OpenCLDevice)
@eval begin
function Adapt.adapt_storage(to::$(T), ::Random.TaskLocalRNG)
return default_device_rng(to)
Expand All @@ -330,6 +331,6 @@ end
Adapt.adapt_storage(::CPUDevice, x::AbstractRange) = x
# Prevent Ambiguity
for T in (AMDGPUDevice, AMDGPUDevice{Nothing}, CUDADevice,
CUDADevice{Nothing}, MetalDevice, oneAPIDevice)
CUDADevice{Nothing}, MetalDevice, oneAPIDevice, OpenCLDevice)
@eval Adapt.adapt_storage(to::$(T), x::AbstractRange) = Adapt.adapt(to, collect(x))
end

0 comments on commit b232d1b

Please sign in to comment.