From 0310b9d98fd3b468f9d9cb753760bf8d60d89fa5 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 18 Oct 2024 12:37:13 -0400 Subject: [PATCH] feat: handle RNGs and undef arrays gracefully --- .buildkite/pipeline.yml | 2 +- ext/MLDataDevicesAMDGPUExt.jl | 2 ++ ext/MLDataDevicesCUDAExt.jl | 4 ++++ ext/MLDataDevicesGPUArraysExt.jl | 5 ++++- src/internal.jl | 9 ++++++--- src/public.jl | 5 ----- test/misc_tests.jl | 4 ++-- 7 files changed, 19 insertions(+), 12 deletions(-) diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index 2c00e63..a8c37f0 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -1,6 +1,6 @@ steps: - label: "Triggering Pipelines (Pull Request)" - if: "build.pull_request.base_branch == 'main'" + if: build.branch != "main" && build.tag == null agents: queue: "juliagpu" plugins: diff --git a/ext/MLDataDevicesAMDGPUExt.jl b/ext/MLDataDevicesAMDGPUExt.jl index 4014b2e..ca275b5 100644 --- a/ext/MLDataDevicesAMDGPUExt.jl +++ b/ext/MLDataDevicesAMDGPUExt.jl @@ -49,8 +49,10 @@ function Internal.get_device(x::AMDGPU.AnyROCArray) parent_x === x && return AMDGPUDevice(AMDGPU.device(x)) return Internal.get_device(parent_x) end +Internal.get_device(::AMDGPU.rocRAND.RNG) = AMDGPUDevice(AMDGPU.device()) Internal.get_device_type(::AMDGPU.AnyROCArray) = AMDGPUDevice +Internal.get_device_type(::AMDGPU.rocRAND.RNG) = AMDGPUDevice # Set Device function MLDataDevices.set_device!(::Type{AMDGPUDevice}, dev::AMDGPU.HIPDevice) diff --git a/ext/MLDataDevicesCUDAExt.jl b/ext/MLDataDevicesCUDAExt.jl index 3492440..9355b81 100644 --- a/ext/MLDataDevicesCUDAExt.jl +++ b/ext/MLDataDevicesCUDAExt.jl @@ -29,8 +29,12 @@ function Internal.get_device(x::CUDA.AnyCuArray) return MLDataDevices.get_device(parent_x) end Internal.get_device(x::AbstractCuSparseArray) = CUDADevice(CUDA.device(x.nzVal)) +Internal.get_device(::CUDA.RNG) = CUDADevice(CUDA.device()) +Internal.get_device(::CUDA.CURAND.RNG) = CUDADevice(CUDA.device()) Internal.get_device_type(::Union{<:CUDA.AnyCuArray, <:AbstractCuSparseArray}) = CUDADevice +Internal.get_device_type(::CUDA.RNG) = CUDADevice +Internal.get_device_type(::CUDA.CURAND.RNG) = CUDADevice # Set Device MLDataDevices.set_device!(::Type{CUDADevice}, dev::CUDA.CuDevice) = CUDA.device!(dev) diff --git a/ext/MLDataDevicesGPUArraysExt.jl b/ext/MLDataDevicesGPUArraysExt.jl index daf7eb3..a09a386 100644 --- a/ext/MLDataDevicesGPUArraysExt.jl +++ b/ext/MLDataDevicesGPUArraysExt.jl @@ -2,9 +2,12 @@ module MLDataDevicesGPUArraysExt using Adapt: Adapt using GPUArrays: GPUArrays -using MLDataDevices: CPUDevice +using MLDataDevices: Internal, CPUDevice using Random: Random Adapt.adapt_storage(::CPUDevice, rng::GPUArrays.RNG) = Random.default_rng() +Internal.get_device(rng::GPUArrays.RNG) = Internal.get_device(rng.state) +Internal.get_device_type(rng::GPUArrays.RNG) = Internal.get_device_type(rng.state) + end diff --git a/src/internal.jl b/src/internal.jl index bcc8cab..387d3af 100644 --- a/src/internal.jl +++ b/src/internal.jl @@ -129,15 +129,16 @@ end for op in (:get_device, :get_device_type) cpu_ret_val = op == :get_device ? CPUDevice() : CPUDevice + unknown_ret_val = op == :get_device ? UnknownDevice() : UnknownDevice not_assigned_msg = "AbstractArray has some undefined references. Giving up, returning \ - $(cpu_ret_val)..." + $(unknown_ret_val)..." @eval begin function $(op)(x::AbstractArray{T}) where {T} if recursive_array_eltype(T) if any(!isassigned(x, i) for i in eachindex(x)) @warn $(not_assigned_msg) - return $(cpu_ret_val) + return $(unknown_ret_val) end return mapreduce(MLDataDevices.$(op), combine_devices, x) end @@ -160,9 +161,11 @@ for op in (:get_device, :get_device_type) return unrolled_mapreduce(MLDataDevices.$(op), combine_devices, map(Base.Fix1(getfield, f), fieldnames(F))) end + + $(op)(::AbstractRNG) = $(cpu_ret_val) end - for T in (Number, AbstractRNG, Val, Symbol, String, Nothing, AbstractRange) + for T in (Number, Val, Symbol, String, Nothing, AbstractRange) @eval $(op)(::$(T)) = $(op == :get_device ? nothing : Nothing) end end diff --git a/src/public.jl b/src/public.jl index 07deeaa..1dc1646 100644 --- a/src/public.jl +++ b/src/public.jl @@ -232,11 +232,6 @@ const GET_DEVICE_ADMONITIONS = """ !!! note Trigger Packages must be loaded for this to return the correct device. - -!!! warning - - RNG types currently don't participate in device determination. We will remove this - restriction in the future. """ # Query Device from Array diff --git a/test/misc_tests.jl b/test/misc_tests.jl index 1a3093d..f6ea454 100644 --- a/test/misc_tests.jl +++ b/test/misc_tests.jl @@ -154,6 +154,6 @@ end @testset "undefined references array" begin x = Matrix{Any}(undef, 10, 10) - @test get_device(x) isa CPUDevice - @test get_device_type(x) <: CPUDevice + @test get_device(x) isa MLDataDevices.UnknownDevice + @test get_device_type(x) <: MLDataDevices.UnknownDevice end