From 6c3a4a77139a1bd2dca33284c37bc4f41e6337ad Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 18 Oct 2024 11:35:08 -0400 Subject: [PATCH 1/4] feat: add fallbacks for unknown objects --- Project.toml | 2 +- ext/MLDataDevicesChainRulesCoreExt.jl | 11 ++++++--- src/internal.jl | 34 +++++++++++++++++++++++---- src/public.jl | 17 +++++++++++++- 4 files changed, 54 insertions(+), 10 deletions(-) diff --git a/Project.toml b/Project.toml index 1cb1875..41f3134 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "MLDataDevices" uuid = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" authors = ["Avik Pal and contributors"] -version = "1.2.1" +version = "1.3.0" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/ext/MLDataDevicesChainRulesCoreExt.jl b/ext/MLDataDevicesChainRulesCoreExt.jl index c6b9560..6a770b8 100644 --- a/ext/MLDataDevicesChainRulesCoreExt.jl +++ b/ext/MLDataDevicesChainRulesCoreExt.jl @@ -3,15 +3,20 @@ module MLDataDevicesChainRulesCoreExt using Adapt: Adapt using ChainRulesCore: ChainRulesCore, NoTangent, @non_differentiable -using MLDataDevices: AbstractDevice, get_device, get_device_type +using MLDataDevices: AbstractDevice, UnknownDevice, get_device, get_device_type @non_differentiable get_device(::Any) @non_differentiable get_device_type(::Any) function ChainRulesCore.rrule( ::typeof(Adapt.adapt_storage), to::AbstractDevice, x::AbstractArray) - ∇adapt_storage = let x = x - Δ -> (NoTangent(), NoTangent(), (get_device(x))(Δ)) + ∇adapt_storage = let dev = get_device(x) + if dev === nothing || dev isa UnknownDevice + @warn "`get_device(::$(typeof(x)))` returned `$(dev)`." maxlog=1 + Δ -> (NoTangent(), NoTangent(), Δ) + else + Δ -> (NoTangent(), NoTangent(), dev(Δ)) + end end return Adapt.adapt_storage(to, x), ∇adapt_storage end diff --git a/src/internal.jl b/src/internal.jl index e13b716..bcc8cab 100644 --- a/src/internal.jl +++ b/src/internal.jl @@ -5,8 +5,8 @@ using Preferences: load_preference using Random: AbstractRNG using ..MLDataDevices: MLDataDevices, AbstractDevice, CPUDevice, CUDADevice, AMDGPUDevice, - MetalDevice, oneAPIDevice, XLADevice, supported_gpu_backends, - GPU_DEVICES, loaded, functional + MetalDevice, oneAPIDevice, XLADevice, UnknownDevice, + supported_gpu_backends, GPU_DEVICES, loaded, functional for dev in (CPUDevice, MetalDevice, oneAPIDevice) msg = "`device_id` is not applicable for `$dev`." @@ -107,16 +107,22 @@ special_aos(::AbstractArray) = false recursive_array_eltype(::Type{T}) where {T} = !isbitstype(T) && !(T <: Number) combine_devices(::Nothing, ::Nothing) = nothing -combine_devices(::Type{Nothing}, ::Type{Nothing}) = Nothing combine_devices(::Nothing, dev::AbstractDevice) = dev -combine_devices(::Type{Nothing}, ::Type{T}) where {T <: AbstractDevice} = T combine_devices(dev::AbstractDevice, ::Nothing) = dev -combine_devices(::Type{T}, ::Type{Nothing}) where {T <: AbstractDevice} = T function combine_devices(dev1::AbstractDevice, dev2::AbstractDevice) dev1 == dev2 && return dev1 + dev1 isa UnknownDevice && return dev2 + dev2 isa UnknownDevice && return dev1 throw(ArgumentError("Objects are on different devices: $(dev1) and $(dev2).")) end + +combine_devices(::Type{Nothing}, ::Type{Nothing}) = Nothing combine_devices(::Type{T}, ::Type{T}) where {T <: AbstractDevice} = T +combine_devices(::Type{T}, ::Type{Nothing}) where {T <: AbstractDevice} = T +combine_devices(::Type{T}, ::Type{UnknownDevice}) where {T <: AbstractDevice} = T +combine_devices(::Type{Nothing}, ::Type{T}) where {T <: AbstractDevice} = T +combine_devices(::Type{UnknownDevice}, ::Type{T}) where {T <: AbstractDevice} = T +combine_devices(::Type{UnknownDevice}, ::Type{UnknownDevice}) = UnknownDevice function combine_devices(T1::Type{<:AbstractDevice}, T2::Type{<:AbstractDevice}) throw(ArgumentError("Objects are on devices with different types: $(T1) and $(T2).")) end @@ -147,6 +153,13 @@ for op in (:get_device, :get_device_type) length(x) == 0 && return $(op == :get_device ? nothing : Nothing) return unrolled_mapreduce(MLDataDevices.$(op), combine_devices, values(x)) end + + function $(op)(f::F) where {F <: Function} + Base.issingletontype(F) && + return $(op == :get_device ? UnknownDevice() : UnknownDevice) + return unrolled_mapreduce(MLDataDevices.$(op), combine_devices, + map(Base.Fix1(getfield, f), fieldnames(F))) + end end for T in (Number, AbstractRNG, Val, Symbol, String, Nothing, AbstractRange) @@ -154,6 +167,17 @@ for op in (:get_device, :get_device_type) end end +get_device(_) = UnknownDevice() +get_device_type(_) = UnknownDevice + +fast_structure(::AbstractArray) = true +fast_structure(::Union{Tuple, NamedTuple}) = true +for T in (Number, AbstractRNG, Val, Symbol, String, Nothing, AbstractRange) + @eval fast_structure(::$(T)) = true +end +fast_structure(::Function) = true +fast_structure(_) = false + function unrolled_mapreduce(f::F, op::O, itr) where {F, O} return unrolled_mapreduce(f, op, itr, static_length(itr)) end diff --git a/src/public.jl b/src/public.jl index 178c6f9..07deeaa 100644 --- a/src/public.jl +++ b/src/public.jl @@ -12,6 +12,9 @@ struct oneAPIDevice <: AbstractGPUDevice end # TODO: Later we might want to add the client field here? struct XLADevice <: AbstractAcceleratorDevice end +# Fallback for when we don't know the device type +struct UnknownDevice <: AbstractDevice end + """ functional(x::AbstractDevice) -> Bool functional(::Type{<:AbstractDevice}) -> Bool @@ -245,6 +248,12 @@ device. Otherwise, we throw an error. If the object is device agnostic, we retur $(GET_DEVICE_ADMONITIONS) +## Special Retuened Values + + - `nothing` -- denotes that the object is device agnostic. For example, scalar, abstract + range, etc. + - `UnknownDevice()` -- denotes that the device type is unknown + See also [`get_device_type`](@ref) for a faster alternative that can be used for dispatch based on device type. """ @@ -258,6 +267,12 @@ itself. This value is often a compile time constant and is recommended to be use of [`get_device`](@ref) where ever defining dispatches based on the device type. $(GET_DEVICE_ADMONITIONS) + +## Special Retuened Values + + - `Nothing` -- denotes that the object is device agnostic. For example, scalar, abstract + range, etc. + - `UnknownDevice` -- denotes that the device type is unknown """ function get_device_type end @@ -345,7 +360,7 @@ end for op in (:get_device, :get_device_type) @eval function $(op)(x) - hasmethod(Internal.$(op), Tuple{typeof(x)}) && return Internal.$(op)(x) + Internal.fast_structure(x) && return Internal.$(op)(x) return mapreduce(Internal.$(op), Internal.combine_devices, fleaves(x)) end end From 46fccbb8c3cdfdaddf3ab2cd56d0fd6627dabaef Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 18 Oct 2024 12:37:13 -0400 Subject: [PATCH 2/4] feat: handle RNGs and undef arrays gracefully --- .buildkite/pipeline.yml | 2 +- ext/MLDataDevicesAMDGPUExt.jl | 2 ++ ext/MLDataDevicesCUDAExt.jl | 4 ++++ ext/MLDataDevicesGPUArraysExt.jl | 5 ++++- src/internal.jl | 5 +++-- src/public.jl | 5 ----- test/misc_tests.jl | 4 ++-- 7 files changed, 16 insertions(+), 11 deletions(-) diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index 2c00e63..a8c37f0 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -1,6 +1,6 @@ steps: - label: "Triggering Pipelines (Pull Request)" - if: "build.pull_request.base_branch == 'main'" + if: build.branch != "main" && build.tag == null agents: queue: "juliagpu" plugins: diff --git a/ext/MLDataDevicesAMDGPUExt.jl b/ext/MLDataDevicesAMDGPUExt.jl index 4014b2e..ca275b5 100644 --- a/ext/MLDataDevicesAMDGPUExt.jl +++ b/ext/MLDataDevicesAMDGPUExt.jl @@ -49,8 +49,10 @@ function Internal.get_device(x::AMDGPU.AnyROCArray) parent_x === x && return AMDGPUDevice(AMDGPU.device(x)) return Internal.get_device(parent_x) end +Internal.get_device(::AMDGPU.rocRAND.RNG) = AMDGPUDevice(AMDGPU.device()) Internal.get_device_type(::AMDGPU.AnyROCArray) = AMDGPUDevice +Internal.get_device_type(::AMDGPU.rocRAND.RNG) = AMDGPUDevice # Set Device function MLDataDevices.set_device!(::Type{AMDGPUDevice}, dev::AMDGPU.HIPDevice) diff --git a/ext/MLDataDevicesCUDAExt.jl b/ext/MLDataDevicesCUDAExt.jl index 3492440..9355b81 100644 --- a/ext/MLDataDevicesCUDAExt.jl +++ b/ext/MLDataDevicesCUDAExt.jl @@ -29,8 +29,12 @@ function Internal.get_device(x::CUDA.AnyCuArray) return MLDataDevices.get_device(parent_x) end Internal.get_device(x::AbstractCuSparseArray) = CUDADevice(CUDA.device(x.nzVal)) +Internal.get_device(::CUDA.RNG) = CUDADevice(CUDA.device()) +Internal.get_device(::CUDA.CURAND.RNG) = CUDADevice(CUDA.device()) Internal.get_device_type(::Union{<:CUDA.AnyCuArray, <:AbstractCuSparseArray}) = CUDADevice +Internal.get_device_type(::CUDA.RNG) = CUDADevice +Internal.get_device_type(::CUDA.CURAND.RNG) = CUDADevice # Set Device MLDataDevices.set_device!(::Type{CUDADevice}, dev::CUDA.CuDevice) = CUDA.device!(dev) diff --git a/ext/MLDataDevicesGPUArraysExt.jl b/ext/MLDataDevicesGPUArraysExt.jl index daf7eb3..a09a386 100644 --- a/ext/MLDataDevicesGPUArraysExt.jl +++ b/ext/MLDataDevicesGPUArraysExt.jl @@ -2,9 +2,12 @@ module MLDataDevicesGPUArraysExt using Adapt: Adapt using GPUArrays: GPUArrays -using MLDataDevices: CPUDevice +using MLDataDevices: Internal, CPUDevice using Random: Random Adapt.adapt_storage(::CPUDevice, rng::GPUArrays.RNG) = Random.default_rng() +Internal.get_device(rng::GPUArrays.RNG) = Internal.get_device(rng.state) +Internal.get_device_type(rng::GPUArrays.RNG) = Internal.get_device_type(rng.state) + end diff --git a/src/internal.jl b/src/internal.jl index bcc8cab..5da37ac 100644 --- a/src/internal.jl +++ b/src/internal.jl @@ -129,15 +129,16 @@ end for op in (:get_device, :get_device_type) cpu_ret_val = op == :get_device ? CPUDevice() : CPUDevice + unknown_ret_val = op == :get_device ? UnknownDevice() : UnknownDevice not_assigned_msg = "AbstractArray has some undefined references. Giving up, returning \ - $(cpu_ret_val)..." + $(unknown_ret_val)..." @eval begin function $(op)(x::AbstractArray{T}) where {T} if recursive_array_eltype(T) if any(!isassigned(x, i) for i in eachindex(x)) @warn $(not_assigned_msg) - return $(cpu_ret_val) + return $(unknown_ret_val) end return mapreduce(MLDataDevices.$(op), combine_devices, x) end diff --git a/src/public.jl b/src/public.jl index 07deeaa..1dc1646 100644 --- a/src/public.jl +++ b/src/public.jl @@ -232,11 +232,6 @@ const GET_DEVICE_ADMONITIONS = """ !!! note Trigger Packages must be loaded for this to return the correct device. - -!!! warning - - RNG types currently don't participate in device determination. We will remove this - restriction in the future. """ # Query Device from Array diff --git a/test/misc_tests.jl b/test/misc_tests.jl index 1a3093d..f6ea454 100644 --- a/test/misc_tests.jl +++ b/test/misc_tests.jl @@ -154,6 +154,6 @@ end @testset "undefined references array" begin x = Matrix{Any}(undef, 10, 10) - @test get_device(x) isa CPUDevice - @test get_device_type(x) <: CPUDevice + @test get_device(x) isa MLDataDevices.UnknownDevice + @test get_device_type(x) <: MLDataDevices.UnknownDevice end From b3bef2292a495b9371b137fde4fa399e708c74c3 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 18 Oct 2024 13:05:23 -0400 Subject: [PATCH 3/4] test: RNG movement --- test/amdgpu_tests.jl | 8 ++++++++ test/cuda_tests.jl | 8 ++++++++ test/metal_tests.jl | 8 ++++++++ test/oneapi_tests.jl | 8 ++++++++ test/xla_tests.jl | 8 ++++++++ 5 files changed, 40 insertions(+) diff --git a/test/amdgpu_tests.jl b/test/amdgpu_tests.jl index 67edff4..f29c279 100644 --- a/test/amdgpu_tests.jl +++ b/test/amdgpu_tests.jl @@ -57,7 +57,11 @@ using FillArrays, Zygote # Extensions @test ps_xpu.e == ps.e @test ps_xpu.d == ps.d @test ps_xpu.rng_default isa rngType + @test get_device(ps_xpu.rng_default) isa AMDGPUDevice + @test get_device_type(ps_xpu.rng_default) <: AMDGPUDevice @test ps_xpu.rng == ps.rng + @test get_device(ps_xpu.rng) === nothing + @test get_device_type(ps_xpu.rng) <: Nothing if MLDataDevices.functional(AMDGPUDevice) @test ps_xpu.one_elem isa ROCArray @@ -83,7 +87,11 @@ using FillArrays, Zygote # Extensions @test ps_cpu.e == ps.e @test ps_cpu.d == ps.d @test ps_cpu.rng_default isa Random.TaskLocalRNG + @test get_device(ps_cpu.rng_default) === nothing + @test get_device_type(ps_cpu.rng_default) <: Nothing @test ps_cpu.rng == ps.rng + @test get_device(ps_cpu.rng) === nothing + @test get_device_type(ps_cpu.rng) <: Nothing if MLDataDevices.functional(AMDGPUDevice) @test ps_cpu.one_elem isa Array diff --git a/test/cuda_tests.jl b/test/cuda_tests.jl index 92c0a27..bd8a234 100644 --- a/test/cuda_tests.jl +++ b/test/cuda_tests.jl @@ -56,7 +56,11 @@ using FillArrays, Zygote # Extensions @test ps_xpu.e == ps.e @test ps_xpu.d == ps.d @test ps_xpu.rng_default isa rngType + @test get_device(ps_xpu.rng_default) isa CUDADevice + @test get_device_type(ps_xpu.rng_default) <: CUDADevice @test ps_xpu.rng == ps.rng + @test get_device(ps_xpu.rng) === nothing + @test get_device_type(ps_xpu.rng) <: Nothing if MLDataDevices.functional(CUDADevice) @test ps_xpu.one_elem isa CuArray @@ -82,7 +86,11 @@ using FillArrays, Zygote # Extensions @test ps_cpu.e == ps.e @test ps_cpu.d == ps.d @test ps_cpu.rng_default isa Random.TaskLocalRNG + @test get_device(ps_cpu.rng_default) === nothing + @test get_device_type(ps_cpu.rng_default) <: Nothing @test ps_cpu.rng == ps.rng + @test get_device(ps_cpu.rng) === nothing + @test get_device_type(ps_cpu.rng) <: Nothing if MLDataDevices.functional(CUDADevice) @test ps_cpu.one_elem isa Array diff --git a/test/metal_tests.jl b/test/metal_tests.jl index 789fa49..a214ebd 100644 --- a/test/metal_tests.jl +++ b/test/metal_tests.jl @@ -55,7 +55,11 @@ using FillArrays, Zygote # Extensions @test ps_xpu.e == ps.e @test ps_xpu.d == ps.d @test ps_xpu.rng_default isa rngType + @test get_device(ps_xpu.rng_default) isa MetalDevice + @test get_device_type(ps_xpu.rng_default) <: MetalDevice @test ps_xpu.rng == ps.rng + @test get_device(ps_xpu.rng) === nothing + @test get_device_type(ps_xpu.rng) <: Nothing if MLDataDevices.functional(MetalDevice) @test ps_xpu.one_elem isa MtlArray @@ -81,7 +85,11 @@ using FillArrays, Zygote # Extensions @test ps_cpu.e == ps.e @test ps_cpu.d == ps.d @test ps_cpu.rng_default isa Random.TaskLocalRNG + @test get_device(ps_cpu.rng_default) === nothing + @test get_device_type(ps_cpu.rng_default) <: Nothing @test ps_cpu.rng == ps.rng + @test get_device(ps_cpu.rng) === nothing + @test get_device_type(ps_cpu.rng) <: Nothing if MLDataDevices.functional(MetalDevice) @test ps_cpu.one_elem isa Array diff --git a/test/oneapi_tests.jl b/test/oneapi_tests.jl index 7731c43..d1720f0 100644 --- a/test/oneapi_tests.jl +++ b/test/oneapi_tests.jl @@ -55,7 +55,11 @@ using FillArrays, Zygote # Extensions @test ps_xpu.e == ps.e @test ps_xpu.d == ps.d @test ps_xpu.rng_default isa rngType + @test get_device(ps_xpu.rng_default) isa oneAPIDevice + @test get_device_type(ps_xpu.rng_default) <: oneAPIDevice @test ps_xpu.rng == ps.rng + @test get_device(ps_xpu.rng) === nothing + @test get_device_type(ps_xpu.rng) <: Nothing if MLDataDevices.functional(oneAPIDevice) @test ps_xpu.one_elem isa oneArray @@ -81,7 +85,11 @@ using FillArrays, Zygote # Extensions @test ps_cpu.e == ps.e @test ps_cpu.d == ps.d @test ps_cpu.rng_default isa Random.TaskLocalRNG + @test get_device(ps_cpu.rng_default) === nothing + @test get_device_type(ps_cpu.rng_default) <: Nothing @test ps_cpu.rng == ps.rng + @test get_device(ps_cpu.rng) === nothing + @test get_device_type(ps_cpu.rng) <: Nothing if MLDataDevices.functional(oneAPIDevice) @test ps_cpu.one_elem isa Array diff --git a/test/xla_tests.jl b/test/xla_tests.jl index 81ae929..138727f 100644 --- a/test/xla_tests.jl +++ b/test/xla_tests.jl @@ -54,7 +54,11 @@ using FillArrays, Zygote # Extensions @test ps_xpu.e == ps.e @test ps_xpu.d == ps.d @test ps_xpu.rng_default isa rngType + @test get_device(ps_xpu.rng_default) === nothing + @test get_device_type(ps_xpu.rng_default) <: Nothing @test ps_xpu.rng == ps.rng + @test get_device(ps_xpu.rng) === nothing + @test get_device_type(ps_xpu.rng) <: Nothing if MLDataDevices.functional(XLADevice) @test ps_xpu.one_elem isa Reactant.RArray @@ -80,7 +84,11 @@ using FillArrays, Zygote # Extensions @test ps_cpu.e == ps.e @test ps_cpu.d == ps.d @test ps_cpu.rng_default isa Random.TaskLocalRNG + @test get_device(ps_cpu.rng_default) === nothing + @test get_device_type(ps_cpu.rng_default) <: Nothing @test ps_cpu.rng == ps.rng + @test get_device(ps_cpu.rng) === nothing + @test get_device_type(ps_cpu.rng) <: Nothing if MLDataDevices.functional(XLADevice) @test ps_cpu.one_elem isa Array From 37429fbaf93ac3235be00724eaf07b23f7a3d98d Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 18 Oct 2024 13:25:57 -0400 Subject: [PATCH 4/4] test: functions and closures --- test/amdgpu_tests.jl | 21 +++++++++++++++++++++ test/cuda_tests.jl | 21 +++++++++++++++++++++ test/metal_tests.jl | 21 +++++++++++++++++++++ test/oneapi_tests.jl | 21 +++++++++++++++++++++ test/xla_tests.jl | 21 +++++++++++++++++++++ 5 files changed, 105 insertions(+) diff --git a/test/amdgpu_tests.jl b/test/amdgpu_tests.jl index f29c279..41a8797 100644 --- a/test/amdgpu_tests.jl +++ b/test/amdgpu_tests.jl @@ -126,6 +126,27 @@ using FillArrays, Zygote # Extensions end end +@testset "Functions" begin + if MLDataDevices.functional(AMDGPUDevice) + @test get_device(tanh) isa MLDataDevices.UnknownDevice + @test get_device_type(tanh) <: MLDataDevices.UnknownDevice + + f(x, y) = () -> (x, x .^ 2, y) + + ff = f([1, 2, 3], 1) + @test get_device(ff) isa CPUDevice + @test get_device_type(ff) <: CPUDevice + + ff_xpu = ff |> AMDGPUDevice() + @test get_device(ff_xpu) isa AMDGPUDevice + @test get_device_type(ff_xpu) <: AMDGPUDevice + + ff_cpu = ff_xpu |> cpu_device() + @test get_device(ff_cpu) isa CPUDevice + @test get_device_type(ff_cpu) <: CPUDevice + end +end + @testset "Wrapped Arrays" begin if MLDataDevices.functional(AMDGPUDevice) x = rand(10, 10) |> AMDGPUDevice() diff --git a/test/cuda_tests.jl b/test/cuda_tests.jl index bd8a234..1f95831 100644 --- a/test/cuda_tests.jl +++ b/test/cuda_tests.jl @@ -151,6 +151,27 @@ using FillArrays, Zygote # Extensions end end +@testset "Functions" begin + if MLDataDevices.functional(CUDADevice) + @test get_device(tanh) isa MLDataDevices.UnknownDevice + @test get_device_type(tanh) <: MLDataDevices.UnknownDevice + + f(x, y) = () -> (x, x .^ 2, y) + + ff = f([1, 2, 3], 1) + @test get_device(ff) isa CPUDevice + @test get_device_type(ff) <: CPUDevice + + ff_xpu = ff |> CUDADevice() + @test get_device(ff_xpu) isa CUDADevice + @test get_device_type(ff_xpu) <: CUDADevice + + ff_cpu = ff_xpu |> cpu_device() + @test get_device(ff_cpu) isa CPUDevice + @test get_device_type(ff_cpu) <: CPUDevice + end +end + @testset "Wrapped Arrays" begin if MLDataDevices.functional(CUDADevice) x = rand(10, 10) |> CUDADevice() diff --git a/test/metal_tests.jl b/test/metal_tests.jl index a214ebd..aeb596a 100644 --- a/test/metal_tests.jl +++ b/test/metal_tests.jl @@ -115,6 +115,27 @@ using FillArrays, Zygote # Extensions end end +@testset "Functions" begin + if MLDataDevices.functional(MetalDevice) + @test get_device(tanh) isa MLDataDevices.UnknownDevice + @test get_device_type(tanh) <: MLDataDevices.UnknownDevice + + f(x, y) = () -> (x, x .^ 2, y) + + ff = f([1, 2, 3], 1) + @test get_device(ff) isa CPUDevice + @test get_device_type(ff) <: CPUDevice + + ff_xpu = ff |> MetalDevice() + @test get_device(ff_xpu) isa MetalDevice + @test get_device_type(ff_xpu) <: MetalDevice + + ff_cpu = ff_xpu |> cpu_device() + @test get_device(ff_cpu) isa CPUDevice + @test get_device_type(ff_cpu) <: CPUDevice + end +end + @testset "Wrapper Arrays" begin if MLDataDevices.functional(MetalDevice) x = rand(Float32, 10, 10) |> MetalDevice() diff --git a/test/oneapi_tests.jl b/test/oneapi_tests.jl index d1720f0..8bb6026 100644 --- a/test/oneapi_tests.jl +++ b/test/oneapi_tests.jl @@ -115,6 +115,27 @@ using FillArrays, Zygote # Extensions end end +@testset "Functions" begin + if MLDataDevices.functional(oneAPIDevice) + @test get_device(tanh) isa MLDataDevices.UnknownDevice + @test get_device_type(tanh) <: MLDataDevices.UnknownDevice + + f(x, y) = () -> (x, x .^ 2, y) + + ff = f([1, 2, 3], 1) + @test get_device(ff) isa CPUDevice + @test get_device_type(ff) <: CPUDevice + + ff_xpu = ff |> oneAPIDevice() + @test get_device(ff_xpu) isa oneAPIDevice + @test get_device_type(ff_xpu) <: oneAPIDevice + + ff_cpu = ff_xpu |> cpu_device() + @test get_device(ff_cpu) isa CPUDevice + @test get_device_type(ff_cpu) <: CPUDevice + end +end + @testset "Wrapper Arrays" begin if MLDataDevices.functional(oneAPIDevice) x = rand(10, 10) |> oneAPIDevice() diff --git a/test/xla_tests.jl b/test/xla_tests.jl index 138727f..21466bd 100644 --- a/test/xla_tests.jl +++ b/test/xla_tests.jl @@ -114,6 +114,27 @@ using FillArrays, Zygote # Extensions end end +@testset "Functions" begin + if MLDataDevices.functional(XLADevice) + @test get_device(tanh) isa MLDataDevices.UnknownDevice + @test get_device_type(tanh) <: MLDataDevices.UnknownDevice + + f(x, y) = () -> (x, x .^ 2, y) + + ff = f([1, 2, 3], 1) + @test get_device(ff) isa CPUDevice + @test get_device_type(ff) <: CPUDevice + + ff_xpu = ff |> XLADevice() + @test get_device(ff_xpu) isa XLADevice + @test get_device_type(ff_xpu) <: XLADevice + + ff_cpu = ff_xpu |> cpu_device() + @test get_device(ff_cpu) isa CPUDevice + @test get_device_type(ff_cpu) <: CPUDevice + end +end + @testset "Wrapped Arrays" begin if MLDataDevices.functional(XLADevice) x = rand(10, 10) |> XLADevice()