From 50df4fa300d274d70ef3c801a5cb0267b4d1e624 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 28 Aug 2024 13:03:15 -0400 Subject: [PATCH 1/6] feat: add `unsafe_free!` --- ext/MLDataDevicesAMDGPUExt.jl | 6 ++++++ ext/MLDataDevicesCUDAExt.jl | 6 ++++++ ext/MLDataDevicesMLUtilsExt.jl | 5 +++++ ext/MLDataDevicesMetalExt.jl | 6 ++++++ ext/MLDataDevicesoneAPIExt.jl | 6 ++++++ src/internal.jl | 13 +++++++++++++ 6 files changed, 42 insertions(+) create mode 100644 ext/MLDataDevicesMLUtilsExt.jl diff --git a/ext/MLDataDevicesAMDGPUExt.jl b/ext/MLDataDevicesAMDGPUExt.jl index e539a15..53bda67 100644 --- a/ext/MLDataDevicesAMDGPUExt.jl +++ b/ext/MLDataDevicesAMDGPUExt.jl @@ -64,6 +64,12 @@ function MLDataDevices.set_device!(::Type{AMDGPUDevice}, ::Nothing, rank::Intege return MLDataDevices.set_device!(AMDGPUDevice, id) end +# unsafe_free! +function Internal.unsafe_free_internal!(::Type{AMDGPUDevice}, x::AbstractArray) + AMDGPU.unsafe_free!(x) + return +end + # Device Transfer ## To GPU Adapt.adapt_storage(::AMDGPUDevice{Nothing}, x::AbstractArray) = AMDGPU.roc(x) diff --git a/ext/MLDataDevicesCUDAExt.jl b/ext/MLDataDevicesCUDAExt.jl index cc4cde4..3492440 100644 --- a/ext/MLDataDevicesCUDAExt.jl +++ b/ext/MLDataDevicesCUDAExt.jl @@ -42,6 +42,12 @@ function MLDataDevices.set_device!(::Type{CUDADevice}, ::Nothing, rank::Integer) return MLDataDevices.set_device!(CUDADevice, id) end +# unsafe_free! +function Internal.unsafe_free_internal!(::Type{CUDADevice}, x::AbstractArray) + CUDA.unsafe_free!(x) + return +end + # Device Transfer Adapt.adapt_storage(::CUDADevice{Nothing}, x::AbstractArray) = CUDA.cu(x) function Adapt.adapt_storage(to::CUDADevice, x::AbstractArray) diff --git a/ext/MLDataDevicesMLUtilsExt.jl b/ext/MLDataDevicesMLUtilsExt.jl new file mode 100644 index 0000000..a54da03 --- /dev/null +++ b/ext/MLDataDevicesMLUtilsExt.jl @@ -0,0 +1,5 @@ +module MLDataDevicesMLUtilsExt + +using MLUtils: DataLoader + +end diff --git a/ext/MLDataDevicesMetalExt.jl b/ext/MLDataDevicesMetalExt.jl index 87d0b0e..ffc4bc9 100644 --- a/ext/MLDataDevicesMetalExt.jl +++ b/ext/MLDataDevicesMetalExt.jl @@ -18,6 +18,12 @@ Internal.get_device(::MtlArray) = MetalDevice() Internal.get_device_type(::MtlArray) = MetalDevice +# unsafe_free! +function Internal.unsafe_free_internal!(::Type{MetalDevice}, x::AbstractArray) + Metal.unsafe_free!(x) + return +end + # Device Transfer ## To GPU Adapt.adapt_storage(::MetalDevice, x::AbstractArray) = Metal.mtl(x) diff --git a/ext/MLDataDevicesoneAPIExt.jl b/ext/MLDataDevicesoneAPIExt.jl index 4bda871..130bad2 100644 --- a/ext/MLDataDevicesoneAPIExt.jl +++ b/ext/MLDataDevicesoneAPIExt.jl @@ -29,6 +29,12 @@ Internal.get_device(::oneArray) = oneAPIDevice() Internal.get_device_type(::oneArray) = oneAPIDevice +# unsafe_free! +function Internal.unsafe_free_internal!(::Type{oneAPIDevice}, x::AbstractArray) + oneAPI.unsafe_free!(x) + return +end + # Device Transfer ## To GPU for (T1, T2) in ((Float64, Float32), (ComplexF64, ComplexF32)) diff --git a/src/internal.jl b/src/internal.jl index e894649..f2c807e 100644 --- a/src/internal.jl +++ b/src/internal.jl @@ -1,5 +1,6 @@ module Internal +using Functors: fmap using Preferences: load_preference using Random: AbstractRNG using UnrolledUtilities: unrolled_mapreduce @@ -149,4 +150,16 @@ for op in (:get_device, :get_device_type) end end +function unsafe_free_internal!(x::AbstractArray) + unsafe_free_internal!(MLDataDevices.get_device_type(x), x) + return +end +unsafe_free_internal!(::Type, x::AbstractArray) = nothing +unsafe_free_internal!(_) = nothing + +function unsafe_free!(x) + fmap(unsafe_free_internal!, x) + return +end + end From a0756e92986299dde6b02d8d59cb4574567cab77 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 28 Aug 2024 14:32:58 -0400 Subject: [PATCH 2/6] feat: add DeviceIterator (and support parallel Device DataLoader) --- Project.toml | 5 ++- ext/MLDataDevicesAMDGPUExt.jl | 1 - ext/MLDataDevicesMLUtilsExt.jl | 60 +++++++++++++++++++++++++++++++++- ext/MLDataDevicesMetalExt.jl | 1 - ext/MLDataDevicesoneAPIExt.jl | 1 - src/MLDataDevices.jl | 3 ++ src/iterator.jl | 35 ++++++++++++++++++++ src/public.jl | 2 +- 8 files changed, 102 insertions(+), 6 deletions(-) create mode 100644 src/iterator.jl diff --git a/Project.toml b/Project.toml index 9106f79..35da279 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "MLDataDevices" uuid = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" authors = ["Avik Pal and contributors"] -version = "1.0.3" +version = "1.1.0" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" @@ -16,6 +16,7 @@ AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" +MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" Metal = "dde4c033-4e86-420c-a63e-0dd931031962" RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" @@ -30,6 +31,7 @@ MLDataDevicesAMDGPUExt = "AMDGPU" MLDataDevicesCUDAExt = "CUDA" MLDataDevicesFillArraysExt = "FillArrays" MLDataDevicesGPUArraysExt = "GPUArrays" +MLDataDevicesMLUtilsExt = "MLUtils" MLDataDevicesMetalExt = ["GPUArrays", "Metal"] MLDataDevicesRecursiveArrayToolsExt = "RecursiveArrayTools" MLDataDevicesReverseDiffExt = "ReverseDiff" @@ -47,6 +49,7 @@ ChainRulesCore = "1.23" FillArrays = "1" Functors = "0.4.8" GPUArrays = "10" +MLUtils = "0.4" Metal = "1" Preferences = "1.4" Random = "1.10" diff --git a/ext/MLDataDevicesAMDGPUExt.jl b/ext/MLDataDevicesAMDGPUExt.jl index 53bda67..4014b2e 100644 --- a/ext/MLDataDevicesAMDGPUExt.jl +++ b/ext/MLDataDevicesAMDGPUExt.jl @@ -71,7 +71,6 @@ function Internal.unsafe_free_internal!(::Type{AMDGPUDevice}, x::AbstractArray) end # Device Transfer -## To GPU Adapt.adapt_storage(::AMDGPUDevice{Nothing}, x::AbstractArray) = AMDGPU.roc(x) function Adapt.adapt_storage(to::AMDGPUDevice, x::AbstractArray) old_dev = AMDGPU.device() # remember the current device diff --git a/ext/MLDataDevicesMLUtilsExt.jl b/ext/MLDataDevicesMLUtilsExt.jl index a54da03..57db601 100644 --- a/ext/MLDataDevicesMLUtilsExt.jl +++ b/ext/MLDataDevicesMLUtilsExt.jl @@ -1,5 +1,63 @@ module MLDataDevicesMLUtilsExt -using MLUtils: DataLoader +using MLDataDevices: MLDataDevices, AbstractDevice, AbstractDeviceIterator, CPUDevice, + CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice, DeviceIterator, + Internal +using MLUtils: MLUtils, DataLoader + +for (dev) in (:CPU, :CUDA, :AMDGPU, :Metal, :oneAPI) + ldev = Symbol(dev, :Device) + @eval function (D::$(ldev))(dataloader::DataLoader) + if dataloader.parallel + if dataloader.buffer + @warn "Using `buffer=true` for parallel DataLoader with automatic device \ + transfer is currently not implemented. Ignoring `buffer=true`." + end + return ParallelDeviceDataLoader(D, dataloader) + end + return DeviceIterator(D, dataloader) + end +end + +# Parallel DataLoader that does the device transfer in the same task +struct ParallelDeviceDataLoader{D <: AbstractDevice, DL <: DataLoader} <: + AbstractDeviceIterator{D, DL} + dev::D + iterator::DL +end + +# Mostly from https://github.com/JuliaML/MLUtils.jl/blob/main/src/eachobs.jl +function Base.iterate(c::ParallelDeviceDataLoader) + data = MLUtils.ObsView(c.iterator.data) + + data = c.iterator.shuffle ? MLUtils.shuffleobs(c.iterator.rng, data) : data + data = if c.iterator.batchsize > 0 + MLUtils.BatchView( + data; c.iterator.batchsize, c.iterator.partial, c.iterator.collate) + else + data + end + + iter = eachobsparallel(c.dev, data) + item = iterate(iter) + item === nothing && return nothing + dev_batch, next_state = item + return dev_batch, ((iter, next_state), dev_batch) +end + +function Base.iterate(::ParallelDeviceDataLoader, ((iter, state), prev_batch)) + item = iterate(iter, state) + item === nothing && return nothing + dev_batch, next_state = item + Internal.unsafe_free!(prev_batch) # free the previous batch + return dev_batch, ((iter, next_state), dev_batch) +end + +function eachobsparallel(dev::AbstractDevice, data) + return MLUtils.Loader(1:MLUtils.numobs(data)) do ch, i + obs = MLUtils.getobs(data, i) + put!(ch, dev(obs)) + end +end end diff --git a/ext/MLDataDevicesMetalExt.jl b/ext/MLDataDevicesMetalExt.jl index ffc4bc9..e5eb16d 100644 --- a/ext/MLDataDevicesMetalExt.jl +++ b/ext/MLDataDevicesMetalExt.jl @@ -25,7 +25,6 @@ function Internal.unsafe_free_internal!(::Type{MetalDevice}, x::AbstractArray) end # Device Transfer -## To GPU Adapt.adapt_storage(::MetalDevice, x::AbstractArray) = Metal.mtl(x) end diff --git a/ext/MLDataDevicesoneAPIExt.jl b/ext/MLDataDevicesoneAPIExt.jl index 130bad2..75fc2f0 100644 --- a/ext/MLDataDevicesoneAPIExt.jl +++ b/ext/MLDataDevicesoneAPIExt.jl @@ -36,7 +36,6 @@ function Internal.unsafe_free_internal!(::Type{oneAPIDevice}, x::AbstractArray) end # Device Transfer -## To GPU for (T1, T2) in ((Float64, Float32), (ComplexF64, ComplexF32)) @eval function Adapt.adapt_storage(::oneAPIDevice, x::AbstractArray{$(T1)}) if !SUPPORTS_FP64[oneAPI.device()] diff --git a/src/MLDataDevices.jl b/src/MLDataDevices.jl index b7636db..574fea4 100644 --- a/src/MLDataDevices.jl +++ b/src/MLDataDevices.jl @@ -12,6 +12,7 @@ abstract type AbstractDevice <: Function end abstract type AbstractGPUDevice <: AbstractDevice end include("public.jl") +include("iterator.jl") include("internal.jl") export gpu_backend!, supported_gpu_backends, reset_gpu_device! @@ -21,4 +22,6 @@ export gpu_device, cpu_device export CPUDevice, CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice export get_device, get_device_type +export DeviceIterator + end diff --git a/src/iterator.jl b/src/iterator.jl new file mode 100644 index 0000000..47969be --- /dev/null +++ b/src/iterator.jl @@ -0,0 +1,35 @@ +abstract type AbstractDeviceIterator{D <: AbstractDevice, I} end + +function Base.IteratorSize(::Type{AbstractDeviceIterator{D, I}}) where {D, I} + return Base.IteratorSize(I) +end +Base.length(c::AbstractDeviceIterator) = length(c.iterator) +Base.axes(c::AbstractDeviceIterator) = axes(c.iterator) + +function Base.IteratorEltype(::Type{AbstractDeviceIterator{D, I}}) where {D, I} + return Base.IteratorEltype(I) +end +Base.eltype(c::AbstractDeviceIterator) = eltype(c.iterator) + +# This is based on CuIterator but generalized to work with any device +struct DeviceIterator{D, I} <: AbstractDeviceIterator{D, I} + dev::D + iterator::I +end + +function Base.iterate(c::DeviceIterator) + item = iterate(c.iterator) + item === nothing && return nothing + batch, next_state = item + dev_batch = c.dev(batch) + return dev_batch, (next_state, dev_batch) +end + +function Base.iterate(c::DeviceIterator, (state, prev_batch)) + item = iterate(c.iterator, state) + item === nothing && return nothing + batch, next_state = item + Internal.unsafe_free!(prev_batch) # free the previous batch + dev_batch = c.dev(batch) + return dev_batch, (next_state, dev_batch) +end diff --git a/src/public.jl b/src/public.jl index ac53ee5..d7a7d27 100644 --- a/src/public.jl +++ b/src/public.jl @@ -293,7 +293,7 @@ end # For all other types we rely on fmap which means we lose type stability. # For Lux, typically models only has these 3 datastructures so we should be mostly fine. for (dev) in (:CPU, :CUDA, :AMDGPU, :Metal, :oneAPI) - ldev = Symbol("$(dev)Device") + ldev = Symbol(dev, :Device) @eval begin function (D::$(ldev))(x::AbstractArray{T}) where {T} return (isbitstype(T) || Internal.special_aos(x)) ? Adapt.adapt(D, x) : From 619036028fdf3d0c580534771959ba8c1e5143b6 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 28 Aug 2024 16:17:55 -0400 Subject: [PATCH 3/6] test: basic tests for free-ing data --- Project.toml | 2 +- ext/MLDataDevicesMLUtilsExt.jl | 5 ++-- test/Project.toml | 2 ++ test/iterator_tests.jl | 53 ++++++++++++++++++++++++++++++++++ test/qa_tests.jl | 5 ++-- test/runtests.jl | 2 +- 6 files changed, 62 insertions(+), 7 deletions(-) create mode 100644 test/iterator_tests.jl diff --git a/Project.toml b/Project.toml index 35da279..0602650 100644 --- a/Project.toml +++ b/Project.toml @@ -49,7 +49,7 @@ ChainRulesCore = "1.23" FillArrays = "1" Functors = "0.4.8" GPUArrays = "10" -MLUtils = "0.4" +MLUtils = "0.4.4" Metal = "1" Preferences = "1.4" Random = "1.10" diff --git a/ext/MLDataDevicesMLUtilsExt.jl b/ext/MLDataDevicesMLUtilsExt.jl index 57db601..a3c083e 100644 --- a/ext/MLDataDevicesMLUtilsExt.jl +++ b/ext/MLDataDevicesMLUtilsExt.jl @@ -5,9 +5,8 @@ using MLDataDevices: MLDataDevices, AbstractDevice, AbstractDeviceIterator, CPUD Internal using MLUtils: MLUtils, DataLoader -for (dev) in (:CPU, :CUDA, :AMDGPU, :Metal, :oneAPI) - ldev = Symbol(dev, :Device) - @eval function (D::$(ldev))(dataloader::DataLoader) +for dev in (CPUDevice, CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice) + @eval function (D::$(dev))(dataloader::DataLoader) if dataloader.parallel if dataloader.buffer @warn "Using `buffer=true` for parallel DataLoader with automatic device \ diff --git a/test/Project.toml b/test/Project.toml index f770c7a..9914e0f 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -8,6 +8,7 @@ ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" +MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" @@ -28,6 +29,7 @@ ExplicitImports = "1.9.0" FillArrays = "1" ForwardDiff = "0.10.36" Functors = "0.4.8" +MLUtils = "0.4" Pkg = "1.10" Random = "1.10" RecursiveArrayTools = "3.8" diff --git a/test/iterator_tests.jl b/test/iterator_tests.jl new file mode 100644 index 0000000..78d4601 --- /dev/null +++ b/test/iterator_tests.jl @@ -0,0 +1,53 @@ +using MLDataDevices, MLUtils + +const BACKEND_GROUP = lowercase(get(ENV, "BACKEND_GROUP", "none")) + +if BACKEND_GROUP == "cuda" || BACKEND_GROUP == "all" + using LuxCUDA +end + +if BACKEND_GROUP == "amdgpu" || BACKEND_GROUP == "all" + using AMDGPU +end + +if BACKEND_GROUP == "metal" || BACKEND_GROUP == "all" + using Metal +end + +if BACKEND_GROUP == "oneapi" || BACKEND_GROUP == "all" + using oneAPI +end + +DEVICES = [CPUDevice, CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice] + +freed_if_can_be_freed(x) = freed_if_can_be_freed(get_device_type(x), x) +freed_if_can_be_freed(::Type{CPUDevice}, x) = true +function freed_if_can_be_freed(::Type, x) + try + Array(x) + return false + catch err + err isa ArgumentError && return true + rethrow() + end +end + +@testset "Device Iterator: $(dev_type)" for dev_type in DEVICES + dev = dev_type() + + !MLDataDevices.functional(dev) && continue + + @info "Testing Device Iterator for $(dev)..." + + @testset "Basic Device Iterator" begin + datalist = [rand(10) for _ in 1:10] + + prev_batch = nothing + for data in DeviceIterator(dev, datalist) + prev_batch === nothing || @test freed_if_can_be_freed(prev_batch) + prev_batch = data + @test size(data) == (10,) + @test get_device_type(data) == dev_type + end + end +end diff --git a/test/qa_tests.jl b/test/qa_tests.jl index 965e818..938908a 100644 --- a/test/qa_tests.jl +++ b/test/qa_tests.jl @@ -12,6 +12,7 @@ import FillArrays, RecursiveArrayTools, SparseArrays, Zygote @test check_no_self_qualified_accesses(MLDataDevices) === nothing @test check_all_explicit_imports_via_owners(MLDataDevices) === nothing @test check_all_qualified_accesses_via_owners(MLDataDevices) === nothing - @test_broken check_all_explicit_imports_are_public(MLDataDevices) === nothing # mostly upstream problems - @test_broken check_all_qualified_accesses_are_public(MLDataDevices) === nothing # mostly upstream problem + # mostly upstream problems + @test_broken check_all_explicit_imports_are_public(MLDataDevices) === nothing + @test_broken check_all_qualified_accesses_are_public(MLDataDevices) === nothing end diff --git a/test/runtests.jl b/test/runtests.jl index b9fb136..65cc190 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -28,7 +28,7 @@ end Test.@test true end + @safetestset "Iterator Tests" include("iterator_tests.jl") @safetestset "Misc Tests" include("misc_tests.jl") - @safetestset "QA Tests" include("qa_tests.jl") end From 4dfcfe333118f4e85e5043c001ca9185cf27f37f Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 28 Aug 2024 17:21:11 -0400 Subject: [PATCH 4/6] refactor: simplify parallel dataloader --- ext/MLDataDevicesMLUtilsExt.jl | 52 +++++++++------------------------- src/iterator.jl | 21 +++++--------- test/qa_tests.jl | 3 +- 3 files changed, 23 insertions(+), 53 deletions(-) diff --git a/ext/MLDataDevicesMLUtilsExt.jl b/ext/MLDataDevicesMLUtilsExt.jl index a3c083e..693e661 100644 --- a/ext/MLDataDevicesMLUtilsExt.jl +++ b/ext/MLDataDevicesMLUtilsExt.jl @@ -1,8 +1,7 @@ module MLDataDevicesMLUtilsExt -using MLDataDevices: MLDataDevices, AbstractDevice, AbstractDeviceIterator, CPUDevice, - CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice, DeviceIterator, - Internal +using MLDataDevices: MLDataDevices, AbstractDevice, CPUDevice, CUDADevice, AMDGPUDevice, + MetalDevice, oneAPIDevice, DeviceIterator using MLUtils: MLUtils, DataLoader for dev in (CPUDevice, CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice) @@ -12,44 +11,21 @@ for dev in (CPUDevice, CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice) @warn "Using `buffer=true` for parallel DataLoader with automatic device \ transfer is currently not implemented. Ignoring `buffer=true`." end - return ParallelDeviceDataLoader(D, dataloader) - end - return DeviceIterator(D, dataloader) - end -end - -# Parallel DataLoader that does the device transfer in the same task -struct ParallelDeviceDataLoader{D <: AbstractDevice, DL <: DataLoader} <: - AbstractDeviceIterator{D, DL} - dev::D - iterator::DL -end -# Mostly from https://github.com/JuliaML/MLUtils.jl/blob/main/src/eachobs.jl -function Base.iterate(c::ParallelDeviceDataLoader) - data = MLUtils.ObsView(c.iterator.data) + # Mostly from https://github.com/JuliaML/MLUtils.jl/blob/main/src/eachobs.jl + data = MLUtils.ObsView(dataloader.data) + data = dataloader.shuffle ? MLUtils.shuffleobs(data) : data + data = if dataloader.batchsize > 0 + MLUtils.BatchView( + data; dataloader.batchsize, dataloader.partial, dataloader.collate) + else + data + end - data = c.iterator.shuffle ? MLUtils.shuffleobs(c.iterator.rng, data) : data - data = if c.iterator.batchsize > 0 - MLUtils.BatchView( - data; c.iterator.batchsize, c.iterator.partial, c.iterator.collate) - else - data + return DeviceIterator(D, eachobsparallel(D, data)) + end + return DeviceIterator(D, dataloader) end - - iter = eachobsparallel(c.dev, data) - item = iterate(iter) - item === nothing && return nothing - dev_batch, next_state = item - return dev_batch, ((iter, next_state), dev_batch) -end - -function Base.iterate(::ParallelDeviceDataLoader, ((iter, state), prev_batch)) - item = iterate(iter, state) - item === nothing && return nothing - dev_batch, next_state = item - Internal.unsafe_free!(prev_batch) # free the previous batch - return dev_batch, ((iter, next_state), dev_batch) end function eachobsparallel(dev::AbstractDevice, data) diff --git a/src/iterator.jl b/src/iterator.jl index 47969be..3b4345e 100644 --- a/src/iterator.jl +++ b/src/iterator.jl @@ -1,18 +1,5 @@ -abstract type AbstractDeviceIterator{D <: AbstractDevice, I} end - -function Base.IteratorSize(::Type{AbstractDeviceIterator{D, I}}) where {D, I} - return Base.IteratorSize(I) -end -Base.length(c::AbstractDeviceIterator) = length(c.iterator) -Base.axes(c::AbstractDeviceIterator) = axes(c.iterator) - -function Base.IteratorEltype(::Type{AbstractDeviceIterator{D, I}}) where {D, I} - return Base.IteratorEltype(I) -end -Base.eltype(c::AbstractDeviceIterator) = eltype(c.iterator) - # This is based on CuIterator but generalized to work with any device -struct DeviceIterator{D, I} <: AbstractDeviceIterator{D, I} +struct DeviceIterator{D <: AbstractDevice, I} dev::D iterator::I end @@ -33,3 +20,9 @@ function Base.iterate(c::DeviceIterator, (state, prev_batch)) dev_batch = c.dev(batch) return dev_batch, (next_state, dev_batch) end + +Base.IteratorSize(::Type{DeviceIterator{D, I}}) where {D, I} = Base.IteratorSize(I) +Base.length(c::DeviceIterator) = length(c.iterator) +Base.axes(c::DeviceIterator) = axes(c.iterator) + +Base.IteratorEltype(::Type{DeviceIterator{D, I}}) where {D, I} = Base.EltypeUnknown() diff --git a/test/qa_tests.jl b/test/qa_tests.jl index 938908a..b5e4cb6 100644 --- a/test/qa_tests.jl +++ b/test/qa_tests.jl @@ -11,7 +11,8 @@ import FillArrays, RecursiveArrayTools, SparseArrays, Zygote @test check_no_stale_explicit_imports(MLDataDevices) === nothing @test check_no_self_qualified_accesses(MLDataDevices) === nothing @test check_all_explicit_imports_via_owners(MLDataDevices) === nothing - @test check_all_qualified_accesses_via_owners(MLDataDevices) === nothing + @test check_all_qualified_accesses_via_owners( + MLDataDevices; ignore=(:SparseArrays,)) === nothing # mostly upstream problems @test_broken check_all_explicit_imports_are_public(MLDataDevices) === nothing @test_broken check_all_qualified_accesses_are_public(MLDataDevices) === nothing From 401b3fc41b05f674cb6acdace92c42cf4df65180 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 28 Aug 2024 17:52:26 -0400 Subject: [PATCH 5/6] test: DataLoader aggressive freeing --- test/iterator_tests.jl | 53 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 53 insertions(+) diff --git a/test/iterator_tests.jl b/test/iterator_tests.jl index 78d4601..dbb4d7a 100644 --- a/test/iterator_tests.jl +++ b/test/iterator_tests.jl @@ -50,4 +50,57 @@ end @test get_device_type(data) == dev_type end end + + @testset "DataLoader: parallel=$parallel" for parallel in (true, false) + X = rand(Float64, 3, 33) + pre = DataLoader(dev(X); batchsize=13, shuffle=false) + post = DataLoader(X; batchsize=13, shuffle=false) |> dev + + for epoch in 1:2 + prev_pre, prev_post = nothing, nothing + for (p, q) in zip(pre, post) + @test get_device_type(p) == dev_type + @test get_device_type(q) == dev_type + @test p ≈ q + + dev_type === CPUDevice && continue + + prev_pre === nothing || @test !freed_if_can_be_freed(prev_pre) + prev_pre = p + + prev_post === nothing || @test freed_if_can_be_freed(prev_post) + prev_post = q + end + end + + Y = rand(Float64, 1, 33) + pre = DataLoader((; x=dev(X), y=dev(Y)); batchsize=13, shuffle=false) + post = DataLoader((; x=X, y=Y); batchsize=13, shuffle=false) |> dev + + for epoch in 1:2 + prev_pre, prev_post = nothing, nothing + for (p, q) in zip(pre, post) + @test get_device_type(p.x) == dev_type + @test get_device_type(p.y) == dev_type + @test get_device_type(q.x) == dev_type + @test get_device_type(q.y) == dev_type + @test p.x ≈ q.x + @test p.y ≈ q.y + + dev_type === CPUDevice && continue + + if prev_pre !== nothing + @test !freed_if_can_be_freed(prev_pre.x) + @test !freed_if_can_be_freed(prev_pre.y) + end + prev_pre = p + + if prev_post !== nothing + @test freed_if_can_be_freed(prev_post.x) + @test freed_if_can_be_freed(prev_post.y) + end + prev_post = q + end + end + end end From 7904ff513d86084735f899f294d954b7140ebc65 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 28 Aug 2024 18:02:56 -0400 Subject: [PATCH 6/6] docs: add docstrings for `DeviceIterator` --- src/iterator.jl | 47 ++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 46 insertions(+), 1 deletion(-) diff --git a/src/iterator.jl b/src/iterator.jl index 3b4345e..e0b686e 100644 --- a/src/iterator.jl +++ b/src/iterator.jl @@ -1,4 +1,49 @@ -# This is based on CuIterator but generalized to work with any device +""" + DeviceIterator(dev::AbstractDevice, iterator) + +Create a `DeviceIterator` that iterates through the provided `iterator` via `iterate`. Upon +each iteration, the current batch is copied to the device `dev`, and the previous iteration +is marked as freeable from GPU memory (via `unsafe_free!`) (no-op for a CPU device). + +The conversion follows the same semantics as `dev()`. + +!!! tip "Similarity to `CUDA.CuIterator`" + + The design inspiration was taken from `CUDA.CuIterator` and was generalized to work with + other backends and more complex iterators (using `Functors`). + +!!! tip "`MLUtils.DataLoader`" + + Calling `dev(::MLUtils.DataLoader)` will automatically convert the dataloader to use the + same semantics as `DeviceIterator`. This is generally preferred over looping over the + dataloader directly and transferring the data to the device. + +## Examples + +The following was run on a computer with an NVIDIA GPU. + +```julia-repl +julia> using MLDataDevices, MLUtils + +julia> X = rand(Float64, 3, 33); + +julia> dataloader = DataLoader(X; batchsize=13, shuffle=false); + +julia> for (i, x) in enumerate(dataloader) + @show i, summary(x) + end +(i, summary(x)) = (1, "3×13 Matrix{Float64}") +(i, summary(x)) = (2, "3×13 Matrix{Float64}") +(i, summary(x)) = (3, "3×7 Matrix{Float64}") + +julia> for (i, x) in enumerate(CUDADevice()(dataloader)) + @show i, summary(x) + end +(i, summary(x)) = (1, "3×13 CuArray{Float32, 2, CUDA.DeviceMemory}") +(i, summary(x)) = (2, "3×13 CuArray{Float32, 2, CUDA.DeviceMemory}") +(i, summary(x)) = (3, "3×7 CuArray{Float32, 2, CUDA.DeviceMemory}") +``` +""" struct DeviceIterator{D <: AbstractDevice, I} dev::D iterator::I