diff --git a/Project.toml b/Project.toml index 9106f79..0602650 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "MLDataDevices" uuid = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" authors = ["Avik Pal and contributors"] -version = "1.0.3" +version = "1.1.0" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" @@ -16,6 +16,7 @@ AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" +MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" Metal = "dde4c033-4e86-420c-a63e-0dd931031962" RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" @@ -30,6 +31,7 @@ MLDataDevicesAMDGPUExt = "AMDGPU" MLDataDevicesCUDAExt = "CUDA" MLDataDevicesFillArraysExt = "FillArrays" MLDataDevicesGPUArraysExt = "GPUArrays" +MLDataDevicesMLUtilsExt = "MLUtils" MLDataDevicesMetalExt = ["GPUArrays", "Metal"] MLDataDevicesRecursiveArrayToolsExt = "RecursiveArrayTools" MLDataDevicesReverseDiffExt = "ReverseDiff" @@ -47,6 +49,7 @@ ChainRulesCore = "1.23" FillArrays = "1" Functors = "0.4.8" GPUArrays = "10" +MLUtils = "0.4.4" Metal = "1" Preferences = "1.4" Random = "1.10" diff --git a/ext/MLDataDevicesAMDGPUExt.jl b/ext/MLDataDevicesAMDGPUExt.jl index e539a15..4014b2e 100644 --- a/ext/MLDataDevicesAMDGPUExt.jl +++ b/ext/MLDataDevicesAMDGPUExt.jl @@ -64,8 +64,13 @@ function MLDataDevices.set_device!(::Type{AMDGPUDevice}, ::Nothing, rank::Intege return MLDataDevices.set_device!(AMDGPUDevice, id) end +# unsafe_free! +function Internal.unsafe_free_internal!(::Type{AMDGPUDevice}, x::AbstractArray) + AMDGPU.unsafe_free!(x) + return +end + # Device Transfer -## To GPU Adapt.adapt_storage(::AMDGPUDevice{Nothing}, x::AbstractArray) = AMDGPU.roc(x) function Adapt.adapt_storage(to::AMDGPUDevice, x::AbstractArray) old_dev = AMDGPU.device() # remember the current device diff --git a/ext/MLDataDevicesCUDAExt.jl b/ext/MLDataDevicesCUDAExt.jl index cc4cde4..3492440 100644 --- a/ext/MLDataDevicesCUDAExt.jl +++ b/ext/MLDataDevicesCUDAExt.jl @@ -42,6 +42,12 @@ function MLDataDevices.set_device!(::Type{CUDADevice}, ::Nothing, rank::Integer) return MLDataDevices.set_device!(CUDADevice, id) end +# unsafe_free! +function Internal.unsafe_free_internal!(::Type{CUDADevice}, x::AbstractArray) + CUDA.unsafe_free!(x) + return +end + # Device Transfer Adapt.adapt_storage(::CUDADevice{Nothing}, x::AbstractArray) = CUDA.cu(x) function Adapt.adapt_storage(to::CUDADevice, x::AbstractArray) diff --git a/ext/MLDataDevicesMLUtilsExt.jl b/ext/MLDataDevicesMLUtilsExt.jl new file mode 100644 index 0000000..693e661 --- /dev/null +++ b/ext/MLDataDevicesMLUtilsExt.jl @@ -0,0 +1,38 @@ +module MLDataDevicesMLUtilsExt + +using MLDataDevices: MLDataDevices, AbstractDevice, CPUDevice, CUDADevice, AMDGPUDevice, + MetalDevice, oneAPIDevice, DeviceIterator +using MLUtils: MLUtils, DataLoader + +for dev in (CPUDevice, CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice) + @eval function (D::$(dev))(dataloader::DataLoader) + if dataloader.parallel + if dataloader.buffer + @warn "Using `buffer=true` for parallel DataLoader with automatic device \ + transfer is currently not implemented. Ignoring `buffer=true`." + end + + # Mostly from https://github.com/JuliaML/MLUtils.jl/blob/main/src/eachobs.jl + data = MLUtils.ObsView(dataloader.data) + data = dataloader.shuffle ? MLUtils.shuffleobs(data) : data + data = if dataloader.batchsize > 0 + MLUtils.BatchView( + data; dataloader.batchsize, dataloader.partial, dataloader.collate) + else + data + end + + return DeviceIterator(D, eachobsparallel(D, data)) + end + return DeviceIterator(D, dataloader) + end +end + +function eachobsparallel(dev::AbstractDevice, data) + return MLUtils.Loader(1:MLUtils.numobs(data)) do ch, i + obs = MLUtils.getobs(data, i) + put!(ch, dev(obs)) + end +end + +end diff --git a/ext/MLDataDevicesMetalExt.jl b/ext/MLDataDevicesMetalExt.jl index 87d0b0e..e5eb16d 100644 --- a/ext/MLDataDevicesMetalExt.jl +++ b/ext/MLDataDevicesMetalExt.jl @@ -18,8 +18,13 @@ Internal.get_device(::MtlArray) = MetalDevice() Internal.get_device_type(::MtlArray) = MetalDevice +# unsafe_free! +function Internal.unsafe_free_internal!(::Type{MetalDevice}, x::AbstractArray) + Metal.unsafe_free!(x) + return +end + # Device Transfer -## To GPU Adapt.adapt_storage(::MetalDevice, x::AbstractArray) = Metal.mtl(x) end diff --git a/ext/MLDataDevicesoneAPIExt.jl b/ext/MLDataDevicesoneAPIExt.jl index 4bda871..75fc2f0 100644 --- a/ext/MLDataDevicesoneAPIExt.jl +++ b/ext/MLDataDevicesoneAPIExt.jl @@ -29,8 +29,13 @@ Internal.get_device(::oneArray) = oneAPIDevice() Internal.get_device_type(::oneArray) = oneAPIDevice +# unsafe_free! +function Internal.unsafe_free_internal!(::Type{oneAPIDevice}, x::AbstractArray) + oneAPI.unsafe_free!(x) + return +end + # Device Transfer -## To GPU for (T1, T2) in ((Float64, Float32), (ComplexF64, ComplexF32)) @eval function Adapt.adapt_storage(::oneAPIDevice, x::AbstractArray{$(T1)}) if !SUPPORTS_FP64[oneAPI.device()] diff --git a/src/MLDataDevices.jl b/src/MLDataDevices.jl index b7636db..574fea4 100644 --- a/src/MLDataDevices.jl +++ b/src/MLDataDevices.jl @@ -12,6 +12,7 @@ abstract type AbstractDevice <: Function end abstract type AbstractGPUDevice <: AbstractDevice end include("public.jl") +include("iterator.jl") include("internal.jl") export gpu_backend!, supported_gpu_backends, reset_gpu_device! @@ -21,4 +22,6 @@ export gpu_device, cpu_device export CPUDevice, CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice export get_device, get_device_type +export DeviceIterator + end diff --git a/src/internal.jl b/src/internal.jl index e894649..f2c807e 100644 --- a/src/internal.jl +++ b/src/internal.jl @@ -1,5 +1,6 @@ module Internal +using Functors: fmap using Preferences: load_preference using Random: AbstractRNG using UnrolledUtilities: unrolled_mapreduce @@ -149,4 +150,16 @@ for op in (:get_device, :get_device_type) end end +function unsafe_free_internal!(x::AbstractArray) + unsafe_free_internal!(MLDataDevices.get_device_type(x), x) + return +end +unsafe_free_internal!(::Type, x::AbstractArray) = nothing +unsafe_free_internal!(_) = nothing + +function unsafe_free!(x) + fmap(unsafe_free_internal!, x) + return +end + end diff --git a/src/iterator.jl b/src/iterator.jl new file mode 100644 index 0000000..e0b686e --- /dev/null +++ b/src/iterator.jl @@ -0,0 +1,73 @@ +""" + DeviceIterator(dev::AbstractDevice, iterator) + +Create a `DeviceIterator` that iterates through the provided `iterator` via `iterate`. Upon +each iteration, the current batch is copied to the device `dev`, and the previous iteration +is marked as freeable from GPU memory (via `unsafe_free!`) (no-op for a CPU device). + +The conversion follows the same semantics as `dev()`. + +!!! tip "Similarity to `CUDA.CuIterator`" + + The design inspiration was taken from `CUDA.CuIterator` and was generalized to work with + other backends and more complex iterators (using `Functors`). + +!!! tip "`MLUtils.DataLoader`" + + Calling `dev(::MLUtils.DataLoader)` will automatically convert the dataloader to use the + same semantics as `DeviceIterator`. This is generally preferred over looping over the + dataloader directly and transferring the data to the device. + +## Examples + +The following was run on a computer with an NVIDIA GPU. + +```julia-repl +julia> using MLDataDevices, MLUtils + +julia> X = rand(Float64, 3, 33); + +julia> dataloader = DataLoader(X; batchsize=13, shuffle=false); + +julia> for (i, x) in enumerate(dataloader) + @show i, summary(x) + end +(i, summary(x)) = (1, "3×13 Matrix{Float64}") +(i, summary(x)) = (2, "3×13 Matrix{Float64}") +(i, summary(x)) = (3, "3×7 Matrix{Float64}") + +julia> for (i, x) in enumerate(CUDADevice()(dataloader)) + @show i, summary(x) + end +(i, summary(x)) = (1, "3×13 CuArray{Float32, 2, CUDA.DeviceMemory}") +(i, summary(x)) = (2, "3×13 CuArray{Float32, 2, CUDA.DeviceMemory}") +(i, summary(x)) = (3, "3×7 CuArray{Float32, 2, CUDA.DeviceMemory}") +``` +""" +struct DeviceIterator{D <: AbstractDevice, I} + dev::D + iterator::I +end + +function Base.iterate(c::DeviceIterator) + item = iterate(c.iterator) + item === nothing && return nothing + batch, next_state = item + dev_batch = c.dev(batch) + return dev_batch, (next_state, dev_batch) +end + +function Base.iterate(c::DeviceIterator, (state, prev_batch)) + item = iterate(c.iterator, state) + item === nothing && return nothing + batch, next_state = item + Internal.unsafe_free!(prev_batch) # free the previous batch + dev_batch = c.dev(batch) + return dev_batch, (next_state, dev_batch) +end + +Base.IteratorSize(::Type{DeviceIterator{D, I}}) where {D, I} = Base.IteratorSize(I) +Base.length(c::DeviceIterator) = length(c.iterator) +Base.axes(c::DeviceIterator) = axes(c.iterator) + +Base.IteratorEltype(::Type{DeviceIterator{D, I}}) where {D, I} = Base.EltypeUnknown() diff --git a/src/public.jl b/src/public.jl index ac53ee5..d7a7d27 100644 --- a/src/public.jl +++ b/src/public.jl @@ -293,7 +293,7 @@ end # For all other types we rely on fmap which means we lose type stability. # For Lux, typically models only has these 3 datastructures so we should be mostly fine. for (dev) in (:CPU, :CUDA, :AMDGPU, :Metal, :oneAPI) - ldev = Symbol("$(dev)Device") + ldev = Symbol(dev, :Device) @eval begin function (D::$(ldev))(x::AbstractArray{T}) where {T} return (isbitstype(T) || Internal.special_aos(x)) ? Adapt.adapt(D, x) : diff --git a/test/Project.toml b/test/Project.toml index f770c7a..9914e0f 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -8,6 +8,7 @@ ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" +MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" @@ -28,6 +29,7 @@ ExplicitImports = "1.9.0" FillArrays = "1" ForwardDiff = "0.10.36" Functors = "0.4.8" +MLUtils = "0.4" Pkg = "1.10" Random = "1.10" RecursiveArrayTools = "3.8" diff --git a/test/iterator_tests.jl b/test/iterator_tests.jl new file mode 100644 index 0000000..dbb4d7a --- /dev/null +++ b/test/iterator_tests.jl @@ -0,0 +1,106 @@ +using MLDataDevices, MLUtils + +const BACKEND_GROUP = lowercase(get(ENV, "BACKEND_GROUP", "none")) + +if BACKEND_GROUP == "cuda" || BACKEND_GROUP == "all" + using LuxCUDA +end + +if BACKEND_GROUP == "amdgpu" || BACKEND_GROUP == "all" + using AMDGPU +end + +if BACKEND_GROUP == "metal" || BACKEND_GROUP == "all" + using Metal +end + +if BACKEND_GROUP == "oneapi" || BACKEND_GROUP == "all" + using oneAPI +end + +DEVICES = [CPUDevice, CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice] + +freed_if_can_be_freed(x) = freed_if_can_be_freed(get_device_type(x), x) +freed_if_can_be_freed(::Type{CPUDevice}, x) = true +function freed_if_can_be_freed(::Type, x) + try + Array(x) + return false + catch err + err isa ArgumentError && return true + rethrow() + end +end + +@testset "Device Iterator: $(dev_type)" for dev_type in DEVICES + dev = dev_type() + + !MLDataDevices.functional(dev) && continue + + @info "Testing Device Iterator for $(dev)..." + + @testset "Basic Device Iterator" begin + datalist = [rand(10) for _ in 1:10] + + prev_batch = nothing + for data in DeviceIterator(dev, datalist) + prev_batch === nothing || @test freed_if_can_be_freed(prev_batch) + prev_batch = data + @test size(data) == (10,) + @test get_device_type(data) == dev_type + end + end + + @testset "DataLoader: parallel=$parallel" for parallel in (true, false) + X = rand(Float64, 3, 33) + pre = DataLoader(dev(X); batchsize=13, shuffle=false) + post = DataLoader(X; batchsize=13, shuffle=false) |> dev + + for epoch in 1:2 + prev_pre, prev_post = nothing, nothing + for (p, q) in zip(pre, post) + @test get_device_type(p) == dev_type + @test get_device_type(q) == dev_type + @test p ≈ q + + dev_type === CPUDevice && continue + + prev_pre === nothing || @test !freed_if_can_be_freed(prev_pre) + prev_pre = p + + prev_post === nothing || @test freed_if_can_be_freed(prev_post) + prev_post = q + end + end + + Y = rand(Float64, 1, 33) + pre = DataLoader((; x=dev(X), y=dev(Y)); batchsize=13, shuffle=false) + post = DataLoader((; x=X, y=Y); batchsize=13, shuffle=false) |> dev + + for epoch in 1:2 + prev_pre, prev_post = nothing, nothing + for (p, q) in zip(pre, post) + @test get_device_type(p.x) == dev_type + @test get_device_type(p.y) == dev_type + @test get_device_type(q.x) == dev_type + @test get_device_type(q.y) == dev_type + @test p.x ≈ q.x + @test p.y ≈ q.y + + dev_type === CPUDevice && continue + + if prev_pre !== nothing + @test !freed_if_can_be_freed(prev_pre.x) + @test !freed_if_can_be_freed(prev_pre.y) + end + prev_pre = p + + if prev_post !== nothing + @test freed_if_can_be_freed(prev_post.x) + @test freed_if_can_be_freed(prev_post.y) + end + prev_post = q + end + end + end +end diff --git a/test/qa_tests.jl b/test/qa_tests.jl index 965e818..b5e4cb6 100644 --- a/test/qa_tests.jl +++ b/test/qa_tests.jl @@ -11,7 +11,9 @@ import FillArrays, RecursiveArrayTools, SparseArrays, Zygote @test check_no_stale_explicit_imports(MLDataDevices) === nothing @test check_no_self_qualified_accesses(MLDataDevices) === nothing @test check_all_explicit_imports_via_owners(MLDataDevices) === nothing - @test check_all_qualified_accesses_via_owners(MLDataDevices) === nothing - @test_broken check_all_explicit_imports_are_public(MLDataDevices) === nothing # mostly upstream problems - @test_broken check_all_qualified_accesses_are_public(MLDataDevices) === nothing # mostly upstream problem + @test check_all_qualified_accesses_via_owners( + MLDataDevices; ignore=(:SparseArrays,)) === nothing + # mostly upstream problems + @test_broken check_all_explicit_imports_are_public(MLDataDevices) === nothing + @test_broken check_all_qualified_accesses_are_public(MLDataDevices) === nothing end diff --git a/test/runtests.jl b/test/runtests.jl index b9fb136..65cc190 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -28,7 +28,7 @@ end Test.@test true end + @safetestset "Iterator Tests" include("iterator_tests.jl") @safetestset "Misc Tests" include("misc_tests.jl") - @safetestset "QA Tests" include("qa_tests.jl") end