Skip to content

Commit

Permalink
Allow cpu(::DataLoader) (#2388)
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott authored Mar 5, 2024
1 parent 7e7d4fc commit da11bf2
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 2 deletions.
15 changes: 14 additions & 1 deletion src/functor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -403,8 +403,9 @@ function _metal end

"""
gpu(data::DataLoader)
cpu(data::DataLoader)
Transforms a given `DataLoader` to apply `gpu` to each batch of data,
Transforms a given `DataLoader` to apply `gpu` or `cpu` to each batch of data,
when iterated over. (If no GPU is available, this does nothing.)
# Example
Expand Down Expand Up @@ -456,6 +457,18 @@ function gpu(d::MLUtils.DataLoader)
)
end

function cpu(d::MLUtils.DataLoader)
MLUtils.DataLoader(MLUtils.mapobs(cpu, d.data),
d.batchsize,
d.buffer,
d.partial,
d.shuffle,
d.parallel,
d.collate,
d.rng,
)
end

# Defining device interfaces.
"""
Flux.AbstractDevice <: Function
Expand Down
6 changes: 6 additions & 0 deletions test/data.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
using Flux: DataLoader
using Random

@testset "DataLoader" begin
Expand All @@ -14,6 +15,11 @@ using Random
@test batches[2] == X[:,3:4]
@test batches[3] == X[:,5:5]

d_cpu = d |> cpu # does nothing but shouldn't error
@test d_cpu isa DataLoader
@test first(d_cpu) == X[:,1:2]
@test length(d_cpu) == 3

d = DataLoader(X, batchsize=2, partial=false)
# @inferred first(d)
batches = collect(d)
Expand Down
5 changes: 4 additions & 1 deletion test/ext_cuda/cuda.jl
Original file line number Diff line number Diff line change
Expand Up @@ -182,11 +182,14 @@ end
X = randn(Float64, 3, 33)
pre1 = Flux.DataLoader(X |> gpu; batchsize=13, shuffle=false)
post1 = Flux.DataLoader(X; batchsize=13, shuffle=false) |> gpu
rev1 = pre1 |> cpu # inverse operation
for epoch in 1:2
for (p, q) in zip(pre1, post1)
for (p, q, a) in zip(pre1, post1, rev1)
@test p isa CuArray{Float32}
@test q isa CuArray{Float32}
@test p q
@test a isa Array{Float32}
@test a Array(p)
end
end

Expand Down

0 comments on commit da11bf2

Please sign in to comment.