diff --git a/lib/MLDataDevices/Project.toml b/lib/MLDataDevices/Project.toml index 9fccfa7d6..f087936bd 100644 --- a/lib/MLDataDevices/Project.toml +++ b/lib/MLDataDevices/Project.toml @@ -1,7 +1,7 @@ name = "MLDataDevices" uuid = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" authors = ["Avik Pal and contributors"] -version = "1.6.0" +version = "1.6.1" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/lib/MLDataDevices/ext/MLDataDevicesMLUtilsExt.jl b/lib/MLDataDevices/ext/MLDataDevicesMLUtilsExt.jl index d4d106457..2c5b9ced7 100644 --- a/lib/MLDataDevices/ext/MLDataDevicesMLUtilsExt.jl +++ b/lib/MLDataDevices/ext/MLDataDevicesMLUtilsExt.jl @@ -1,11 +1,14 @@ module MLDataDevicesMLUtilsExt +using Adapt: Adapt using MLDataDevices: MLDataDevices, AbstractDevice, CPUDevice, CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice, ReactantDevice, DeviceIterator using MLUtils: MLUtils, DataLoader +MLDataDevices.isleaf(::DataLoader) = true + for dev in (CPUDevice, CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice, ReactantDevice) - @eval function (D::$(dev))(dataloader::DataLoader) + @eval function Adapt.adapt_structure(D::$(dev), dataloader::DataLoader) if dataloader.parallel if dataloader.buffer @warn "Using `buffer=true` for parallel DataLoader with automatic device \