Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

feat: add fallbacks for unknown objects #87

Merged
merged 4 commits into from
Oct 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
steps:
- label: "Triggering Pipelines (Pull Request)"
if: "build.pull_request.base_branch == 'main'"
if: build.branch != "main" && build.tag == null
agents:
queue: "juliagpu"
plugins:
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MLDataDevices"
uuid = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "1.2.1"
version = "1.3.0"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down
2 changes: 2 additions & 0 deletions ext/MLDataDevicesAMDGPUExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,10 @@ function Internal.get_device(x::AMDGPU.AnyROCArray)
parent_x === x && return AMDGPUDevice(AMDGPU.device(x))
return Internal.get_device(parent_x)
end
Internal.get_device(::AMDGPU.rocRAND.RNG) = AMDGPUDevice(AMDGPU.device())

Internal.get_device_type(::AMDGPU.AnyROCArray) = AMDGPUDevice
Internal.get_device_type(::AMDGPU.rocRAND.RNG) = AMDGPUDevice

# Set Device
function MLDataDevices.set_device!(::Type{AMDGPUDevice}, dev::AMDGPU.HIPDevice)
Expand Down
4 changes: 4 additions & 0 deletions ext/MLDataDevicesCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,12 @@ function Internal.get_device(x::CUDA.AnyCuArray)
return MLDataDevices.get_device(parent_x)
end
Internal.get_device(x::AbstractCuSparseArray) = CUDADevice(CUDA.device(x.nzVal))
Internal.get_device(::CUDA.RNG) = CUDADevice(CUDA.device())
Internal.get_device(::CUDA.CURAND.RNG) = CUDADevice(CUDA.device())

Internal.get_device_type(::Union{<:CUDA.AnyCuArray, <:AbstractCuSparseArray}) = CUDADevice
Internal.get_device_type(::CUDA.RNG) = CUDADevice
Internal.get_device_type(::CUDA.CURAND.RNG) = CUDADevice

# Set Device
MLDataDevices.set_device!(::Type{CUDADevice}, dev::CUDA.CuDevice) = CUDA.device!(dev)
Expand Down
11 changes: 8 additions & 3 deletions ext/MLDataDevicesChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,20 @@ module MLDataDevicesChainRulesCoreExt
using Adapt: Adapt
using ChainRulesCore: ChainRulesCore, NoTangent, @non_differentiable

using MLDataDevices: AbstractDevice, get_device, get_device_type
using MLDataDevices: AbstractDevice, UnknownDevice, get_device, get_device_type

@non_differentiable get_device(::Any)
@non_differentiable get_device_type(::Any)

function ChainRulesCore.rrule(
::typeof(Adapt.adapt_storage), to::AbstractDevice, x::AbstractArray)
∇adapt_storage = let x = x
Δ -> (NoTangent(), NoTangent(), (get_device(x))(Δ))
∇adapt_storage = let dev = get_device(x)
if dev === nothing || dev isa UnknownDevice
@warn "`get_device(::$(typeof(x)))` returned `$(dev)`." maxlog=1
Δ -> (NoTangent(), NoTangent(), Δ)
else
Δ -> (NoTangent(), NoTangent(), dev(Δ))
end
end
return Adapt.adapt_storage(to, x), ∇adapt_storage
end
Expand Down
5 changes: 4 additions & 1 deletion ext/MLDataDevicesGPUArraysExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@ module MLDataDevicesGPUArraysExt

using Adapt: Adapt
using GPUArrays: GPUArrays
using MLDataDevices: CPUDevice
using MLDataDevices: Internal, CPUDevice
using Random: Random

Adapt.adapt_storage(::CPUDevice, rng::GPUArrays.RNG) = Random.default_rng()

Internal.get_device(rng::GPUArrays.RNG) = Internal.get_device(rng.state)
Internal.get_device_type(rng::GPUArrays.RNG) = Internal.get_device_type(rng.state)

end
39 changes: 32 additions & 7 deletions src/internal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ using Preferences: load_preference
using Random: AbstractRNG

using ..MLDataDevices: MLDataDevices, AbstractDevice, CPUDevice, CUDADevice, AMDGPUDevice,
MetalDevice, oneAPIDevice, XLADevice, supported_gpu_backends,
GPU_DEVICES, loaded, functional
MetalDevice, oneAPIDevice, XLADevice, UnknownDevice,
supported_gpu_backends, GPU_DEVICES, loaded, functional

