Skip to content

Commit

Permalink
refactor: rename XLADevice/xla_device to ReactantDevice/reactant_device
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Nov 3, 2024
1 parent 402cc4b commit 7891f84
Show file tree
Hide file tree
Showing 17 changed files with 90 additions and 84 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,8 @@ GPUArraysCore = "0.1.6, 0.2"
LinearAlgebra = "1.10"
LossFunctions = "0.11.1"
LuxCore = "1"
LuxLib = "1.3.4"
MLDataDevices = "1.3"
LuxLib = "1.3.7"
MLDataDevices = "1.5"
MLUtils = "0.4.4"
MPI = "0.20.19"
MacroTools = "0.5.13"
Expand Down
2 changes: 1 addition & 1 deletion docs/src/api/Accelerator_Support/MLDataDevices.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ MLDataDevices.gpu_backend!
```@docs
MLDataDevices.cpu_device
MLDataDevices.gpu_device
MLDataDevices.xla_device
MLDataDevices.reactant_device
```

## Miscellaneous
Expand Down
6 changes: 3 additions & 3 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -142,21 +142,21 @@ Run the following to access a device:
using Reactant, Lux
Reactant.set_default_backend("cpu") # default

const dev = xla_device()
const dev = reactant_device()
```

```julia [GPU Backend]
using Reactant, Lux
Reactant.set_default_backend("gpu")

const dev = xla_device()
const dev = reactant_device()
```

```julia [TPU Backend]
using Reactant, Lux
Reactant.set_default_backend("tpu")

const dev = xla_device()
const dev = reactant_device()
```

:::
8 changes: 4 additions & 4 deletions docs/src/manual/compiling_lux_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,11 @@ x = randn(Float32, 2, 32)
y = x .^ 2
```

We will use [`xla_device`](@ref) similar to [`gpu_device`](@ref) to move the arrays to
We will use [`reactant_device`](@ref) similar to [`gpu_device`](@ref) to move the arrays to
`Reactant`.

```@example compile_lux_model
const xdev = xla_device()
const xdev = reactant_device()
x_ra = x |> xdev
y_ra = y |> xdev
Expand All @@ -66,7 +66,7 @@ pred_lux, _ = model(x, ps, Lux.testmode(st))

To run it using `XLA` we need to compile the model. We can do this using the
`Reactant.@compile` macro. Note that the inputs need to be moved to the device using
[`xla_device`](@ref) first.
[`reactant_device`](@ref) first.

```@example compile_lux_model
model_compiled = @compile model(x_ra, ps_ra, Lux.testmode(st_ra))
Expand Down Expand Up @@ -122,7 +122,7 @@ fmap(Broadcast.BroadcastFunction(-), ∂ps_zyg, ∂ps_enzyme)
Now that we saw the low-level API let's see how to train the model without any of this
boilerplate. Simply follow the following steps:

1. Create a device using `xla_device`. Remember to load `Reactant.jl` before doing this.
1. Create a device using `reactant_device`. Remember to load `Reactant.jl` before doing this.
2. Similar to other device functions move the model, parameters, states and data to the
device. Note that you might want to use [`DeviceIterator`](@ref) to move the data
loader to the device with an iterator.
Expand Down
4 changes: 2 additions & 2 deletions lib/LuxLib/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "LuxLib"
uuid = "82251201-b29d-42c6-8e01-566dec8acb11"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "1.3.6"
version = "1.3.7"

[deps]
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
Expand Down Expand Up @@ -75,7 +75,7 @@ LinearAlgebra = "1.10"
LoopVectorization = "0.12.171"
LuxCore = "1"
MKL = "0.7"
MLDataDevices = "1.2"
MLDataDevices = "1.5"
Markdown = "1.10"
NNlib = "0.9.24"
Octavian = "0.3.28"
Expand Down
2 changes: 1 addition & 1 deletion lib/LuxLib/src/impl/Impl.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ using Random: Random, AbstractRNG, rand!
using Statistics: Statistics, mean, var

using LuxCore: LuxCore
using MLDataDevices: get_device_type, CPUDevice, AMDGPUDevice, CUDADevice, XLADevice,
using MLDataDevices: get_device_type, CPUDevice, AMDGPUDevice, CUDADevice, ReactantDevice,
AbstractGPUDevice, AbstractDevice
using NNlib: NNlib, ConvDims

