Skip to content

Commit

Permalink
fix: partial revert of the ancestor patch
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Dec 4, 2024
1 parent 7a35bdc commit 72b48e1
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 28 deletions.
5 changes: 4 additions & 1 deletion lib/MLDataDevices/ext/MLDataDevicesChainRulesExt.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
module MLDataDevicesChainRulesExt

using Adapt: Adapt
using MLDataDevices: CPUDevice, AbstractDevice
using MLDataDevices: CPUDevice, AbstractDevice, Internal
using ChainRules: OneElement

Adapt.adapt_structure(::CPUDevice, x::OneElement) = x
Adapt.adapt_structure(to::AbstractDevice, x::OneElement) = Adapt.adapt(to, collect(x))

Internal.get_device(::OneElement) = CPUDevice()
Internal.get_device_type(::OneElement) = CPUDevice

end
5 changes: 4 additions & 1 deletion lib/MLDataDevices/ext/MLDataDevicesFillArraysExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@ module MLDataDevicesFillArraysExt

using Adapt: Adapt
using FillArrays: AbstractFill
using MLDataDevices: CPUDevice, AbstractDevice
using MLDataDevices: CPUDevice, AbstractDevice, Internal

Adapt.adapt_structure(::CPUDevice, x::AbstractFill) = x
Adapt.adapt_structure(to::AbstractDevice, x::AbstractFill) = Adapt.adapt(to, collect(x))

Internal.get_device(::AbstractFill) = CPUDevice()
Internal.get_device_type(::AbstractFill) = CPUDevice

end
5 changes: 4 additions & 1 deletion lib/MLDataDevices/ext/MLDataDevicesZygoteExt.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
module MLDataDevicesZygoteExt

using Adapt: Adapt
using MLDataDevices: CPUDevice, AbstractDevice
using MLDataDevices: CPUDevice, AbstractDevice, Internal
using Zygote: OneElement

Adapt.adapt_structure(::CPUDevice, x::OneElement) = x
Adapt.adapt_structure(to::AbstractDevice, x::OneElement) = Adapt.adapt(to, collect(x))

Internal.get_device(::OneElement) = CPUDevice()
Internal.get_device_type(::OneElement) = CPUDevice

end
19 changes: 6 additions & 13 deletions lib/MLDataDevices/src/internal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -125,15 +125,6 @@ function combine_devices(T1::Type{<:AbstractDevice}, T2::Type{<:AbstractDevice})
throw(ArgumentError("Objects are on devices with different types: $(T1) and $(T2)."))
end

function ancestor(x)
if applicable(parent, x)
px = parent(x)
px === x && return x
return ancestor(px)
end
return x
end

for op in (:get_device, :get_device_type)
cpu_ret_val = op == :get_device ? CPUDevice() : CPUDevice
unknown_ret_val = op == :get_device ? UnknownDevice() : UnknownDevice
Expand All @@ -155,7 +146,12 @@ for op in (:get_device, :get_device_type)
end
return mapreduce(MLDataDevices.$(op), combine_devices, x)
end
return $(op)(ancestor(x))
if hasmethod(parent, Tuple{typeof(x)})
parent_x = parent(x)
parent_x === x && return $(cpu_ret_val)
return $(op)(parent_x)
end
return $(cpu_ret_val)
end

function $(op)(x::Union{Tuple, NamedTuple})
Expand All @@ -175,9 +171,6 @@ end
get_device(_) = UnknownDevice()
get_device_type(_) = UnknownDevice

get_device(::Array) = CPUDevice()
get_device_type(::Array) = CPUDevice

fast_structure(::AbstractArray) = true
fast_structure(::Union{Tuple, NamedTuple}) = true
for T in (Number, AbstractRNG, Val, Symbol, String, Nothing, AbstractRange)
Expand Down
16 changes: 4 additions & 12 deletions lib/MLDataDevices/src/public.jl
Original file line number Diff line number Diff line change
Expand Up @@ -363,9 +363,7 @@ function set_device!(::Type{T}, ::Nothing, rank::Integer) where {T <: AbstractDe
end

# Dispatches for Different Data Structures
for dev in (CPUDevice, AMDGPUDevice, CUDADevice, MetalDevice, oneAPIDevice, ReactantDevice)
@eval (D::$(dev))(x) = Functors.fmap(Base.Fix1(Adapt.adapt, D), x; exclude=isleaf)
end
(D::AbstractDevice)(x) = Functors.fmap(Base.Fix1(Adapt.adapt, D), x; exclude=isleaf)

for op in (:get_device, :get_device_type)
@eval function $(op)(x)
Expand All @@ -377,15 +375,9 @@ for op in (:get_device, :get_device_type)
end

# Adapt Interface
Adapt.adapt_storage(::CPUDevice, x::AbstractArray) = Adapt.adapt(Array, x)
Adapt.adapt_storage(::CPUDevice, rng::AbstractRNG) = rng

for T in (AMDGPUDevice, CUDADevice, MetalDevice, oneAPIDevice, ReactantDevice)
@eval begin
Adapt.adapt_storage(to::$(T), ::Random.TaskLocalRNG) = default_device_rng(to)
Adapt.adapt_storage(::$(T), rng::AbstractRNG) = rng
end
end
Adapt.adapt_storage(::CPUDevice, x::AbstractArray) = Array(x)
Adapt.adapt_storage(to::AbstractDevice, ::Random.TaskLocalRNG) = default_device_rng(to)
Adapt.adapt_storage(::AbstractDevice, rng::AbstractRNG) = rng

"""
isleaf(x) -> Bool
Expand Down

0 comments on commit 72b48e1

Please sign in to comment.