Skip to content

Commit

Permalink
Apply suggestions from code review
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal authored Nov 6, 2024
1 parent 4c3714f commit bfb5021
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 14 deletions.
10 changes: 3 additions & 7 deletions lib/MLDataDevices/ext/MLDataDevicesChainRulesExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,10 @@ using ChainRules: OneElement

Adapt.adapt_storage(::CPUDevice, x::OneElement) = x

for Dev in (CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice, ReactantDevice)
for dev in (CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice, ReactantDevice,
CUDADevice{Nothing}, AMDGPUDevice{Nothing})
# use `@eval` to avoid ambiguity with adapt_storage(::CUDADevice, ::AbstractArray)
@eval Adapt.adapt_storage(to::$Dev, x::OneElement) = Adapt.adapt(to, collect(x))
end

for Dev in (CUDADevice, AMDGPUDevice)
# use `@eval` to avoid ambiguity with adapt_storage(::CUDADevice{Nothing}, ::AbstractArray)
@eval Adapt.adapt_storage(to::$Dev{Nothing}, x::OneElement) = Adapt.adapt(to, collect(x))
@eval Adapt.adapt_storage(to::$(dev), x::OneElement) = Adapt.adapt(to, collect(x))
end

end
10 changes: 3 additions & 7 deletions lib/MLDataDevices/ext/MLDataDevicesZygoteExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,10 @@ using Zygote: OneElement

Adapt.adapt_storage(::CPUDevice, x::OneElement) = x

for Dev in (CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice, ReactantDevice)
for dev in (CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice, ReactantDevice,
CUDADevice{Nothing}, AMDGPUDevice{Nothing})
# use `@eval` to avoid ambiguity with adapt_storage(::CUDADevice, ::AbstractArray)
@eval Adapt.adapt_storage(to::$Dev, x::OneElement) = Adapt.adapt(to, collect(x))
end

for Dev in (CUDADevice, AMDGPUDevice)
# use `@eval` to avoid ambiguity with adapt_storage(::CUDADevice{Nothing}, ::AbstractArray)
@eval Adapt.adapt_storage(to::$Dev{Nothing}, x::OneElement) = Adapt.adapt(to, collect(x))
@eval Adapt.adapt_storage(to::$(dev), x::OneElement) = Adapt.adapt(to, collect(x))
end

end
Expand Down

0 comments on commit bfb5021

Please sign in to comment.