diff --git a/lib/LuxCore/Project.toml b/lib/LuxCore/Project.toml index 93e1a0eff..acb9f2ec1 100644 --- a/lib/LuxCore/Project.toml +++ b/lib/LuxCore/Project.toml @@ -1,7 +1,7 @@ name = "LuxCore" uuid = "bb33d45b-7691-41d6-9220-0943567d0623" authors = ["Avik Pal and contributors"] -version = "1.2.0" +version = "1.2.1" [deps] Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" @@ -9,6 +9,7 @@ DispatchDoctor = "8d63f2c5-f18a-4cf2-ba9d-b3f60fc568c8" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" [weakdeps] +Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" @@ -25,11 +26,12 @@ LuxCoreArrayInterfaceTrackerExt = ["ArrayInterface", "Tracker"] LuxCoreChainRulesCoreExt = "ChainRulesCore" LuxCoreEnzymeCoreExt = "EnzymeCore" LuxCoreFunctorsExt = "Functors" -LuxCoreMLDataDevicesExt = "MLDataDevices" +LuxCoreMLDataDevicesExt = ["Adapt", "MLDataDevices"] LuxCoreReactantExt = "Reactant" LuxCoreSetfieldExt = "Setfield" [compat] +Adapt = "4.1" ArrayInterface = "7.9" ChainRulesCore = "1.24" Compat = "4.16" diff --git a/lib/LuxCore/ext/LuxCoreMLDataDevicesExt.jl b/lib/LuxCore/ext/LuxCoreMLDataDevicesExt.jl index 1a2dbbd69..14c842eb1 100644 --- a/lib/LuxCore/ext/LuxCoreMLDataDevicesExt.jl +++ b/lib/LuxCore/ext/LuxCoreMLDataDevicesExt.jl @@ -1,16 +1,16 @@ module LuxCoreMLDataDevicesExt -using LuxCore: LuxCore -using MLDataDevices: MLDataDevices +using Adapt: Adapt +using LuxCore: LuxCore, AbstractLuxLayer +using MLDataDevices: MLDataDevices, AbstractDevice -for (dev) in (:CPU, :CUDA, :AMDGPU, :Metal, :oneAPI) - ldev = Symbol(dev, :Device) - @eval function (::MLDataDevices.$(ldev))(NN::LuxCore.AbstractLuxLayer) - @warn "Lux layers are stateless and hence don't participate in device transfers. \ - Apply this function on the parameters and states generated using \ - `LuxCore.setup`." - return NN - end +MLDataDevices.isleaf(::AbstractLuxLayer) = true + +function Adapt.adapt_storage(::AbstractDevice, x::AbstractLuxLayer) + @warn "Lux layers are stateless and hence don't participate in device transfers. \ + Apply this function on the parameters and states generated using \ + `LuxCore.setup`." + return x end end diff --git a/lib/MLDataDevices/Project.toml b/lib/MLDataDevices/Project.toml index b3a7ebf07..c72c2fe70 100644 --- a/lib/MLDataDevices/Project.toml +++ b/lib/MLDataDevices/Project.toml @@ -1,7 +1,7 @@ name = "MLDataDevices" uuid = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" authors = ["Avik Pal and contributors"] -version = "1.6.2" +version = "1.6.3" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/lib/MLDataDevices/ext/MLDataDevicesChainRulesCoreExt.jl b/lib/MLDataDevices/ext/MLDataDevicesChainRulesCoreExt.jl index 2b230cd87..dcfde2876 100644 --- a/lib/MLDataDevices/ext/MLDataDevicesChainRulesCoreExt.jl +++ b/lib/MLDataDevices/ext/MLDataDevicesChainRulesCoreExt.jl @@ -12,7 +12,8 @@ using MLDataDevices: AbstractDevice, UnknownDevice, get_device, get_device_type, @non_differentiable cpu_device(::Any...) @non_differentiable reactant_device(::Any...) -function ChainRulesCore.rrule(::typeof(Adapt.adapt), to::AbstractDevice, x::AbstractArray) +function ChainRulesCore.rrule( + ::typeof(Adapt.adapt_storage), to::AbstractDevice, x::AbstractArray) dev = get_device(x) y = Adapt.adapt_storage(to, x) if dev === nothing || dev isa UnknownDevice diff --git a/lib/MLDataDevices/ext/MLDataDevicesChainRulesExt.jl b/lib/MLDataDevices/ext/MLDataDevicesChainRulesExt.jl index 25b05c01d..551586440 100644 --- a/lib/MLDataDevices/ext/MLDataDevicesChainRulesExt.jl +++ b/lib/MLDataDevices/ext/MLDataDevicesChainRulesExt.jl @@ -1,16 +1,13 @@ module MLDataDevicesChainRulesExt using Adapt: Adapt -using MLDataDevices: CPUDevice, CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice, - ReactantDevice +using MLDataDevices: CPUDevice, AbstractDevice, Internal using ChainRules: OneElement Adapt.adapt_structure(::CPUDevice, x::OneElement) = x +Adapt.adapt_structure(to::AbstractDevice, x::OneElement) = Adapt.adapt(to, collect(x)) -for dev in (CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice, ReactantDevice, - CUDADevice{Nothing}, AMDGPUDevice{Nothing}) - # use `@eval` to avoid ambiguity with adapt_storage(::CUDADevice, ::AbstractArray) - @eval Adapt.adapt_structure(to::$(dev), x::OneElement) = Adapt.adapt(to, collect(x)) -end +Internal.get_device(::OneElement) = CPUDevice() +Internal.get_device_type(::OneElement) = CPUDevice end diff --git a/lib/MLDataDevices/ext/MLDataDevicesFillArraysExt.jl b/lib/MLDataDevices/ext/MLDataDevicesFillArraysExt.jl index 5a88241e6..2b9c1875c 100644 --- a/lib/MLDataDevices/ext/MLDataDevicesFillArraysExt.jl +++ b/lib/MLDataDevices/ext/MLDataDevicesFillArraysExt.jl @@ -1,10 +1,13 @@ module MLDataDevicesFillArraysExt using Adapt: Adapt -using FillArrays: FillArrays, AbstractFill -using MLDataDevices: MLDataDevices, CPUDevice, AbstractDevice +using FillArrays: AbstractFill +using MLDataDevices: CPUDevice, AbstractDevice, Internal Adapt.adapt_structure(::CPUDevice, x::AbstractFill) = x Adapt.adapt_structure(to::AbstractDevice, x::AbstractFill) = Adapt.adapt(to, collect(x)) +Internal.get_device(::AbstractFill) = CPUDevice() +Internal.get_device_type(::AbstractFill) = CPUDevice + end diff --git a/lib/MLDataDevices/ext/MLDataDevicesMLUtilsExt.jl b/lib/MLDataDevices/ext/MLDataDevicesMLUtilsExt.jl index 2c5b9ced7..2c757b983 100644 --- a/lib/MLDataDevices/ext/MLDataDevicesMLUtilsExt.jl +++ b/lib/MLDataDevices/ext/MLDataDevicesMLUtilsExt.jl @@ -1,34 +1,31 @@ module MLDataDevicesMLUtilsExt using Adapt: Adapt -using MLDataDevices: MLDataDevices, AbstractDevice, CPUDevice, CUDADevice, AMDGPUDevice, - MetalDevice, oneAPIDevice, ReactantDevice, DeviceIterator +using MLDataDevices: MLDataDevices, AbstractDevice, DeviceIterator using MLUtils: MLUtils, DataLoader MLDataDevices.isleaf(::DataLoader) = true -for dev in (CPUDevice, CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice, ReactantDevice) - @eval function Adapt.adapt_structure(D::$(dev), dataloader::DataLoader) - if dataloader.parallel - if dataloader.buffer - @warn "Using `buffer=true` for parallel DataLoader with automatic device \ - transfer is currently not implemented. Ignoring `buffer=true`." - end - - # Mostly from https://github.com/JuliaML/MLUtils.jl/blob/main/src/eachobs.jl - data = MLUtils.ObsView(dataloader.data) - data = dataloader.shuffle ? MLUtils.shuffleobs(data) : data - data = if dataloader.batchsize > 0 - MLUtils.BatchView( - data; dataloader.batchsize, dataloader.partial, dataloader.collate) - else - data - end +function Adapt.adapt_structure(dev::AbstractDevice, dataloader::DataLoader) + if dataloader.parallel + if dataloader.buffer + @warn "Using `buffer=true` for parallel DataLoader with automatic device \ + transfer is currently not implemented. Ignoring `buffer=true`." + end - return DeviceIterator(identity, eachobsparallel(D, data)) + # Mostly from https://github.com/JuliaML/MLUtils.jl/blob/main/src/eachobs.jl + data = MLUtils.ObsView(dataloader.data) + data = dataloader.shuffle ? MLUtils.shuffleobs(data) : data + data = if dataloader.batchsize > 0 + MLUtils.BatchView( + data; dataloader.batchsize, dataloader.partial, dataloader.collate) + else + data end - return DeviceIterator(D, dataloader) + + return DeviceIterator(identity, eachobsparallel(dev, data)) end + return DeviceIterator(dev, dataloader) end function eachobsparallel(dev::AbstractDevice, data) diff --git a/lib/MLDataDevices/ext/MLDataDevicesReactantExt.jl b/lib/MLDataDevices/ext/MLDataDevicesReactantExt.jl index aaf0a4952..4e55940d0 100644 --- a/lib/MLDataDevices/ext/MLDataDevicesReactantExt.jl +++ b/lib/MLDataDevices/ext/MLDataDevicesReactantExt.jl @@ -37,7 +37,7 @@ function Adapt.adapt_storage( dev::ReactantDevice, x::AbstractArray{<:Reactant.ReactantPrimitive}) @warn "ReactantDevice got an array on device: $(get_device_type(x)). We will have to \ transfer this via CPU." maxlog=1 - return Adapt.adapt_storage(dev, Adapt.adapt_storage(CPUDevice(), x)) + return Adapt.adapt(dev, Adapt.adapt(CPUDevice(), x)) end function Adapt.adapt_storage(dev::ReactantDevice, x::Array{<:Reactant.ReactantPrimitive}) @@ -46,4 +46,12 @@ function Adapt.adapt_storage(dev::ReactantDevice, x::Array{<:Reactant.ReactantPr return ConcreteRArray(x; client, device) end +# XXX: Check for client and device and use faster implementation if possible +function Adapt.adapt_storage(dev::ReactantDevice, x::ConcreteRArray) + dev.client === missing && dev.device === missing && return x + @warn "Fetching `client` and `device` from a ConcreteRArray hasn't been implemented \ + yet. Using slow fallback path." maxlog=1 + return Adapt.adapt(dev, Adapt.adapt(CPUDevice(), x)) +end + end diff --git a/lib/MLDataDevices/ext/MLDataDevicesTrackerExt.jl b/lib/MLDataDevices/ext/MLDataDevicesTrackerExt.jl index f9b90d9cb..8ea347c26 100644 --- a/lib/MLDataDevices/ext/MLDataDevicesTrackerExt.jl +++ b/lib/MLDataDevices/ext/MLDataDevicesTrackerExt.jl @@ -1,7 +1,7 @@ module MLDataDevicesTrackerExt using Adapt: Adapt -using MLDataDevices: Internal, AMDGPUDevice, CUDADevice, MetalDevice, oneAPIDevice +using MLDataDevices: Internal, AbstractDevice using Tracker: Tracker for op in (:get_device, :get_device_type) @@ -11,13 +11,10 @@ end Internal.special_aos(::AbstractArray{<:Tracker.TrackedReal}) = true -for T in (AMDGPUDevice, AMDGPUDevice{Nothing}, CUDADevice, - CUDADevice{Nothing}, MetalDevice, oneAPIDevice) - @eval function Adapt.adapt_storage(to::$(T), x::AbstractArray{<:Tracker.TrackedReal}) - @warn "AbstractArray{<:Tracker.TrackedReal} is not supported for $(to). Converting \ - to Tracker.TrackedArray." maxlog=1 - return to(Tracker.collect(x)) - end +function Adapt.adapt_structure(to::AbstractDevice, x::AbstractArray{<:Tracker.TrackedReal}) + @warn "AbstractArray{<:Tracker.TrackedReal} is not supported for $(to). Converting to \ + Tracker.TrackedArray." maxlog=1 + return Adapt.adapt(to, Tracker.collect(x)) end end diff --git a/lib/MLDataDevices/ext/MLDataDevicesZygoteExt.jl b/lib/MLDataDevices/ext/MLDataDevicesZygoteExt.jl index 66a363d55..14f4de128 100644 --- a/lib/MLDataDevices/ext/MLDataDevicesZygoteExt.jl +++ b/lib/MLDataDevices/ext/MLDataDevicesZygoteExt.jl @@ -1,16 +1,13 @@ module MLDataDevicesZygoteExt using Adapt: Adapt -using MLDataDevices: CPUDevice, CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice, - ReactantDevice +using MLDataDevices: CPUDevice, AbstractDevice, Internal using Zygote: OneElement Adapt.adapt_structure(::CPUDevice, x::OneElement) = x +Adapt.adapt_structure(to::AbstractDevice, x::OneElement) = Adapt.adapt(to, collect(x)) -for dev in (CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice, ReactantDevice, - CUDADevice{Nothing}, AMDGPUDevice{Nothing}) - # use `@eval` to avoid ambiguity with adapt_storage(::CUDADevice, ::AbstractArray) - @eval Adapt.adapt_structure(to::$(dev), x::OneElement) = Adapt.adapt(to, collect(x)) -end +Internal.get_device(::OneElement) = CPUDevice() +Internal.get_device_type(::OneElement) = CPUDevice end diff --git a/lib/MLDataDevices/src/internal.jl b/lib/MLDataDevices/src/internal.jl index 8b9c770f7..2b939b153 100644 --- a/lib/MLDataDevices/src/internal.jl +++ b/lib/MLDataDevices/src/internal.jl @@ -104,8 +104,6 @@ end special_aos(::AbstractArray) = false -recursive_array_eltype(::Type{T}) where {T} = !isbitstype(T) && !(T <: Number) - combine_devices(::Nothing, ::Nothing) = nothing combine_devices(::Nothing, dev::AbstractDevice) = dev combine_devices(dev::AbstractDevice, ::Nothing) = dev @@ -137,7 +135,7 @@ for op in (:get_device, :get_device_type) @eval begin function $(op)(x::AbstractArray{T}) where {T} - if recursive_array_eltype(T) + if !isbitstype(T) && !(T <: Number) is_assigned_idxs = findall(Base.Fix1(isassigned, x), eachindex(x)) if length(is_assigned_idxs) == 0 @warn $(all_not_assigned_msg) @@ -158,15 +156,12 @@ for op in (:get_device, :get_device_type) function $(op)(x::Union{Tuple, NamedTuple}) length(x) == 0 && return $(op == :get_device ? nothing : Nothing) + # NOTE: We need unrolled_mapreduce for julia 1.10 to ensure type stability return unrolled_mapreduce(MLDataDevices.$(op), combine_devices, values(x)) end - function $(op)(f::F) where {F <: Function} - Base.issingletontype(F) && - return $(op == :get_device ? UnknownDevice() : UnknownDevice) - return unrolled_mapreduce(MLDataDevices.$(op), combine_devices, - map(Base.Fix1(getfield, f), fieldnames(F))) - end + # NOTE: Don't mark as fast_structure + $(op)(::Function) = $(op == :get_device ? UnknownDevice() : UnknownDevice) end for T in (Number, AbstractRNG, Val, Symbol, String, Nothing, AbstractRange) @@ -182,7 +177,6 @@ fast_structure(::Union{Tuple, NamedTuple}) = true for T in (Number, AbstractRNG, Val, Symbol, String, Nothing, AbstractRange) @eval fast_structure(::$(T)) = true end -fast_structure(::Function) = true fast_structure(_) = false function unrolled_mapreduce(f::F, op::O, itr) where {F, O} diff --git a/lib/MLDataDevices/src/public.jl b/lib/MLDataDevices/src/public.jl index 13dcaba29..6b87ef422 100644 --- a/lib/MLDataDevices/src/public.jl +++ b/lib/MLDataDevices/src/public.jl @@ -363,9 +363,7 @@ function set_device!(::Type{T}, ::Nothing, rank::Integer) where {T <: AbstractDe end # Dispatches for Different Data Structures -for dev in (CPUDevice, AMDGPUDevice, CUDADevice, MetalDevice, oneAPIDevice, ReactantDevice) - @eval (D::$(dev))(x) = Functors.fmap(Base.Fix1(Adapt.adapt, D), x; exclude=isleaf) -end +(D::AbstractDevice)(x) = Functors.fmap(Base.Fix1(Adapt.adapt, D), x; exclude=isleaf) for op in (:get_device, :get_device_type) @eval function $(op)(x) @@ -377,15 +375,9 @@ for op in (:get_device, :get_device_type) end # Adapt Interface -Adapt.adapt_storage(::CPUDevice, x::AbstractArray) = Adapt.adapt(Array, x) -Adapt.adapt_storage(::CPUDevice, rng::AbstractRNG) = rng - -for T in (AMDGPUDevice, CUDADevice, MetalDevice, oneAPIDevice, ReactantDevice) - @eval begin - Adapt.adapt_storage(to::$(T), ::Random.TaskLocalRNG) = default_device_rng(to) - Adapt.adapt_storage(::$(T), rng::AbstractRNG) = rng - end -end +Adapt.adapt_storage(::CPUDevice, x::AbstractArray) = Array(x) +Adapt.adapt_storage(to::AbstractDevice, ::Random.TaskLocalRNG) = default_device_rng(to) +Adapt.adapt_storage(::AbstractDevice, rng::AbstractRNG) = rng """ isleaf(x) -> Bool @@ -398,7 +390,7 @@ when an object with nested structure containing the type is transferred to a dev `Adapt.adapt_structure(::AbstractDevice, x::T)` or `Adapt.adapt_structure(::AbstractDevice, x::T)` will be called during -data movement if `isleaf(x::T) == true`. +data movement if `isleaf(x::T)`. If `MLDataDevices.isleaf(x::T)` is not defined, then it will fall back to `Functors.isleaf(x)`. """ diff --git a/lib/MLDataDevices/test/misc_tests.jl b/lib/MLDataDevices/test/misc_tests.jl index 65f63c9a9..148d758b9 100644 --- a/lib/MLDataDevices/test/misc_tests.jl +++ b/lib/MLDataDevices/test/misc_tests.jl @@ -52,7 +52,7 @@ end @testset "CRC Tests" begin dev = cpu_device() # Other devices don't work with FiniteDifferences.jl - test_rrule(Adapt.adapt, dev, randn(Float64, 10); check_inferred=false) + test_rrule(Adapt.adapt_storage, dev, randn(Float64, 10); check_inferred=false) gdev = gpu_device() if !(gdev isa MetalDevice) # On intel devices causes problems