Expand Down
2 changes: 1 addition & 1 deletion lib/LuxLib/src/impl/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ end

conv(x, weight, cdims::ConvDims) = conv(get_device_type((x, weight)), x, weight, cdims)

function conv(::Type{<:Union{CPUDevice, CUDADevice, AMDGPUDevice, XLADevice}},
function conv(::Type{<:Union{CPUDevice, CUDADevice, AMDGPUDevice, ReactantDevice}},
x′, weight′, cdims::ConvDims)
x, weight = get_conv_input_weight(x′, weight′)
return NNlib.conv(x, weight, cdims)
Expand Down
2 changes: 1 addition & 1 deletion lib/MLDataDevices/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ Currently we provide support for the following backends:
3. `AMDGPUDevice`: `AMDGPU.jl` for AMD ROCM GPUs.
4. `MetalDevice`: `Metal.jl` for Apple Metal GPUs. **(Experimental)**
5. `oneAPIDevice`: `oneAPI.jl` for Intel GPUs. **(Experimental)**
6. `XLADevice`: `Reactant.jl` for XLA Support. **(Experimental)**
6. `ReactantDevice`: `Reactant.jl` for XLA Support. **(Experimental)**

## Updating to v1.0

Expand Down
4 changes: 2 additions & 2 deletions lib/MLDataDevices/ext/MLDataDevicesMLUtilsExt.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
module MLDataDevicesMLUtilsExt

using MLDataDevices: MLDataDevices, AbstractDevice, CPUDevice, CUDADevice, AMDGPUDevice,
MetalDevice, oneAPIDevice, XLADevice, DeviceIterator
MetalDevice, oneAPIDevice, ReactantDevice, DeviceIterator
using MLUtils: MLUtils, DataLoader

for dev in (CPUDevice, CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice, XLADevice)
for dev in (CPUDevice, CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice, ReactantDevice)
@eval function (D::$(dev))(dataloader::DataLoader)
if dataloader.parallel
if dataloader.buffer
Expand Down
20 changes: 10 additions & 10 deletions lib/MLDataDevices/ext/MLDataDevicesReactantExt.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module MLDataDevicesReactantExt

using Adapt: Adapt
using MLDataDevices: MLDataDevices, Internal, XLADevice, CPUDevice, get_device_type
using MLDataDevices: MLDataDevices, Internal, ReactantDevice, CPUDevice, get_device_type
using Reactant: Reactant, XLA, RArray, ConcreteRArray, TracedRArray, TracedRNumber

@static if isdefined(Reactant, :ConcreteRNumber)
Expand All @@ -14,38 +14,38 @@ else
const ReactantType = Union{RArray, TracedRArray, TracedRNumber}
end

MLDataDevices.loaded(::Union{XLADevice, Type{<:XLADevice}}) = true
MLDataDevices.functional(::Union{XLADevice, Type{<:XLADevice}}) = true
MLDataDevices.loaded(::Union{ReactantDevice, Type{<:ReactantDevice}}) = true
MLDataDevices.functional(::Union{ReactantDevice, Type{<:ReactantDevice}}) = true

# Default RNG: Forward to CPU, we will compile it
function MLDataDevices.default_device_rng(::XLADevice)
function MLDataDevices.default_device_rng(::ReactantDevice)
return MLDataDevices.default_device_rng(CPUDevice())
end

# Query Device from Array
function Internal.get_device(x::ConcreteRType)
client = XLA.client(x.data)
device = XLA.device(x.data)
return XLADevice(client, device)
return ReactantDevice(client, device)
end

function Internal.get_device(::Union{TracedRArray, TracedRNumber})
error("`get_device` isn't meant to be called inside `Reactant.@compile` context.")
end

Internal.get_device_type(::ReactantType) = XLADevice
Internal.get_device_type(::ReactantType) = ReactantDevice

# unsafe_free!
Internal.unsafe_free_internal!(::Type{XLADevice}, x::AbstractArray) = nothing
Internal.unsafe_free_internal!(::Type{ReactantDevice}, x::AbstractArray) = nothing

# Device Transfer
function Adapt.adapt_storage(dev::XLADevice, x::AbstractArray{<:Reactant.ReactantPrimitive})
@warn "XLADevice got an array on device: $(get_device_type(x)). We will have to \
function Adapt.adapt_storage(dev::ReactantDevice, x::AbstractArray{<:Reactant.ReactantPrimitive})
@warn "ReactantDevice got an array on device: $(get_device_type(x)). We will have to \
transfer this via CPU." maxlog=1
return Adapt.adapt_storage(dev, Adapt.adapt_storage(CPUDevice(), x))
end

function Adapt.adapt_storage(dev::XLADevice, x::Array{<:Reactant.ReactantPrimitive})
function Adapt.adapt_storage(dev::ReactantDevice, x::Array{<:Reactant.ReactantPrimitive})
client = dev.client === missing ? XLA.default_backend[] : dev.client
device = dev.device === missing ?
XLA.ClientGetDevice(client, XLA.default_device_idx[]) : dev.device
Expand Down
3 changes: 2 additions & 1 deletion lib/MLDataDevices/src/MLDataDevices.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ include("internal.jl")

export gpu_backend!, supported_gpu_backends, reset_gpu_device!
export default_device_rng
export gpu_device, cpu_device, xla_device
export gpu_device, cpu_device
export xla_device, reactant_device

export CPUDevice
export CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice
Expand Down
8 changes: 4 additions & 4 deletions lib/MLDataDevices/src/internal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ using Preferences: load_preference
using Random: AbstractRNG

using ..MLDataDevices: MLDataDevices, AbstractDevice, CPUDevice, CUDADevice, AMDGPUDevice,
MetalDevice, oneAPIDevice, XLADevice, UnknownDevice,
MetalDevice, oneAPIDevice, ReactantDevice, UnknownDevice,
supported_gpu_backends, GPU_DEVICES, loaded, functional

for dev in (CPUDevice, MetalDevice, oneAPIDevice)
Expand All @@ -27,11 +27,11 @@ for name in (:CPU, :CUDA, :AMDGPU, :Metal, :oneAPI)
get_triggerpkg_name(::Union{$ldev, Type{<:$ldev}}) = $(tpkg)
end
end
get_device_name(::XLADevice) = "XLA"
get_triggerpkg_name(::XLADevice) = "Reactant"
get_device_name(::ReactantDevice) = "XLA"
get_triggerpkg_name(::ReactantDevice) = "Reactant"

for T in (CPUDevice, CUDADevice{Nothing}, AMDGPUDevice{Nothing},
MetalDevice, oneAPIDevice, XLADevice)
MetalDevice, oneAPIDevice, ReactantDevice)
@eval get_device_id(::$(T)) = nothing
end

