Skip to content

Commit

Permalink
fix: use the correct dispatches for device overloads (#1118)
Browse files Browse the repository at this point in the history
* fix: use the correct dispatches for device overloads

* fix: correct dispatches for adapt_structure

* refactor: simplify get_device(_type) code

* fix: partial revert of the ancestor patch

* fix: restore unrolled_mapreduce for julia 1.10
  • Loading branch information
avik-pal authored Dec 5, 2024
1 parent 5ad4fa9 commit fd7b740
Show file tree
Hide file tree
Showing 13 changed files with 72 additions and 84 deletions.
6 changes: 4 additions & 2 deletions lib/LuxCore/Project.toml
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
name = "LuxCore"
uuid = "bb33d45b-7691-41d6-9220-0943567d0623"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "1.2.0"
version = "1.2.1"

[deps]
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
DispatchDoctor = "8d63f2c5-f18a-4cf2-ba9d-b3f60fc568c8"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

[weakdeps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
Expand All @@ -25,11 +26,12 @@ LuxCoreArrayInterfaceTrackerExt = ["ArrayInterface", "Tracker"]
LuxCoreChainRulesCoreExt = "ChainRulesCore"
LuxCoreEnzymeCoreExt = "EnzymeCore"
LuxCoreFunctorsExt = "Functors"
LuxCoreMLDataDevicesExt = "MLDataDevices"
LuxCoreMLDataDevicesExt = ["Adapt", "MLDataDevices"]
LuxCoreReactantExt = "Reactant"
LuxCoreSetfieldExt = "Setfield"

[compat]
Adapt = "4.1"
ArrayInterface = "7.9"
ChainRulesCore = "1.24"
Compat = "4.16"
Expand Down
20 changes: 10 additions & 10 deletions lib/LuxCore/ext/LuxCoreMLDataDevicesExt.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
module LuxCoreMLDataDevicesExt

using LuxCore: LuxCore
using MLDataDevices: MLDataDevices
using Adapt: Adapt
using LuxCore: LuxCore, AbstractLuxLayer
using MLDataDevices: MLDataDevices, AbstractDevice

for (dev) in (:CPU, :CUDA, :AMDGPU, :Metal, :oneAPI)
ldev = Symbol(dev, :Device)
@eval function (::MLDataDevices.$(ldev))(NN::LuxCore.AbstractLuxLayer)
@warn "Lux layers are stateless and hence don't participate in device transfers. \
Apply this function on the parameters and states generated using \
`LuxCore.setup`."
return NN
end
MLDataDevices.isleaf(::AbstractLuxLayer) = true

function Adapt.adapt_storage(::AbstractDevice, x::AbstractLuxLayer)
@warn "Lux layers are stateless and hence don't participate in device transfers. \
Apply this function on the parameters and states generated using \
`LuxCore.setup`."
return x
end

end
2 changes: 1 addition & 1 deletion lib/MLDataDevices/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MLDataDevices"
uuid = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "1.6.2"
version = "1.6.3"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down
3 changes: 2 additions & 1 deletion lib/MLDataDevices/ext/MLDataDevicesChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ using MLDataDevices: AbstractDevice, UnknownDevice, get_device, get_device_type,
@non_differentiable cpu_device(::Any...)
@non_differentiable reactant_device(::Any...)

function ChainRulesCore.rrule(::typeof(Adapt.adapt), to::AbstractDevice, x::AbstractArray)
function ChainRulesCore.rrule(
::typeof(Adapt.adapt_storage), to::AbstractDevice, x::AbstractArray)
dev = get_device(x)
y = Adapt.adapt_storage(to, x)
if dev === nothing || dev isa UnknownDevice
Expand Down
11 changes: 4 additions & 7 deletions lib/MLDataDevices/ext/MLDataDevicesChainRulesExt.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
module MLDataDevicesChainRulesExt

using Adapt: Adapt
using MLDataDevices: CPUDevice, CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice,
ReactantDevice
using MLDataDevices: CPUDevice, AbstractDevice, Internal
using ChainRules: OneElement

Adapt.adapt_structure(::CPUDevice, x::OneElement) = x
Adapt.adapt_structure(to::AbstractDevice, x::OneElement) = Adapt.adapt(to, collect(x))

for dev in (CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice, ReactantDevice,
CUDADevice{Nothing}, AMDGPUDevice{Nothing})
# use `@eval` to avoid ambiguity with adapt_storage(::CUDADevice, ::AbstractArray)
@eval Adapt.adapt_structure(to::$(dev), x::OneElement) = Adapt.adapt(to, collect(x))
end
Internal.get_device(::OneElement) = CPUDevice()
Internal.get_device_type(::OneElement) = CPUDevice

end
7 changes: 5 additions & 2 deletions lib/MLDataDevices/ext/MLDataDevicesFillArraysExt.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
module MLDataDevicesFillArraysExt

using Adapt: Adapt
using FillArrays: FillArrays, AbstractFill
using MLDataDevices: MLDataDevices, CPUDevice, AbstractDevice
using FillArrays: AbstractFill
using MLDataDevices: CPUDevice, AbstractDevice, Internal

Adapt.adapt_structure(::CPUDevice, x::AbstractFill) = x
Adapt.adapt_structure(to::AbstractDevice, x::AbstractFill) = Adapt.adapt(to, collect(x))

Internal.get_device(::AbstractFill) = CPUDevice()
Internal.get_device_type(::AbstractFill) = CPUDevice

end
39 changes: 18 additions & 21 deletions lib/MLDataDevices/ext/MLDataDevicesMLUtilsExt.jl
Original file line number Diff line number Diff line change
@@ -1,34 +1,31 @@
module MLDataDevicesMLUtilsExt

using Adapt: Adapt
using MLDataDevices: MLDataDevices, AbstractDevice, CPUDevice, CUDADevice, AMDGPUDevice,
MetalDevice, oneAPIDevice, ReactantDevice, DeviceIterator
using MLDataDevices: MLDataDevices, AbstractDevice, DeviceIterator
using MLUtils: MLUtils, DataLoader

MLDataDevices.isleaf(::DataLoader) = true

for dev in (CPUDevice, CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice, ReactantDevice)
@eval function Adapt.adapt_structure(D::$(dev), dataloader::DataLoader)
if dataloader.parallel
if dataloader.buffer
@warn "Using `buffer=true` for parallel DataLoader with automatic device \
transfer is currently not implemented. Ignoring `buffer=true`."
end

# Mostly from https://github.com/JuliaML/MLUtils.jl/blob/main/src/eachobs.jl
data = MLUtils.ObsView(dataloader.data)
data = dataloader.shuffle ? MLUtils.shuffleobs(data) : data
data = if dataloader.batchsize > 0
MLUtils.BatchView(
data; dataloader.batchsize, dataloader.partial, dataloader.collate)
else
data
end
function Adapt.adapt_structure(dev::AbstractDevice, dataloader::DataLoader)
if dataloader.parallel
if dataloader.buffer
@warn "Using `buffer=true` for parallel DataLoader with automatic device \
transfer is currently not implemented. Ignoring `buffer=true`."
end

return DeviceIterator(identity, eachobsparallel(D, data))
# Mostly from https://github.com/JuliaML/MLUtils.jl/blob/main/src/eachobs.jl
data = MLUtils.ObsView(dataloader.data)
data = dataloader.shuffle ? MLUtils.shuffleobs(data) : data
data = if dataloader.batchsize > 0
MLUtils.BatchView(
data; dataloader.batchsize, dataloader.partial, dataloader.collate)
else
data
end
return DeviceIterator(D, dataloader)

return DeviceIterator(identity, eachobsparallel(dev, data))
end
return DeviceIterator(dev, dataloader)
end

function eachobsparallel(dev::AbstractDevice, data)
Expand Down
10 changes: 9 additions & 1 deletion lib/MLDataDevices/ext/MLDataDevicesReactantExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ function Adapt.adapt_storage(
dev::ReactantDevice, x::AbstractArray{<:Reactant.ReactantPrimitive})
@warn "ReactantDevice got an array on device: $(get_device_type(x)). We will have to \
transfer this via CPU." maxlog=1
return Adapt.adapt_storage(dev, Adapt.adapt_storage(CPUDevice(), x))
return Adapt.adapt(dev, Adapt.adapt(CPUDevice(), x))
end

function Adapt.adapt_storage(dev::ReactantDevice, x::Array{<:Reactant.ReactantPrimitive})
Expand All @@ -46,4 +46,12 @@ function Adapt.adapt_storage(dev::ReactantDevice, x::Array{<:Reactant.ReactantPr
return ConcreteRArray(x; client, device)
end

# XXX: Check for client and device and use faster implementation if possible
function Adapt.adapt_storage(dev::ReactantDevice, x::ConcreteRArray)
dev.client === missing && dev.device === missing && return x
@warn "Fetching `client` and `device` from a ConcreteRArray hasn't been implemented \
yet. Using slow fallback path." maxlog=1
return Adapt.adapt(dev, Adapt.adapt(CPUDevice(), x))
end

end
13 changes: 5 additions & 8 deletions lib/MLDataDevices/ext/MLDataDevicesTrackerExt.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module MLDataDevicesTrackerExt

using Adapt: Adapt
using MLDataDevices: Internal, AMDGPUDevice, CUDADevice, MetalDevice, oneAPIDevice
using MLDataDevices: Internal, AbstractDevice
using Tracker: Tracker

for op in (:get_device, :get_device_type)
Expand All @@ -11,13 +11,10 @@ end

Internal.special_aos(::AbstractArray{<:Tracker.TrackedReal}) = true

for T in (AMDGPUDevice, AMDGPUDevice{Nothing}, CUDADevice,
CUDADevice{Nothing}, MetalDevice, oneAPIDevice)
@eval function Adapt.adapt_storage(to::$(T), x::AbstractArray{<:Tracker.TrackedReal})
@warn "AbstractArray{<:Tracker.TrackedReal} is not supported for $(to). Converting \
to Tracker.TrackedArray." maxlog=1
return to(Tracker.collect(x))
end
function Adapt.adapt_structure(to::AbstractDevice, x::AbstractArray{<:Tracker.TrackedReal})
@warn "AbstractArray{<:Tracker.TrackedReal} is not supported for $(to). Converting to \
Tracker.TrackedArray." maxlog=1
return Adapt.adapt(to, Tracker.collect(x))
end

end
11 changes: 4 additions & 7 deletions lib/MLDataDevices/ext/MLDataDevicesZygoteExt.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
module MLDataDevicesZygoteExt

using Adapt: Adapt
using MLDataDevices: CPUDevice, CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice,
ReactantDevice
using MLDataDevices: CPUDevice, AbstractDevice, Internal
using Zygote: OneElement

Adapt.adapt_structure(::CPUDevice, x::OneElement) = x
Adapt.adapt_structure(to::AbstractDevice, x::OneElement) = Adapt.adapt(to, collect(x))

for dev in (CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice, ReactantDevice,
CUDADevice{Nothing}, AMDGPUDevice{Nothing})
# use `@eval` to avoid ambiguity with adapt_storage(::CUDADevice, ::AbstractArray)
@eval Adapt.adapt_structure(to::$(dev), x::OneElement) = Adapt.adapt(to, collect(x))
end
Internal.get_device(::OneElement) = CPUDevice()
Internal.get_device_type(::OneElement) = CPUDevice

end
14 changes: 4 additions & 10 deletions lib/MLDataDevices/src/internal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,6 @@ end

special_aos(::AbstractArray) = false

recursive_array_eltype(::Type{T}) where {T} = !isbitstype(T) && !(T <: Number)

combine_devices(::Nothing, ::Nothing) = nothing
combine_devices(::Nothing, dev::AbstractDevice) = dev
combine_devices(dev::AbstractDevice, ::Nothing) = dev
Expand Down Expand Up @@ -137,7 +135,7 @@ for op in (:get_device, :get_device_type)

@eval begin
function $(op)(x::AbstractArray{T}) where {T}
if recursive_array_eltype(T)
if !isbitstype(T) && !(T <: Number)
is_assigned_idxs = findall(Base.Fix1(isassigned, x), eachindex(x))
if length(is_assigned_idxs) == 0
@warn $(all_not_assigned_msg)
Expand All @@ -158,15 +156,12 @@ for op in (:get_device, :get_device_type)

function $(op)(x::Union{Tuple, NamedTuple})
length(x) == 0 && return $(op == :get_device ? nothing : Nothing)
# NOTE: We need unrolled_mapreduce for julia 1.10 to ensure type stability
return unrolled_mapreduce(MLDataDevices.$(op), combine_devices, values(x))
end

function $(op)(f::F) where {F <: Function}
Base.issingletontype(F) &&
return $(op == :get_device ? UnknownDevice() : UnknownDevice)
return unrolled_mapreduce(MLDataDevices.$(op), combine_devices,
map(Base.Fix1(getfield, f), fieldnames(F)))
end
# NOTE: Don't mark as fast_structure
$(op)(::Function) = $(op == :get_device ? UnknownDevice() : UnknownDevice)
end

for T in (Number, AbstractRNG, Val, Symbol, String, Nothing, AbstractRange)
Expand All @@ -182,7 +177,6 @@ fast_structure(::Union{Tuple, NamedTuple}) = true
for T in (Number, AbstractRNG, Val, Symbol, String, Nothing, AbstractRange)
@eval fast_structure(::$(T)) = true
end
fast_structure(::Function) = true
fast_structure(_) = false

function unrolled_mapreduce(f::F, op::O, itr) where {F, O}
Expand Down
18 changes: 5 additions & 13 deletions lib/MLDataDevices/src/public.jl
Original file line number Diff line number Diff line change
Expand Up @@ -363,9 +363,7 @@ function set_device!(::Type{T}, ::Nothing, rank::Integer) where {T <: AbstractDe
end

# Dispatches for Different Data Structures
for dev in (CPUDevice, AMDGPUDevice, CUDADevice, MetalDevice, oneAPIDevice, ReactantDevice)
@eval (D::$(dev))(x) = Functors.fmap(Base.Fix1(Adapt.adapt, D), x; exclude=isleaf)
end
(D::AbstractDevice)(x) = Functors.fmap(Base.Fix1(Adapt.adapt, D), x; exclude=isleaf)

for op in (:get_device, :get_device_type)
@eval function $(op)(x)
Expand All @@ -377,15 +375,9 @@ for op in (:get_device, :get_device_type)
end

# Adapt Interface
Adapt.adapt_storage(::CPUDevice, x::AbstractArray) = Adapt.adapt(Array, x)
Adapt.adapt_storage(::CPUDevice, rng::AbstractRNG) = rng

for T in (AMDGPUDevice, CUDADevice, MetalDevice, oneAPIDevice, ReactantDevice)
@eval begin
Adapt.adapt_storage(to::$(T), ::Random.TaskLocalRNG) = default_device_rng(to)
Adapt.adapt_storage(::$(T), rng::AbstractRNG) = rng
end
end
Adapt.adapt_storage(::CPUDevice, x::AbstractArray) = Array(x)
Adapt.adapt_storage(to::AbstractDevice, ::Random.TaskLocalRNG) = default_device_rng(to)
Adapt.adapt_storage(::AbstractDevice, rng::AbstractRNG) = rng

"""
isleaf(x) -> Bool
Expand All @@ -398,7 +390,7 @@ when an object with nested structure containing the type is transferred to a dev
`Adapt.adapt_structure(::AbstractDevice, x::T)` or
`Adapt.adapt_structure(::AbstractDevice, x::T)` will be called during
data movement if `isleaf(x::T) == true`.
data movement if `isleaf(x::T)`.
If `MLDataDevices.isleaf(x::T)` is not defined, then it will fall back to `Functors.isleaf(x)`.
"""
Expand Down
2 changes: 1 addition & 1 deletion lib/MLDataDevices/test/misc_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ end

@testset "CRC Tests" begin
dev = cpu_device() # Other devices don't work with FiniteDifferences.jl
test_rrule(Adapt.adapt, dev, randn(Float64, 10); check_inferred=false)
test_rrule(Adapt.adapt_storage, dev, randn(Float64, 10); check_inferred=false)

gdev = gpu_device()
if !(gdev isa MetalDevice) # On intel devices causes problems
Expand Down

0 comments on commit fd7b740

Please sign in to comment.