diff --git a/Project.toml b/Project.toml index 68d4325..7d5e4a8 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "MLDataDevices" uuid = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" authors = ["Avik Pal and contributors"] -version = "1.4.2" +version = "1.5.0" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" @@ -57,7 +57,7 @@ MLUtils = "0.4.4" Metal = "1" Preferences = "1.4" Random = "1.10" -Reactant = "0.2" +Reactant = "0.2.3" RecursiveArrayTools = "3.8" ReverseDiff = "1.15" SparseArrays = "1.10" diff --git a/ext/MLDataDevicesReactantExt.jl b/ext/MLDataDevicesReactantExt.jl index 3abc8fc..0ad0a0d 100644 --- a/ext/MLDataDevicesReactantExt.jl +++ b/ext/MLDataDevicesReactantExt.jl @@ -1,8 +1,9 @@ module MLDataDevicesReactantExt using Adapt: Adapt -using MLDataDevices: MLDataDevices, Internal, XLADevice, CPUDevice -using Reactant: Reactant, RArray +using MLDataDevices: MLDataDevices, Internal, XLADevice, CPUDevice, get_device_type +using Reactant: Reactant, XLA, RArray, ConcreteRArray, TracedRArray, ConcreteRNumber, + TracedRNumber MLDataDevices.loaded(::Union{XLADevice, Type{<:XLADevice}}) = true MLDataDevices.functional(::Union{XLADevice, Type{<:XLADevice}}) = true @@ -13,14 +14,34 @@ function MLDataDevices.default_device_rng(::XLADevice) end # Query Device from Array -Internal.get_device(::RArray) = XLADevice() +function Internal.get_device(x::Union{ConcreteRArray, ConcreteRNumber}) + client = XLA.client(x.data) + device = XLA.device(x.data) + return XLADevice(client, device) +end + +function Internal.get_device(::Union{TracedRArray, TracedRNumber}) + error("`get_device` isn't meant to be called inside `Reactant.@compile` context.") +end -Internal.get_device_type(::RArray) = XLADevice +Internal.get_device_type(::Union{RArray, TracedRNumber, ConcreteRNumber}) = XLADevice # unsafe_free! Internal.unsafe_free_internal!(::Type{XLADevice}, x::AbstractArray) = nothing # Device Transfer -Adapt.adapt_storage(::XLADevice, x::AbstractArray) = Reactant.to_rarray(x) +function Adapt.adapt_storage(dev::XLADevice, x::AbstractArray{<:Reactant.ReactantPrimitive}) + @warn "XLADevice got an array on device: $(get_device_type(x)). We will have to \ + transfer this via CPU." maxlog=1 + return Adapt.adapt_storage(dev, Adapt.adapt_storage(CPUDevice(), x)) +end + +function Adapt.adapt_storage(dev::XLADevice, x::Array{<:Reactant.ReactantPrimitive}) + client = dev.client === missing ? XLA.default_backend[] : dev.client + device = dev.device === missing ? + XLA.ClientGetDevice(client, XLA.default_device_idx[]) : dev.device + return ConcreteRArray{eltype(x), ndims(x)}( + XLA.AsyncBuffer(XLA.ArrayFromHostBuffer(client, x, device), nothing), size(x)) +end end diff --git a/src/public.jl b/src/public.jl index 6440ddb..794f11c 100644 --- a/src/public.jl +++ b/src/public.jl @@ -9,8 +9,10 @@ end struct MetalDevice <: AbstractGPUDevice end struct oneAPIDevice <: AbstractGPUDevice end -# TODO: Later we might want to add the client field here? -struct XLADevice <: AbstractAcceleratorDevice end +@kwdef struct XLADevice{C, D} <: AbstractAcceleratorDevice + client::C = missing + device::D = missing +end # Fallback for when we don't know the device type struct UnknownDevice <: AbstractDevice end @@ -189,20 +191,25 @@ Return a `CPUDevice` object which can be used to transfer data to CPU. cpu_device() = CPUDevice() """ - xla_device(; force::Bool=false) -> Union{XLADevice, CPUDevice} + xla_device(; + force::Bool=false, client=missing, device=missing + ) -> Union{XLADevice, CPUDevice} Return a `XLADevice` object if functional. Otherwise, throw an error if `force` is `true`. Falls back to `CPUDevice` if `force` is `false`. +`client` and `device` are used to specify the client and index of the XLA device. If not +specified, then the default client and index are used. + !!! danger This is an experimental feature and might change without deprecations """ -function xla_device(; force::Bool=false) +function xla_device(; force::Bool=false, client=missing, device=missing) msg = "`XLADevice` is not loaded or not functional. Load `Reactant.jl` before calling \ this function. Defaulting to CPU." if loaded(XLADevice) - functional(XLADevice) && return XLADevice() + functional(XLADevice) && return XLADevice(client, device) msg = "`XLADevice` is loaded but not functional. Defaulting to CPU." end force && throw(Internal.DeviceSelectionException("XLA"))