Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Commit

Permalink
feat: support passing in device and client to XLA
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Nov 1, 2024
1 parent 17419d2 commit 86add29
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 12 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MLDataDevices"
uuid = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "1.4.2"
version = "1.5.0"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down Expand Up @@ -57,7 +57,7 @@ MLUtils = "0.4.4"
Metal = "1"
Preferences = "1.4"
Random = "1.10"
Reactant = "0.2"
Reactant = "0.2.3"
RecursiveArrayTools = "3.8"
ReverseDiff = "1.15"
SparseArrays = "1.10"
Expand Down
31 changes: 26 additions & 5 deletions ext/MLDataDevicesReactantExt.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
module MLDataDevicesReactantExt

using Adapt: Adapt
using MLDataDevices: MLDataDevices, Internal, XLADevice, CPUDevice
using Reactant: Reactant, RArray
using MLDataDevices: MLDataDevices, Internal, XLADevice, CPUDevice, get_device_type
using Reactant: Reactant, XLA, RArray, ConcreteRArray, TracedRArray, ConcreteRNumber,
TracedRNumber

MLDataDevices.loaded(::Union{XLADevice, Type{<:XLADevice}}) = true
MLDataDevices.functional(::Union{XLADevice, Type{<:XLADevice}}) = true
Expand All @@ -13,14 +14,34 @@ function MLDataDevices.default_device_rng(::XLADevice)
end

# Query Device from Array
Internal.get_device(::RArray) = XLADevice()
function Internal.get_device(x::Union{ConcreteRArray, ConcreteRNumber})
client = XLA.client(x.data)
device = XLA.device(x.data)
return XLADevice(client, device)
end

function Internal.get_device(::Union{TracedRArray, TracedRNumber})
error("`get_device` isn't meant to be called inside `Reactant.@compile` context.")
end

Internal.get_device_type(::RArray) = XLADevice
Internal.get_device_type(::Union{RArray, TracedRNumber, ConcreteRNumber}) = XLADevice

# unsafe_free!
Internal.unsafe_free_internal!(::Type{XLADevice}, x::AbstractArray) = nothing

# Device Transfer
Adapt.adapt_storage(::XLADevice, x::AbstractArray) = Reactant.to_rarray(x)
function Adapt.adapt_storage(dev::XLADevice, x::AbstractArray{<:Reactant.ReactantPrimitive})
@warn "XLADevice got an array on device: $(get_device_type(x)). We will have to \
transfer this via CPU." maxlog=1
return Adapt.adapt_storage(dev, Adapt.adapt_storage(CPUDevice(), x))
end

function Adapt.adapt_storage(dev::XLADevice, x::Array{<:Reactant.ReactantPrimitive})
client = dev.client === missing ? XLA.default_backend[] : dev.client
device = dev.device === missing ?
XLA.ClientGetDevice(client, XLA.default_device_idx[]) : dev.device
return ConcreteRArray{eltype(x), ndims(x)}(
XLA.AsyncBuffer(XLA.ArrayFromHostBuffer(client, x, device), nothing), size(x))
end

end
17 changes: 12 additions & 5 deletions src/public.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@ end
struct MetalDevice <: AbstractGPUDevice end
struct oneAPIDevice <: AbstractGPUDevice end

# TODO: Later we might want to add the client field here?
struct XLADevice <: AbstractAcceleratorDevice end
@kwdef struct XLADevice{C, D} <: AbstractAcceleratorDevice
client::C = missing
device::D = missing
end

# Fallback for when we don't know the device type
struct UnknownDevice <: AbstractDevice end
Expand Down Expand Up @@ -189,20 +191,25 @@ Return a `CPUDevice` object which can be used to transfer data to CPU.
cpu_device() = CPUDevice()

"""
xla_device(; force::Bool=false) -> Union{XLADevice, CPUDevice}
xla_device(;
force::Bool=false, client=missing, device=missing
) -> Union{XLADevice, CPUDevice}
Return a `XLADevice` object if functional. Otherwise, throw an error if `force` is `true`.
Falls back to `CPUDevice` if `force` is `false`.
`client` and `device` are used to specify the client and index of the XLA device. If not
specified, then the default client and index are used.
!!! danger
This is an experimental feature and might change without deprecations
"""
function xla_device(; force::Bool=false)
function xla_device(; force::Bool=false, client=missing, device=missing)
msg = "`XLADevice` is not loaded or not functional. Load `Reactant.jl` before calling \
this function. Defaulting to CPU."
if loaded(XLADevice)
functional(XLADevice) && return XLADevice()
functional(XLADevice) && return XLADevice(client, device)
msg = "`XLADevice` is loaded but not functional. Defaulting to CPU."
end
force && throw(Internal.DeviceSelectionException("XLA"))
Expand Down

0 comments on commit 86add29

Please sign in to comment.