for dev in (CPUDevice, MetalDevice, oneAPIDevice)
msg = "`device_id` is not applicable for `$dev`."
Expand Down Expand Up @@ -107,31 +107,38 @@ special_aos(::AbstractArray) = false
recursive_array_eltype(::Type{T}) where {T} = !isbitstype(T) && !(T <: Number)

combine_devices(::Nothing, ::Nothing) = nothing
combine_devices(::Type{Nothing}, ::Type{Nothing}) = Nothing
combine_devices(::Nothing, dev::AbstractDevice) = dev
combine_devices(::Type{Nothing}, ::Type{T}) where {T <: AbstractDevice} = T
combine_devices(dev::AbstractDevice, ::Nothing) = dev
combine_devices(::Type{T}, ::Type{Nothing}) where {T <: AbstractDevice} = T
function combine_devices(dev1::AbstractDevice, dev2::AbstractDevice)
dev1 == dev2 && return dev1
dev1 isa UnknownDevice && return dev2
dev2 isa UnknownDevice && return dev1
throw(ArgumentError("Objects are on different devices: $(dev1) and $(dev2)."))
end

combine_devices(::Type{Nothing}, ::Type{Nothing}) = Nothing
combine_devices(::Type{T}, ::Type{T}) where {T <: AbstractDevice} = T
combine_devices(::Type{T}, ::Type{Nothing}) where {T <: AbstractDevice} = T
combine_devices(::Type{T}, ::Type{UnknownDevice}) where {T <: AbstractDevice} = T
combine_devices(::Type{Nothing}, ::Type{T}) where {T <: AbstractDevice} = T
combine_devices(::Type{UnknownDevice}, ::Type{T}) where {T <: AbstractDevice} = T
combine_devices(::Type{UnknownDevice}, ::Type{UnknownDevice}) = UnknownDevice
function combine_devices(T1::Type{<:AbstractDevice}, T2::Type{<:AbstractDevice})
throw(ArgumentError("Objects are on devices with different types: $(T1) and $(T2)."))
end

for op in (:get_device, :get_device_type)
cpu_ret_val = op == :get_device ? CPUDevice() : CPUDevice
unknown_ret_val = op == :get_device ? UnknownDevice() : UnknownDevice
not_assigned_msg = "AbstractArray has some undefined references. Giving up, returning \
$(cpu_ret_val)..."
$(unknown_ret_val)..."

@eval begin
function $(op)(x::AbstractArray{T}) where {T}
if recursive_array_eltype(T)
if any(!isassigned(x, i) for i in eachindex(x))
@warn $(not_assigned_msg)
return $(cpu_ret_val)
return $(unknown_ret_val)
end
return mapreduce(MLDataDevices.$(op), combine_devices, x)
end
Expand All @@ -147,13 +154,31 @@ for op in (:get_device, :get_device_type)
length(x) == 0 && return $(op == :get_device ? nothing : Nothing)
return unrolled_mapreduce(MLDataDevices.$(op), combine_devices, values(x))
end

function $(op)(f::F) where {F <: Function}
Base.issingletontype(F) &&
return $(op == :get_device ? UnknownDevice() : UnknownDevice)
return unrolled_mapreduce(MLDataDevices.$(op), combine_devices,
map(Base.Fix1(getfield, f), fieldnames(F)))
end
end

for T in (Number, AbstractRNG, Val, Symbol, String, Nothing, AbstractRange)
@eval $(op)(::$(T)) = $(op == :get_device ? nothing : Nothing)
end
end

get_device(_) = UnknownDevice()
get_device_type(_) = UnknownDevice

fast_structure(::AbstractArray) = true
fast_structure(::Union{Tuple, NamedTuple}) = true
for T in (Number, AbstractRNG, Val, Symbol, String, Nothing, AbstractRange)
@eval fast_structure(::$(T)) = true
end
fast_structure(::Function) = true
fast_structure(_) = false

function unrolled_mapreduce(f::F, op::O, itr) where {F, O}
return unrolled_mapreduce(f, op, itr, static_length(itr))
end
Expand Down
22 changes: 16 additions & 6 deletions src/public.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ struct oneAPIDevice <: AbstractGPUDevice end
# TODO: Later we might want to add the client field here?
struct XLADevice <: AbstractAcceleratorDevice end

# Fallback for when we don't know the device type
struct UnknownDevice <: AbstractDevice end

"""
functional(x::AbstractDevice) -> Bool
functional(::Type{<:AbstractDevice}) -> Bool
Expand Down Expand Up @@ -229,11 +232,6 @@ const GET_DEVICE_ADMONITIONS = """
!!! note