Expand Down
33 changes: 19 additions & 14 deletions lib/MLDataDevices/src/public.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@ end
struct MetalDevice <: AbstractGPUDevice end
struct oneAPIDevice <: AbstractGPUDevice end

@kwdef struct XLADevice{C, D} <: AbstractAcceleratorDevice
@kwdef struct ReactantDevice{C, D} <: AbstractAcceleratorDevice
client::C = missing
device::D = missing
end

function Base.:(==)(x::XLADevice, y::XLADevice)
function Base.:(==)(x::ReactantDevice, y::ReactantDevice)
if x.client !== missing
y.client === missing && return false
x.client.client != y.client.client && return false
Expand All @@ -30,6 +30,9 @@ function Base.:(==)(x::XLADevice, y::XLADevice)
return true
end

# XXX: Deprecate in v2
const XLADevice = ReactantDevice

# Fallback for when we don't know the device type
struct UnknownDevice <: AbstractDevice end

Expand Down Expand Up @@ -207,11 +210,11 @@ Return a `CPUDevice` object which can be used to transfer data to CPU.
cpu_device() = CPUDevice()

"""
xla_device(;
reactant_device(;
force::Bool=false, client=missing, device=missing
) -> Union{XLADevice, CPUDevice}
) -> Union{ReactantDevice, CPUDevice}
Return a `XLADevice` object if functional. Otherwise, throw an error if `force` is `true`.
Return a `ReactantDevice` object if functional. Otherwise, throw an error if `force` is `true`.
Falls back to `CPUDevice` if `force` is `false`.
`client` and `device` are used to specify the client and particular device to use. If not
Expand All @@ -221,18 +224,20 @@ specified, then the default client and index are used.
This is an experimental feature and might change without deprecations
"""
function xla_device(; force::Bool=false, client=missing, device=missing)
msg = "`XLADevice` is not loaded or not functional. Load `Reactant.jl` before calling \
this function. Defaulting to CPU."
if loaded(XLADevice)
functional(XLADevice) && return XLADevice(client, device)
msg = "`XLADevice` is loaded but not functional. Defaulting to CPU."
function reactant_device(; force::Bool=false, client=missing, device=missing)
msg = "`ReactantDevice` is not loaded or not functional. Load `Reactant.jl` before \
calling this function. Defaulting to CPU."
if loaded(ReactantDevice)
functional(ReactantDevice) && return ReactantDevice(client, device)
msg = "`ReactantDevice` is loaded but not functional. Defaulting to CPU."
end
force && throw(Internal.DeviceSelectionException("XLA"))
@warn msg maxlog=1
return cpu_device()
end

