Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Commit

Permalink
feat: handle RNGs and undef arrays gracefully
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Oct 18, 2024
1 parent 6c3a4a7 commit 7360263
Show file tree
Hide file tree
Showing 7 changed files with 17 additions and 12 deletions.
2 changes: 1 addition & 1 deletion .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
steps:
- label: "Triggering Pipelines (Pull Request)"
if: "build.pull_request.base_branch == 'main'"
if: build.branch != "main" && build.tag == null
agents:
queue: "juliagpu"
plugins:
Expand Down
2 changes: 2 additions & 0 deletions ext/MLDataDevicesAMDGPUExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,10 @@ function Internal.get_device(x::AMDGPU.AnyROCArray)
parent_x === x && return AMDGPUDevice(AMDGPU.device(x))
return Internal.get_device(parent_x)
end
Internal.get_device(::AMDGPU.rocRAND.RNG) = AMDGPUDevice(AMDGPU.device())

Internal.get_device_type(::AMDGPU.AnyROCArray) = AMDGPUDevice
Internal.get_device_type(::AMDGPU.rocRAND.RNG) = AMDGPUDevice

# Set Device
function MLDataDevices.set_device!(::Type{AMDGPUDevice}, dev::AMDGPU.HIPDevice)
Expand Down
2 changes: 2 additions & 0 deletions ext/MLDataDevicesCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,10 @@ function Internal.get_device(x::CUDA.AnyCuArray)
return MLDataDevices.get_device(parent_x)
end
Internal.get_device(x::AbstractCuSparseArray) = CUDADevice(CUDA.device(x.nzVal))
Internal.get_device(::CUDA.RNG) = CUDADevice(CUDA.device())

Internal.get_device_type(::Union{<:CUDA.AnyCuArray, <:AbstractCuSparseArray}) = CUDADevice
Internal.get_device_type(::CUDA.RNG) = CUDADevice

# Set Device
MLDataDevices.set_device!(::Type{CUDADevice}, dev::CUDA.CuDevice) = CUDA.device!(dev)
Expand Down
5 changes: 4 additions & 1 deletion ext/MLDataDevicesGPUArraysExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@ module MLDataDevicesGPUArraysExt

using Adapt: Adapt
using GPUArrays: GPUArrays
using MLDataDevices: CPUDevice
using MLDataDevices: Internal, CPUDevice
using Random: Random

Adapt.adapt_storage(::CPUDevice, rng::GPUArrays.RNG) = Random.default_rng()

Internal.get_device(rng::GPUArrays.RNG) = Internal.get_device(rng.state)
Internal.get_device_type(rng::GPUArrays.RNG) = Internal.get_device_type(rng.state)

end
9 changes: 6 additions & 3 deletions src/internal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -129,15 +129,16 @@ end

for op in (:get_device, :get_device_type)
cpu_ret_val = op == :get_device ? CPUDevice() : CPUDevice
unknown_ret_val = op == :get_device ? UnknownDevice() : UnknownDevice
not_assigned_msg = "AbstractArray has some undefined references. Giving up, returning \
$(cpu_ret_val)..."
$(unknown_ret_val)..."

@eval begin
function $(op)(x::AbstractArray{T}) where {T}
if recursive_array_eltype(T)
if any(!isassigned(x, i) for i in eachindex(x))
@warn $(not_assigned_msg)
return $(cpu_ret_val)
return $(unknown_ret_val)
end
return mapreduce(MLDataDevices.$(op), combine_devices, x)
end
Expand All @@ -160,9 +161,11 @@ for op in (:get_device, :get_device_type)
return unrolled_mapreduce(MLDataDevices.$(op), combine_devices,
map(Base.Fix1(getfield, f), fieldnames(F)))
end

$(op)(::AbstractRNG) = $(cpu_ret_val)
end

for T in (Number, AbstractRNG, Val, Symbol, String, Nothing, AbstractRange)
for T in (Number, Val, Symbol, String, Nothing, AbstractRange)
@eval $(op)(::$(T)) = $(op == :get_device ? nothing : Nothing)
end
end
Expand Down
5 changes: 0 additions & 5 deletions src/public.jl
Original file line number Diff line number Diff line change
Expand Up @@ -232,11 +232,6 @@ const GET_DEVICE_ADMONITIONS = """
!!! note
Trigger Packages must be loaded for this to return the correct device.
!!! warning
RNG types currently don't participate in device determination. We will remove this
restriction in the future.
"""

# Query Device from Array
Expand Down
4 changes: 2 additions & 2 deletions test/misc_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,6 @@ end
@testset "undefined references array" begin
x = Matrix{Any}(undef, 10, 10)

@test get_device(x) isa CPUDevice
@test get_device_type(x) <: CPUDevice
@test get_device(x) isa MLDataDevices.UnknownDevice
@test get_device_type(x) <: MLDataDevices.UnknownDevice
end

0 comments on commit 7360263

Please sign in to comment.