Trigger Packages must be loaded for this to return the correct device.

!!! warning

RNG types currently don't participate in device determination. We will remove this
restriction in the future.
"""

# Query Device from Array
Expand All @@ -245,6 +243,12 @@ device. Otherwise, we throw an error. If the object is device agnostic, we retur

$(GET_DEVICE_ADMONITIONS)

## Special Retuened Values

- `nothing` -- denotes that the object is device agnostic. For example, scalar, abstract
range, etc.
- `UnknownDevice()` -- denotes that the device type is unknown

See also [`get_device_type`](@ref) for a faster alternative that can be used for dispatch
based on device type.
"""
Expand All @@ -258,6 +262,12 @@ itself. This value is often a compile time constant and is recommended to be use
of [`get_device`](@ref) where ever defining dispatches based on the device type.

$(GET_DEVICE_ADMONITIONS)

## Special Retuened Values

- `Nothing` -- denotes that the object is device agnostic. For example, scalar, abstract
range, etc.
- `UnknownDevice` -- denotes that the device type is unknown
"""
function get_device_type end

Expand Down Expand Up @@ -345,7 +355,7 @@ end

for op in (:get_device, :get_device_type)
@eval function $(op)(x)
hasmethod(Internal.$(op), Tuple{typeof(x)}) && return Internal.$(op)(x)
Internal.fast_structure(x) && return Internal.$(op)(x)
return mapreduce(Internal.$(op), Internal.combine_devices, fleaves(x))
end
end
Expand Down
29 changes: 29 additions & 0 deletions test/amdgpu_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,11 @@ using FillArrays, Zygote # Extensions
@test ps_xpu.e == ps.e
@test ps_xpu.d == ps.d
@test ps_xpu.rng_default isa rngType
@test get_device(ps_xpu.rng_default) isa AMDGPUDevice
@test get_device_type(ps_xpu.rng_default) <: AMDGPUDevice
@test ps_xpu.rng == ps.rng
@test get_device(ps_xpu.rng) === nothing
@test get_device_type(ps_xpu.rng) <: Nothing

if MLDataDevices.functional(AMDGPUDevice)
@test ps_xpu.one_elem isa ROCArray
Expand All @@ -83,7 +87,11 @@ using FillArrays, Zygote # Extensions
@test ps_cpu.e == ps.e
@test ps_cpu.d == ps.d
@test ps_cpu.rng_default isa Random.TaskLocalRNG
@test get_device(ps_cpu.rng_default) === nothing
@test get_device_type(ps_cpu.rng_default) <: Nothing
@test ps_cpu.rng == ps.rng
@test get_device(ps_cpu.rng) === nothing
@test get_device_type(ps_cpu.rng) <: Nothing

if MLDataDevices.functional(AMDGPUDevice)
@test ps_cpu.one_elem isa Array
Expand Down Expand Up @@ -118,6 +126,27 @@ using FillArrays, Zygote # Extensions
end
end

@testset "Functions" begin
if MLDataDevices.functional(AMDGPUDevice)
@test get_device(tanh) isa MLDataDevices.UnknownDevice
@test get_device_type(tanh) <: MLDataDevices.UnknownDevice

f(x, y) = () -> (x, x .^ 2, y)

ff = f([1, 2, 3], 1)
@test get_device(ff) isa CPUDevice
@test get_device_type(ff) <: CPUDevice

ff_xpu = ff |> AMDGPUDevice()
@test get_device(ff_xpu) isa AMDGPUDevice
@test get_device_type(ff_xpu) <: AMDGPUDevice

ff_cpu = ff_xpu |> cpu_device()
@test get_device(ff_cpu) isa CPUDevice
@test get_device_type(ff_cpu) <: CPUDevice
end
end

@testset "Wrapped Arrays" begin
if MLDataDevices.functional(AMDGPUDevice)
x = rand(10, 10) |> AMDGPUDevice()
Expand Down
29 changes: 29 additions & 0 deletions test/cuda_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,11 @@ using FillArrays, Zygote # Extensions
@test ps_xpu.e == ps.e
@test ps_xpu.d == ps.d
@test ps_xpu.rng_default isa rngType
@test get_device(ps_xpu.rng_default) isa CUDADevice
@test get_device_type(ps_xpu.rng_default) <: CUDADevice
@test ps_xpu.rng == ps.rng
@test get_device(ps_xpu.rng) === nothing
@test get_device_type(ps_xpu.rng) <: Nothing

if MLDataDevices.functional(CUDADevice)
@test ps_xpu.one_elem isa CuArray
Expand All @@ -82,7 +86,11 @@ using FillArrays, Zygote # Extensions
@test ps_cpu.e == ps.e
@test ps_cpu.d == ps.d
@test ps_cpu.rng_default isa Random.TaskLocalRNG
@test get_device(ps_cpu.rng_default) === nothing
@test get_device_type(ps_cpu.rng_default) <: Nothing
@test ps_cpu.rng == ps.rng
@test get_device(ps_cpu.rng) === nothing
@test get_device_type(ps_cpu.rng) <: Nothing

if MLDataDevices.functional(CUDADevice)
@test ps_cpu.one_elem isa Array
Expand Down Expand Up @@ -143,6 +151,27 @@ using FillArrays, Zygote # Extensions
end
end

@testset "Functions" begin
if MLDataDevices.functional(CUDADevice)
@test get_device(tanh) isa MLDataDevices.UnknownDevice
@test get_device_type(tanh) <: MLDataDevices.UnknownDevice

f(x, y) = () -> (x, x .^ 2, y)

ff = f([1, 2, 3], 1)
@test get_device(ff) isa CPUDevice
@test get_device_type(ff) <: CPUDevice

ff_xpu = ff |> CUDADevice()
@test get_device(ff_xpu) isa CUDADevice
@test get_device_type(ff_xpu) <: CUDADevice

ff_cpu = ff_xpu |> cpu_device()
@test get_device(ff_cpu) isa CPUDevice
@test get_device_type(ff_cpu) <: CPUDevice
end
end

@testset "Wrapped Arrays" begin
if MLDataDevices.functional(CUDADevice)
x = rand(10, 10) |> CUDADevice()
Expand Down
29 changes: 29 additions & 0 deletions test/metal_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,11 @@ using FillArrays, Zygote # Extensions
@test ps_xpu.e == ps.e
@test ps_xpu.d == ps.d
@test ps_xpu.rng_default isa rngType
@test get_device(ps_xpu.rng_default) isa MetalDevice
@test get_device_type(ps_xpu.rng_default) <: MetalDevice
@test ps_xpu.rng == ps.rng
@test get_device(ps_xpu.rng) === nothing
@test get_device_type(ps_xpu.rng) <: Nothing

if MLDataDevices.functional(MetalDevice)
@test ps_xpu.one_elem isa MtlArray
Expand All @@ -81,7 +85,11 @@ using FillArrays, Zygote # Extensions
@test ps_cpu.e == ps.e
@test ps_cpu.d == ps.d
@test ps_cpu.rng_default isa Random.TaskLocalRNG
@test get_device(ps_cpu.rng_default) === nothing
@test get_device_type(ps_cpu.rng_default) <: Nothing
@test ps_cpu.rng == ps.rng
@test get_device(ps_cpu.rng) === nothing
@test get_device_type(ps_cpu.rng) <: Nothing

if MLDataDevices.functional(MetalDevice)
@test ps_cpu.one_elem isa Array
Expand All @@ -107,6 +115,27 @@ using FillArrays, Zygote # Extensions
end
end

@testset "Functions" begin
if MLDataDevices.functional(MetalDevice)
@test get_device(tanh) isa MLDataDevices.UnknownDevice
@test get_device_type(tanh) <: MLDataDevices.UnknownDevice

f(x, y) = () -> (x, x .^ 2, y)

ff = f([1, 2, 3], 1)
@test get_device(ff) isa CPUDevice
@test get_device_type(ff) <: CPUDevice

ff_xpu = ff |> MetalDevice()
@test get_device(ff_xpu) isa MetalDevice
@test get_device_type(ff_xpu) <: MetalDevice

ff_cpu = ff_xpu |> cpu_device()
@test get_device(ff_cpu) isa CPUDevice
@test get_device_type(ff_cpu) <: CPUDevice
end
end

@testset "Wrapper Arrays" begin
if MLDataDevices.functional(MetalDevice)
x = rand(Float32, 10, 10) |> MetalDevice()
Expand Down
4 changes: 2 additions & 2 deletions test/misc_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,6 @@ end
@testset "undefined references array" begin
x = Matrix{Any}(undef, 10, 10)

@test get_device(x) isa CPUDevice
@test get_device_type(x) <: CPUDevice
@test get_device(x) isa MLDataDevices.UnknownDevice
@test get_device_type(x) <: MLDataDevices.UnknownDevice
end
Loading
Loading