Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

feat: device iterators #71

Merged
merged 6 commits into from
Aug 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MLDataDevices"
uuid = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "1.0.3"
version = "1.1.0"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand All @@ -16,6 +16,7 @@ AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Expand All @@ -30,6 +31,7 @@ MLDataDevicesAMDGPUExt = "AMDGPU"
MLDataDevicesCUDAExt = "CUDA"
MLDataDevicesFillArraysExt = "FillArrays"
MLDataDevicesGPUArraysExt = "GPUArrays"
MLDataDevicesMLUtilsExt = "MLUtils"
MLDataDevicesMetalExt = ["GPUArrays", "Metal"]
MLDataDevicesRecursiveArrayToolsExt = "RecursiveArrayTools"
MLDataDevicesReverseDiffExt = "ReverseDiff"
Expand All @@ -47,6 +49,7 @@ ChainRulesCore = "1.23"
FillArrays = "1"
Functors = "0.4.8"
GPUArrays = "10"
MLUtils = "0.4.4"
Metal = "1"
Preferences = "1.4"
Random = "1.10"
Expand Down
7 changes: 6 additions & 1 deletion ext/MLDataDevicesAMDGPUExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,13 @@ function MLDataDevices.set_device!(::Type{AMDGPUDevice}, ::Nothing, rank::Intege
return MLDataDevices.set_device!(AMDGPUDevice, id)
end

# unsafe_free!
function Internal.unsafe_free_internal!(::Type{AMDGPUDevice}, x::AbstractArray)
AMDGPU.unsafe_free!(x)
return
end

# Device Transfer
## To GPU
Adapt.adapt_storage(::AMDGPUDevice{Nothing}, x::AbstractArray) = AMDGPU.roc(x)
function Adapt.adapt_storage(to::AMDGPUDevice, x::AbstractArray)
old_dev = AMDGPU.device() # remember the current device
Expand Down
6 changes: 6 additions & 0 deletions ext/MLDataDevicesCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,12 @@ function MLDataDevices.set_device!(::Type{CUDADevice}, ::Nothing, rank::Integer)
return MLDataDevices.set_device!(CUDADevice, id)
end

# unsafe_free!
function Internal.unsafe_free_internal!(::Type{CUDADevice}, x::AbstractArray)
CUDA.unsafe_free!(x)
return
end

# Device Transfer
Adapt.adapt_storage(::CUDADevice{Nothing}, x::AbstractArray) = CUDA.cu(x)
function Adapt.adapt_storage(to::CUDADevice, x::AbstractArray)
Expand Down
38 changes: 38 additions & 0 deletions ext/MLDataDevicesMLUtilsExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
module MLDataDevicesMLUtilsExt

using MLDataDevices: MLDataDevices, AbstractDevice, CPUDevice, CUDADevice, AMDGPUDevice,
MetalDevice, oneAPIDevice, DeviceIterator
using MLUtils: MLUtils, DataLoader

for dev in (CPUDevice, CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice)
@eval function (D::$(dev))(dataloader::DataLoader)
if dataloader.parallel
if dataloader.buffer
@warn "Using `buffer=true` for parallel DataLoader with automatic device \
transfer is currently not implemented. Ignoring `buffer=true`."
end

# Mostly from https://github.com/JuliaML/MLUtils.jl/blob/main/src/eachobs.jl
data = MLUtils.ObsView(dataloader.data)
data = dataloader.shuffle ? MLUtils.shuffleobs(data) : data
data = if dataloader.batchsize > 0
MLUtils.BatchView(
data; dataloader.batchsize, dataloader.partial, dataloader.collate)
else
data
end

return DeviceIterator(D, eachobsparallel(D, data))
end
return DeviceIterator(D, dataloader)
end
end

function eachobsparallel(dev::AbstractDevice, data)
return MLUtils.Loader(1:MLUtils.numobs(data)) do ch, i
obs = MLUtils.getobs(data, i)
put!(ch, dev(obs))
end
end

end
7 changes: 6 additions & 1 deletion ext/MLDataDevicesMetalExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,13 @@ Internal.get_device(::MtlArray) = MetalDevice()

Internal.get_device_type(::MtlArray) = MetalDevice

# unsafe_free!
function Internal.unsafe_free_internal!(::Type{MetalDevice}, x::AbstractArray)
Metal.unsafe_free!(x)
return
end

# Device Transfer
## To GPU
Adapt.adapt_storage(::MetalDevice, x::AbstractArray) = Metal.mtl(x)

end
7 changes: 6 additions & 1 deletion ext/MLDataDevicesoneAPIExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,13 @@ Internal.get_device(::oneArray) = oneAPIDevice()

Internal.get_device_type(::oneArray) = oneAPIDevice

# unsafe_free!
function Internal.unsafe_free_internal!(::Type{oneAPIDevice}, x::AbstractArray)
oneAPI.unsafe_free!(x)
return
end

# Device Transfer
## To GPU
for (T1, T2) in ((Float64, Float32), (ComplexF64, ComplexF32))
@eval function Adapt.adapt_storage(::oneAPIDevice, x::AbstractArray{$(T1)})
if !SUPPORTS_FP64[oneAPI.device()]
Expand Down
3 changes: 3 additions & 0 deletions src/MLDataDevices.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ abstract type AbstractDevice <: Function end
abstract type AbstractGPUDevice <: AbstractDevice end

include("public.jl")
include("iterator.jl")
include("internal.jl")

export gpu_backend!, supported_gpu_backends, reset_gpu_device!
Expand All @@ -21,4 +22,6 @@ export gpu_device, cpu_device
export CPUDevice, CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice
export get_device, get_device_type

export DeviceIterator

end
13 changes: 13 additions & 0 deletions src/internal.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
module Internal

using Functors: fmap
using Preferences: load_preference
using Random: AbstractRNG
using UnrolledUtilities: unrolled_mapreduce
Expand Down Expand Up @@ -149,4 +150,16 @@ for op in (:get_device, :get_device_type)
end
end

function unsafe_free_internal!(x::AbstractArray)
unsafe_free_internal!(MLDataDevices.get_device_type(x), x)
return
end
unsafe_free_internal!(::Type, x::AbstractArray) = nothing
unsafe_free_internal!(_) = nothing

function unsafe_free!(x)
fmap(unsafe_free_internal!, x)
return
end

end
73 changes: 73 additions & 0 deletions src/iterator.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
"""
DeviceIterator(dev::AbstractDevice, iterator)

Create a `DeviceIterator` that iterates through the provided `iterator` via `iterate`. Upon
each iteration, the current batch is copied to the device `dev`, and the previous iteration
is marked as freeable from GPU memory (via `unsafe_free!`) (no-op for a CPU device).

The conversion follows the same semantics as `dev(<item from iterator>)`.

!!! tip "Similarity to `CUDA.CuIterator`"

The design inspiration was taken from `CUDA.CuIterator` and was generalized to work with
other backends and more complex iterators (using `Functors`).

!!! tip "`MLUtils.DataLoader`"

Calling `dev(::MLUtils.DataLoader)` will automatically convert the dataloader to use the
same semantics as `DeviceIterator`. This is generally preferred over looping over the
dataloader directly and transferring the data to the device.

## Examples

The following was run on a computer with an NVIDIA GPU.

```julia-repl
julia> using MLDataDevices, MLUtils

julia> X = rand(Float64, 3, 33);

julia> dataloader = DataLoader(X; batchsize=13, shuffle=false);

julia> for (i, x) in enumerate(dataloader)
@show i, summary(x)
end
(i, summary(x)) = (1, "3×13 Matrix{Float64}")
(i, summary(x)) = (2, "3×13 Matrix{Float64}")
(i, summary(x)) = (3, "3×7 Matrix{Float64}")

julia> for (i, x) in enumerate(CUDADevice()(dataloader))
@show i, summary(x)
end
(i, summary(x)) = (1, "3×13 CuArray{Float32, 2, CUDA.DeviceMemory}")
(i, summary(x)) = (2, "3×13 CuArray{Float32, 2, CUDA.DeviceMemory}")
(i, summary(x)) = (3, "3×7 CuArray{Float32, 2, CUDA.DeviceMemory}")
```
"""
struct DeviceIterator{D <: AbstractDevice, I}
dev::D
iterator::I
end

function Base.iterate(c::DeviceIterator)
item = iterate(c.iterator)
item === nothing && return nothing
batch, next_state = item
dev_batch = c.dev(batch)
return dev_batch, (next_state, dev_batch)
end

function Base.iterate(c::DeviceIterator, (state, prev_batch))
item = iterate(c.iterator, state)
item === nothing && return nothing
batch, next_state = item
Internal.unsafe_free!(prev_batch) # free the previous batch
dev_batch = c.dev(batch)
return dev_batch, (next_state, dev_batch)
end

Base.IteratorSize(::Type{DeviceIterator{D, I}}) where {D, I} = Base.IteratorSize(I)
Base.length(c::DeviceIterator) = length(c.iterator)
Base.axes(c::DeviceIterator) = axes(c.iterator)

Base.IteratorEltype(::Type{DeviceIterator{D, I}}) where {D, I} = Base.EltypeUnknown()
2 changes: 1 addition & 1 deletion src/public.jl
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ end
# For all other types we rely on fmap which means we lose type stability.
# For Lux, typically models only has these 3 datastructures so we should be mostly fine.
for (dev) in (:CPU, :CUDA, :AMDGPU, :Metal, :oneAPI)
ldev = Symbol("$(dev)Device")
ldev = Symbol(dev, :Device)
@eval begin
function (D::$(ldev))(x::AbstractArray{T}) where {T}
return (isbitstype(T) || Internal.special_aos(x)) ? Adapt.adapt(D, x) :
Expand Down
2 changes: 2 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
Expand All @@ -28,6 +29,7 @@ ExplicitImports = "1.9.0"
FillArrays = "1"
ForwardDiff = "0.10.36"
Functors = "0.4.8"
MLUtils = "0.4"
Pkg = "1.10"
Random = "1.10"
RecursiveArrayTools = "3.8"
Expand Down
106 changes: 106 additions & 0 deletions test/iterator_tests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
using MLDataDevices, MLUtils

const BACKEND_GROUP = lowercase(get(ENV, "BACKEND_GROUP", "none"))

if BACKEND_GROUP == "cuda" || BACKEND_GROUP == "all"
using LuxCUDA
end

if BACKEND_GROUP == "amdgpu" || BACKEND_GROUP == "all"
using AMDGPU
end

if BACKEND_GROUP == "metal" || BACKEND_GROUP == "all"
using Metal
end

if BACKEND_GROUP == "oneapi" || BACKEND_GROUP == "all"
using oneAPI
end

DEVICES = [CPUDevice, CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice]

freed_if_can_be_freed(x) = freed_if_can_be_freed(get_device_type(x), x)
freed_if_can_be_freed(::Type{CPUDevice}, x) = true
function freed_if_can_be_freed(::Type, x)
try
Array(x)
return false
catch err
err isa ArgumentError && return true
rethrow()
end
end

@testset "Device Iterator: $(dev_type)" for dev_type in DEVICES
dev = dev_type()

!MLDataDevices.functional(dev) && continue

@info "Testing Device Iterator for $(dev)..."

@testset "Basic Device Iterator" begin
datalist = [rand(10) for _ in 1:10]

prev_batch = nothing
for data in DeviceIterator(dev, datalist)
prev_batch === nothing || @test freed_if_can_be_freed(prev_batch)
prev_batch = data
@test size(data) == (10,)
@test get_device_type(data) == dev_type
end
end

@testset "DataLoader: parallel=$parallel" for parallel in (true, false)
X = rand(Float64, 3, 33)
pre = DataLoader(dev(X); batchsize=13, shuffle=false)
post = DataLoader(X; batchsize=13, shuffle=false) |> dev

for epoch in 1:2
prev_pre, prev_post = nothing, nothing
for (p, q) in zip(pre, post)
@test get_device_type(p) == dev_type
@test get_device_type(q) == dev_type
@test p ≈ q

dev_type === CPUDevice && continue

prev_pre === nothing || @test !freed_if_can_be_freed(prev_pre)
prev_pre = p

prev_post === nothing || @test freed_if_can_be_freed(prev_post)
prev_post = q
end
end

Y = rand(Float64, 1, 33)
pre = DataLoader((; x=dev(X), y=dev(Y)); batchsize=13, shuffle=false)
post = DataLoader((; x=X, y=Y); batchsize=13, shuffle=false) |> dev

for epoch in 1:2
prev_pre, prev_post = nothing, nothing
for (p, q) in zip(pre, post)
@test get_device_type(p.x) == dev_type
@test get_device_type(p.y) == dev_type
@test get_device_type(q.x) == dev_type
@test get_device_type(q.y) == dev_type
@test p.x ≈ q.x
@test p.y ≈ q.y

dev_type === CPUDevice && continue

if prev_pre !== nothing
@test !freed_if_can_be_freed(prev_pre.x)
@test !freed_if_can_be_freed(prev_pre.y)
end
prev_pre = p

if prev_post !== nothing
@test freed_if_can_be_freed(prev_post.x)
@test freed_if_can_be_freed(prev_post.y)
end
prev_post = q
end
end
end
end
8 changes: 5 additions & 3 deletions test/qa_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ import FillArrays, RecursiveArrayTools, SparseArrays, Zygote
@test check_no_stale_explicit_imports(MLDataDevices) === nothing
@test check_no_self_qualified_accesses(MLDataDevices) === nothing
@test check_all_explicit_imports_via_owners(MLDataDevices) === nothing
@test check_all_qualified_accesses_via_owners(MLDataDevices) === nothing
@test_broken check_all_explicit_imports_are_public(MLDataDevices) === nothing # mostly upstream problems
@test_broken check_all_qualified_accesses_are_public(MLDataDevices) === nothing # mostly upstream problem
@test check_all_qualified_accesses_via_owners(
MLDataDevices; ignore=(:SparseArrays,)) === nothing
# mostly upstream problems
@test_broken check_all_explicit_imports_are_public(MLDataDevices) === nothing
@test_broken check_all_qualified_accesses_are_public(MLDataDevices) === nothing
end
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ end
Test.@test true
end

@safetestset "Iterator Tests" include("iterator_tests.jl")
@safetestset "Misc Tests" include("misc_tests.jl")

@safetestset "QA Tests" include("qa_tests.jl")
end
Loading