Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Commit

Permalink
feat: add DeviceIterator (and support parallel Device DataLoader)
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Aug 28, 2024
1 parent 50df4fa commit 1b9d552
Show file tree
Hide file tree
Showing 8 changed files with 102 additions and 6 deletions.
5 changes: 4 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MLDataDevices"
uuid = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "1.0.3"
version = "1.1.0"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand All @@ -16,6 +16,7 @@ AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Expand All @@ -30,6 +31,7 @@ MLDataDevicesAMDGPUExt = "AMDGPU"
MLDataDevicesCUDAExt = "CUDA"
MLDataDevicesFillArraysExt = "FillArrays"
MLDataDevicesGPUArraysExt = "GPUArrays"
MLDataDevicesMLUtilsExt = "MLUtils"
MLDataDevicesMetalExt = ["GPUArrays", "Metal"]
MLDataDevicesRecursiveArrayToolsExt = "RecursiveArrayTools"
MLDataDevicesReverseDiffExt = "ReverseDiff"
Expand All @@ -47,6 +49,7 @@ ChainRulesCore = "1.23"
FillArrays = "1"
Functors = "0.4.8"
GPUArrays = "10"
MLUtils = "0.4"
Metal = "1"
Preferences = "1.4"
Random = "1.10"
Expand Down
1 change: 0 additions & 1 deletion ext/MLDataDevicesAMDGPUExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ function Internal.unsafe_free_internal!(::Type{AMDGPUDevice}, x::AbstractArray)
end

# Device Transfer
## To GPU
Adapt.adapt_storage(::AMDGPUDevice{Nothing}, x::AbstractArray) = AMDGPU.roc(x)
function Adapt.adapt_storage(to::AMDGPUDevice, x::AbstractArray)
old_dev = AMDGPU.device() # remember the current device
Expand Down
60 changes: 59 additions & 1 deletion ext/MLDataDevicesMLUtilsExt.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,63 @@
module MLDataDevicesMLUtilsExt

using MLUtils: DataLoader
using MLDataDevices: MLDataDevices, AbstractDevice, AbstractDeviceIterator, CPUDevice,
CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice, DeviceIterator,
Internal
using MLUtils: MLUtils, DataLoader

for (dev) in (:CPU, :CUDA, :AMDGPU, :Metal, :oneAPI)
ldev = Symbol(dev, :Device)
@eval function (D::$(ldev))(dataloader::DataLoader)
if dataloader.parallel
if dataloader.buffer
@warn "Using `buffer=true` for parallel DataLoader with automatic device \
transfer is currently not implemented. Ignoring `buffer=true`."
end
return ParallelDeviceDataLoader(D, dataloader)
end
return DeviceIterator(D, dataloader)
end
end

# Parallel DataLoader that does the device transfer in the same task
struct ParallelDeviceDataLoader{D <: AbstractDevice, DL <: DataLoader} <:
AbstractDeviceIterator{D, DL}
dev::D
iterator::DL
end

# Mostly from https://github.com/JuliaML/MLUtils.jl/blob/main/src/eachobs.jl
function Base.iterate(c::ParallelDeviceDataLoader)
data = MLUtils.ObsView(c.iterator.data)

data = c.iterator.shuffle ? MLUtils.shuffleobs(c.iterator.rng, data) : data
data = if c.iterator.batchsize > 0
MLUtils.BatchView(
data; c.iterator.batchsize, c.iterator.partial, c.iterator.collate)
else
data
end

iter = eachobsparallel(c.dev, data)
item = iterate(iter)
item === nothing && return nothing
dev_batch, next_state = item
return dev_batch, ((iter, next_state), dev_batch)
end

function Base.iterate(::ParallelDeviceDataLoader, ((iter, state), prev_batch))
item = iterate(iter, state)
item === nothing && return nothing
dev_batch, next_state = item
Internal.unsafe_free!(prev_batch) # free the previous batch
return dev_batch, ((iter, next_state), dev_batch)
end

function eachobsparallel(dev::AbstractDevice, data)
return MLUtils.Loader(1:MLUtils.numobs(data)) do ch, i
obs = MLUtils.getobs(data, i)
put!(ch, dev(obs))
end
end

end
1 change: 0 additions & 1 deletion ext/MLDataDevicesMetalExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ function Internal.unsafe_free_internal!(::Type{MetalDevice}, x::AbstractArray)
end

# Device Transfer
## To GPU
Adapt.adapt_storage(::MetalDevice, x::AbstractArray) = Metal.mtl(x)

end
1 change: 0 additions & 1 deletion ext/MLDataDevicesoneAPIExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ function Internal.unsafe_free_internal!(::Type{oneAPIDevice}, x::AbstractArray)
end

# Device Transfer
## To GPU
for (T1, T2) in ((Float64, Float32), (ComplexF64, ComplexF32))
@eval function Adapt.adapt_storage(::oneAPIDevice, x::AbstractArray{$(T1)})
if !SUPPORTS_FP64[oneAPI.device()]
Expand Down
3 changes: 3 additions & 0 deletions src/MLDataDevices.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ abstract type AbstractDevice <: Function end
abstract type AbstractGPUDevice <: AbstractDevice end

include("public.jl")
include("iterator.jl")
include("internal.jl")

export gpu_backend!, supported_gpu_backends, reset_gpu_device!
Expand All @@ -21,4 +22,6 @@ export gpu_device, cpu_device
export CPUDevice, CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice
export get_device, get_device_type

export DeviceIterator

end
35 changes: 35 additions & 0 deletions src/iterator.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
abstract type AbstractDeviceIterator{D <: AbstractDevice, I} end

function Base.IteratorSize(::Type{AbstractDeviceIterator{D, I}}) where {D, I}
return Base.IteratorSize(I)
end
Base.length(c::AbstractDeviceIterator) = length(c.iterator)
Base.axes(c::AbstractDeviceIterator) = axes(c.iterator)

function Base.IteratorEltype(::Type{AbstractDeviceIterator{D, I}}) where {D, I}
return Base.IteratorEltype(I)
end
Base.eltype(c::AbstractDeviceIterator) = eltype(c.iterator)

# This is based on CuIterator but generalized to work with any device
struct DeviceIterator{D, I} <: AbstractDeviceIterator{D, I}
dev::D
iterator::I
end

function Base.iterate(c::DeviceIterator)
item = iterate(c.iterator)
item === nothing && return nothing
batch, next_state = item
dev_batch = c.dev(batch)
return dev_batch, (next_state, dev_batch)
end

function Base.iterate(c::DeviceIterator, (state, prev_batch))
item = iterate(c.iterator, state)
item === nothing && return nothing
batch, next_state = item
Internal.unsafe_free!(prev_batch) # free the previous batch
dev_batch = c.dev(batch)
return dev_batch, (next_state, dev_batch)
end
2 changes: 1 addition & 1 deletion src/public.jl
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ end
# For all other types we rely on fmap which means we lose type stability.
# For Lux, typically models only has these 3 datastructures so we should be mostly fine.
for (dev) in (:CPU, :CUDA, :AMDGPU, :Metal, :oneAPI)
ldev = Symbol("$(dev)Device")
ldev = Symbol(dev, :Device)
@eval begin
function (D::$(ldev))(x::AbstractArray{T}) where {T}
return (isbitstype(T) || Internal.special_aos(x)) ? Adapt.adapt(D, x) :
Expand Down

0 comments on commit 1b9d552

Please sign in to comment.