Base.@deprecate xla_device(; kwargs...) reactant_device(; kwargs...)

"""
default_device_rng(::AbstractDevice)
Expand Down Expand Up @@ -335,8 +340,8 @@ function set_device!(::Type{T}, dev_or_id) where {T <: AbstractDevice}
@warn "Support for Multi Device oneAPI hasn't been implemented yet. Ignoring the device setting."
T === CPUDevice &&
@warn "Setting device for `CPUDevice` doesn't make sense. Ignoring the device setting."
T === XLADevice &&
@warn "Setting device for `XLADevice` hasn't been implemented yet. Ignoring the device setting."
T === ReactantDevice &&
@warn "Setting device for `ReactantDevice` hasn't been implemented yet. Ignoring the device setting."
return
end

Expand Down Expand Up @@ -389,7 +394,7 @@ end
Adapt.adapt_storage(::CPUDevice, x::AbstractArray) = Adapt.adapt(Array, x)
Adapt.adapt_storage(::CPUDevice, rng::AbstractRNG) = rng

for T in (AMDGPUDevice, CUDADevice, MetalDevice, oneAPIDevice, XLADevice)
for T in (AMDGPUDevice, CUDADevice, MetalDevice, oneAPIDevice, ReactantDevice)
@eval begin
function Adapt.adapt_storage(to::$(T), ::Random.TaskLocalRNG)
return default_device_rng(to)
Expand Down
12 changes: 6 additions & 6 deletions lib/MLDataDevices/test/iterator_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@ if BACKEND_GROUP == "xla" || BACKEND_GROUP == "all"
end
end

DEVICES = [CPUDevice, CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice, XLADevice]
DEVICES = [CPUDevice, CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice, ReactantDevice]

freed_if_can_be_freed(x) = freed_if_can_be_freed(get_device_type(x), x)
freed_if_can_be_freed(::Type{CPUDevice}, x) = true
freed_if_can_be_freed(::Type{XLADevice}, x) = true
freed_if_can_be_freed(::Type{ReactantDevice}, x) = true
function freed_if_can_be_freed(::Type, x)
try
Array(x)
Expand Down Expand Up @@ -63,7 +63,7 @@ end
@info "Testing DataLoader with parallel=$parallel"
X = rand(Float64, 3, 33)
post = DataLoader(X; batchsize=13, shuffle=false, parallel) |> dev
if dev_type === XLADevice
if dev_type === ReactantDevice
pre = post # XXX: deadlocks and other shenanigans
else
pre = DataLoader(dev(X); batchsize=13, shuffle=false, parallel)
Expand All @@ -77,7 +77,7 @@ end
# Ordering is not guaranteed in parallel
!parallel && @test p q

if dev_type === CPUDevice || dev_type === XLADevice
if dev_type === CPUDevice || dev_type === ReactantDevice
continue
end

Expand All @@ -91,7 +91,7 @@ end

Y = rand(Float64, 1, 33)
post = DataLoader((; x=X, y=Y); batchsize=13, shuffle=false, parallel) |> dev
if dev_type === XLADevice
if dev_type === ReactantDevice
pre = post # XXX: deadlocks and other shenanigans
else
pre = DataLoader((; x=dev(X), y=dev(Y)); batchsize=13, shuffle=false, parallel)
Expand All @@ -108,7 +108,7 @@ end
!parallel && @test p.x q.x
!parallel && @test p.y q.y

if dev_type === CPUDevice || dev_type === XLADevice
if dev_type === CPUDevice || dev_type === ReactantDevice
continue
end

Expand Down
Loading

0 comments on commit 7891f84

Please sign in to comment.