Skip to content

Commit

Permalink
feat: allow specifying sharding in ReactantDevice (#1254)
Browse files Browse the repository at this point in the history
* feat: allow specifying sharding in ReactantDevice

* fix: unwrap shardinfo

* chore: allow new GPUArraysCore

* fix: unthunk

* fix: simplify the logic
  • Loading branch information
avik-pal authored Mar 6, 2025
1 parent 1af6555 commit e3a72b5
Show file tree
Hide file tree
Showing 6 changed files with 64 additions and 26 deletions.
2 changes: 1 addition & 1 deletion lib/MLDataDevices/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MLDataDevices"
uuid = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "1.7.0"
version = "1.8.0"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down
5 changes: 4 additions & 1 deletion lib/MLDataDevices/ext/MLDataDevicesChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@ function ChainRulesCore.rrule(
return y, ∇adapt_storage_unknown
else
∇adapt_storage = let dev = dev, x = x
Δ -> (NoTangent(), NoTangent(), ProjectTo(x)(dev(Δ)))
Δ -> begin
∂x = ChainRulesCore.@thunk ProjectTo(x)(dev(ChainRulesCore.unthunk(Δ)))
return NoTangent(), NoTangent(), ∂x
end
end
return Adapt.adapt_storage(to, x), ∇adapt_storage
end
Expand Down
50 changes: 38 additions & 12 deletions lib/MLDataDevices/ext/MLDataDevicesReactantExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,30 +3,43 @@ module MLDataDevicesReactantExt
using Adapt: Adapt
using MLDataDevices: MLDataDevices, Internal, ReactantDevice, CPUDevice
using Random: Random
using Reactant: Reactant, XLA, ConcreteRArray, ConcreteRNumber, TracedRArray,
TracedRNumber
using Reactant: Reactant, ConcreteRArray, ConcreteRNumber, TracedRArray, TracedRNumber

MLDataDevices.loaded(::Union{ReactantDevice, Type{<:ReactantDevice}}) = true
MLDataDevices.functional(::Union{ReactantDevice, Type{<:ReactantDevice}}) = true

# Default RNG
function MLDataDevices.default_device_rng(::ReactantDevice)
return Reactant.TracedRandom.default_rng()
end
MLDataDevices.default_device_rng(::ReactantDevice) = Reactant.TracedRandom.default_rng()

# Query Device from Array
function Internal.get_device(x::Union{ConcreteRNumber, ConcreteRArray})
return ReactantDevice(XLA.client(x), XLA.device(x))
@static if isdefined(Reactant, :ConcreteIFRTArray)
const AllConcreteTypes = Union{
Reactant.ConcreteIFRTNumber, Reactant.ConcreteIFRTArray,
Reactant.ConcretePJRTNumber, Reactant.ConcretePJRTArray
}
elseif isdefined(Reactant, :ConcretePJRTArray)
const AllConcreteTypes = Union{
Reactant.ConcretePJRTNumber, Reactant.ConcretePJRTArray
}
else
const AllConcreteTypes = Union{ConcreteRNumber, ConcreteRArray}
end

function Internal.get_device(::Union{TracedRArray, TracedRNumber})
error("`get_device` isn't meant to be called inside `Reactant.@compile` context.")
function Internal.get_device(x::AllConcreteTypes)
return ReactantDevice(
Reactant.XLA.client(x),
Reactant.XLA.device(x),
IdDict{AllConcreteTypes, Reactant.Sharding.AbstractSharding}(
x => x.sharding.sharding
)
)
end
Internal.get_device_type(::AllConcreteTypes) = ReactantDevice

function Internal.get_device_type(
::Union{TracedRArray, TracedRNumber, ConcreteRArray, ConcreteRNumber})
return ReactantDevice
function Internal.get_device(::Union{TracedRArray, TracedRNumber})
error("`get_device` isn't meant to be called inside `Reactant.@compile` context.")
end
Internal.get_device_type(::Union{TracedRArray, TracedRNumber}) = ReactantDevice

# unsafe_free!
Internal.unsafe_free_internal!(::Type{ReactantDevice}, x::AbstractArray) = nothing
Expand All @@ -36,6 +49,19 @@ function Adapt.adapt_storage(dev::ReactantDevice, x::AbstractArray)
kwargs = (;)
dev.client === missing || (kwargs = (; kwargs..., client=dev.client))
dev.device === missing || (kwargs = (; kwargs..., device=dev.device))
if dev.sharding !== missing
if dev.sharding isa IdDict
sharding = dev.sharding[x]
@assert sharding isa Reactant.Sharding.AbstractSharding
kwargs = (; kwargs..., sharding)
elseif dev.sharding isa Reactant.Sharding.AbstractSharding
kwargs = (; kwargs..., dev.sharding)
else
throw(ArgumentError("`sharding` must be an `IdDict` or a \
`Reactant.Sharding.AbstractSharding` but got \
$(typeof(dev.sharding))."))
end
end
return ConcreteRArray(x; kwargs...)
end

Expand Down
9 changes: 8 additions & 1 deletion lib/MLDataDevices/src/internal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,14 @@ end
combine_devices(dev::ReactantDevice, ::AbstractDevice) = dev
combine_devices(::AbstractDevice, dev::ReactantDevice) = dev
function combine_devices(dev1::ReactantDevice, dev2::ReactantDevice)
dev1 == dev2 && return dev1
if dev1 == dev2
# `merge(...)` of `IdDict` constructs a `Dict`
sharding = dev1.sharding
for (k, v) in dev2.sharding
sharding[k] = v
end
return ReactantDevice(dev1.client, dev1.device, sharding)
end
throw(ArgumentError("Objects are on different devices: $(dev1) and $(dev2)."))
end
combine_devices(::Type{ReactantDevice}, ::Type{UnknownDevice}) = ReactantDevice
Expand Down
22 changes: 12 additions & 10 deletions lib/MLDataDevices/src/public.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@ end
struct MetalDevice <: AbstractGPUDevice end
struct oneAPIDevice <: AbstractGPUDevice end

# TODO: We need to rethink how to store the client and device. what about sharding??
@kwdef struct ReactantDevice{C, D} <: AbstractAcceleratorDevice
@kwdef struct ReactantDevice{C, D, S} <: AbstractAcceleratorDevice
client::C = missing
device::D = missing
sharding::S = missing
end

function Base.:(==)(x::ReactantDevice, y::ReactantDevice)
Expand Down Expand Up @@ -212,24 +212,26 @@ cpu_device() = CPUDevice()

"""
reactant_device(;
force::Bool=false, client=missing, device=missing
force::Bool=false, client=missing, device=missing, sharding=missing
) -> Union{ReactantDevice, CPUDevice}
Return a `ReactantDevice` object if functional. Otherwise, throw an error if `force` is `true`.
Falls back to `CPUDevice` if `force` is `false`.
Return a `ReactantDevice` object if functional. Otherwise, throw an error if `force` is
`true`. Falls back to `CPUDevice` if `force` is `false`.
`client` and `device` are used to specify the client and particular device to use. If not
specified, then the default client and index are used.
!!! danger
This is an experimental feature and might change without deprecations
`sharding` is used to specify the sharding strategy. If a
`Reactant.Sharding.AbstractSharding` is specified, then we use it to shard all abstract
arrays. Alternatively, pass in a `IdDict` to specify the sharding for specific leaves.
"""
function reactant_device(; force::Bool=false, client=missing, device=missing)
function reactant_device(;
force::Bool=false, client=missing, device=missing, sharding=missing
)
msg = "`ReactantDevice` is not loaded or not functional. Load `Reactant.jl` before \
calling this function. Defaulting to CPU."
if loaded(ReactantDevice)
functional(ReactantDevice) && return ReactantDevice(client, device)
functional(ReactantDevice) && return ReactantDevice(client, device, sharding)
msg = "`ReactantDevice` is loaded but not functional. Defaulting to CPU."
end
force && throw(Internal.DeviceSelectionException("XLA"))
Expand Down
2 changes: 1 addition & 1 deletion lib/WeightInitializers/test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ Aqua = "0.8.7"
Documenter = "1.5.0"
ExplicitImports = "1.9.0"
GPUArrays = "10.2"
GPUArraysCore = "0.1.6"
GPUArraysCore = "0.1.6, 0.2"
Hwloc = "3.3"
InteractiveUtils = "<0.0.1, 1"
LinearAlgebra = "1.10"
Expand Down

2 comments on commit e3a72b5

@avik-pal
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register subdir=lib/MLDataDevices

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/126463

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a MLDataDevices-v1.8.0 -m "<description of version>" e3a72b5710c9844e45667fbba87ef82616464b47
git push origin MLDataDevices-v1.8.0

Please sign in to comment.