diff --git a/lib/MLDataDevices/ext/MLDataDevicesChainRulesExt.jl b/lib/MLDataDevices/ext/MLDataDevicesChainRulesExt.jl index 0dab66bf37..5515864403 100644 --- a/lib/MLDataDevices/ext/MLDataDevicesChainRulesExt.jl +++ b/lib/MLDataDevices/ext/MLDataDevicesChainRulesExt.jl @@ -1,10 +1,13 @@ module MLDataDevicesChainRulesExt using Adapt: Adapt -using MLDataDevices: CPUDevice, AbstractDevice +using MLDataDevices: CPUDevice, AbstractDevice, Internal using ChainRules: OneElement Adapt.adapt_structure(::CPUDevice, x::OneElement) = x Adapt.adapt_structure(to::AbstractDevice, x::OneElement) = Adapt.adapt(to, collect(x)) +Internal.get_device(::OneElement) = CPUDevice() +Internal.get_device_type(::OneElement) = CPUDevice + end diff --git a/lib/MLDataDevices/ext/MLDataDevicesFillArraysExt.jl b/lib/MLDataDevices/ext/MLDataDevicesFillArraysExt.jl index 0abe03b116..2b9c1875c7 100644 --- a/lib/MLDataDevices/ext/MLDataDevicesFillArraysExt.jl +++ b/lib/MLDataDevices/ext/MLDataDevicesFillArraysExt.jl @@ -2,9 +2,12 @@ module MLDataDevicesFillArraysExt using Adapt: Adapt using FillArrays: AbstractFill -using MLDataDevices: CPUDevice, AbstractDevice +using MLDataDevices: CPUDevice, AbstractDevice, Internal Adapt.adapt_structure(::CPUDevice, x::AbstractFill) = x Adapt.adapt_structure(to::AbstractDevice, x::AbstractFill) = Adapt.adapt(to, collect(x)) +Internal.get_device(::AbstractFill) = CPUDevice() +Internal.get_device_type(::AbstractFill) = CPUDevice + end diff --git a/lib/MLDataDevices/ext/MLDataDevicesZygoteExt.jl b/lib/MLDataDevices/ext/MLDataDevicesZygoteExt.jl index ce45d8f1eb..14f4de1280 100644 --- a/lib/MLDataDevices/ext/MLDataDevicesZygoteExt.jl +++ b/lib/MLDataDevices/ext/MLDataDevicesZygoteExt.jl @@ -1,10 +1,13 @@ module MLDataDevicesZygoteExt using Adapt: Adapt -using MLDataDevices: CPUDevice, AbstractDevice +using MLDataDevices: CPUDevice, AbstractDevice, Internal using Zygote: OneElement Adapt.adapt_structure(::CPUDevice, x::OneElement) = x Adapt.adapt_structure(to::AbstractDevice, x::OneElement) = Adapt.adapt(to, collect(x)) +Internal.get_device(::OneElement) = CPUDevice() +Internal.get_device_type(::OneElement) = CPUDevice + end diff --git a/lib/MLDataDevices/src/internal.jl b/lib/MLDataDevices/src/internal.jl index cffab87c3b..0dd936b4a4 100644 --- a/lib/MLDataDevices/src/internal.jl +++ b/lib/MLDataDevices/src/internal.jl @@ -125,15 +125,6 @@ function combine_devices(T1::Type{<:AbstractDevice}, T2::Type{<:AbstractDevice}) throw(ArgumentError("Objects are on devices with different types: $(T1) and $(T2).")) end -function ancestor(x) - if applicable(parent, x) - px = parent(x) - px === x && return x - return ancestor(px) - end - return x -end - for op in (:get_device, :get_device_type) cpu_ret_val = op == :get_device ? CPUDevice() : CPUDevice unknown_ret_val = op == :get_device ? UnknownDevice() : UnknownDevice @@ -155,7 +146,12 @@ for op in (:get_device, :get_device_type) end return mapreduce(MLDataDevices.$(op), combine_devices, x) end - return $(op)(ancestor(x)) + if hasmethod(parent, Tuple{typeof(x)}) + parent_x = parent(x) + parent_x === x && return $(cpu_ret_val) + return $(op)(parent_x) + end + return $(cpu_ret_val) end function $(op)(x::Union{Tuple, NamedTuple}) @@ -175,9 +171,6 @@ end get_device(_) = UnknownDevice() get_device_type(_) = UnknownDevice -get_device(::Array) = CPUDevice() -get_device_type(::Array) = CPUDevice - fast_structure(::AbstractArray) = true fast_structure(::Union{Tuple, NamedTuple}) = true for T in (Number, AbstractRNG, Val, Symbol, String, Nothing, AbstractRange) diff --git a/lib/MLDataDevices/src/public.jl b/lib/MLDataDevices/src/public.jl index f8471a0c7b..6b87ef422d 100644 --- a/lib/MLDataDevices/src/public.jl +++ b/lib/MLDataDevices/src/public.jl @@ -363,9 +363,7 @@ function set_device!(::Type{T}, ::Nothing, rank::Integer) where {T <: AbstractDe end # Dispatches for Different Data Structures -for dev in (CPUDevice, AMDGPUDevice, CUDADevice, MetalDevice, oneAPIDevice, ReactantDevice) - @eval (D::$(dev))(x) = Functors.fmap(Base.Fix1(Adapt.adapt, D), x; exclude=isleaf) -end +(D::AbstractDevice)(x) = Functors.fmap(Base.Fix1(Adapt.adapt, D), x; exclude=isleaf) for op in (:get_device, :get_device_type) @eval function $(op)(x) @@ -377,15 +375,9 @@ for op in (:get_device, :get_device_type) end # Adapt Interface -Adapt.adapt_storage(::CPUDevice, x::AbstractArray) = Adapt.adapt(Array, x) -Adapt.adapt_storage(::CPUDevice, rng::AbstractRNG) = rng - -for T in (AMDGPUDevice, CUDADevice, MetalDevice, oneAPIDevice, ReactantDevice) - @eval begin - Adapt.adapt_storage(to::$(T), ::Random.TaskLocalRNG) = default_device_rng(to) - Adapt.adapt_storage(::$(T), rng::AbstractRNG) = rng - end -end +Adapt.adapt_storage(::CPUDevice, x::AbstractArray) = Array(x) +Adapt.adapt_storage(to::AbstractDevice, ::Random.TaskLocalRNG) = default_device_rng(to) +Adapt.adapt_storage(::AbstractDevice, rng::AbstractRNG) = rng """ isleaf(x) -> Bool