Skip to content

Commit

Permalink
fix: dataloaders use adapt_structure
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Nov 15, 2024
1 parent bbf5033 commit edf52e9
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 2 deletions.
2 changes: 1 addition & 1 deletion lib/MLDataDevices/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MLDataDevices"
uuid = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "1.6.0"
version = "1.6.1"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down
5 changes: 4 additions & 1 deletion lib/MLDataDevices/ext/MLDataDevicesMLUtilsExt.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
module MLDataDevicesMLUtilsExt

using Adapt: Adapt
using MLDataDevices: MLDataDevices, AbstractDevice, CPUDevice, CUDADevice, AMDGPUDevice,
MetalDevice, oneAPIDevice, ReactantDevice, DeviceIterator
using MLUtils: MLUtils, DataLoader

MLDataDevices.isleaf(::DataLoader) = true

for dev in (CPUDevice, CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice, ReactantDevice)
@eval function (D::$(dev))(dataloader::DataLoader)
@eval function Adapt.adapt_structure(D::$(dev), dataloader::DataLoader)
if dataloader.parallel
if dataloader.buffer
@warn "Using `buffer=true` for parallel DataLoader with automatic device \
Expand Down

0 comments on commit edf52e9

Please sign in to comment.