Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Commit

Permalink
feat: more extensive testing of XLA backend
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Oct 3, 2024
1 parent e78b99f commit 575e0ac
Show file tree
Hide file tree
Showing 13 changed files with 216 additions and 56 deletions.
7 changes: 5 additions & 2 deletions .buildkite/testing.yml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
steps:
- group: ":julia: CUDA GPU"
steps:
- label: ":julia: Julia {{matrix.julia}} + CUDA GPU"
- label: ":julia: Julia {{matrix.julia}} + CUDA GPU (Backend Group: {{matrix.group}})"
plugins:
- JuliaCI/julia#v1:
version: "{{matrix.julia}}"
Expand All @@ -16,13 +16,16 @@ steps:
queue: "juliagpu"
cuda: "*"
env:
BACKEND_GROUP: "CUDA"
BACKEND_GROUP: "{{matrix.group}}"
if: build.message !~ /\[skip tests\]/ && build.message !~ /\[skip ci\]/
timeout_in_minutes: 60
matrix:
setup:
julia:
- "1"
group:
- CUDA
- XLA

- group: ":telescope: Downstream CUDA"
steps:
Expand Down
13 changes: 9 additions & 4 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ concurrency:

jobs:
ci:
name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ github.event_name }}
name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.group }} - ${{ github.event_name }}
if: ${{ !contains(github.event.head_commit.message, '[skip tests]') }}
runs-on: ${{ matrix.os }}
strategy:
Expand All @@ -33,6 +33,12 @@ jobs:
- ubuntu-latest
- macos-latest
- windows-latest
group:
- CPU
- XLA
exclude:
- os: windows-latest
group: XLA
steps:
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v2
Expand All @@ -50,6 +56,8 @@ jobs:
${{ runner.os }}-
- uses: julia-actions/julia-buildpkg@v1
- uses: julia-actions/julia-runtest@v1
env:
GROUP: ${{ matrix.group }}
- uses: julia-actions/julia-processcoverage@v1
with:
directories: src,ext
Expand Down Expand Up @@ -171,6 +179,3 @@ jobs:
- name: Check if the PR does increase number of invalidations
if: steps.invs_pr.outputs.total > steps.invs_default.outputs.total
run: exit 1

env:
BACKEND_GROUP: "CPU"
4 changes: 2 additions & 2 deletions ext/MLDataDevicesMLUtilsExt.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
module MLDataDevicesMLUtilsExt

using MLDataDevices: MLDataDevices, AbstractDevice, CPUDevice, CUDADevice, AMDGPUDevice,
MetalDevice, oneAPIDevice, DeviceIterator
MetalDevice, oneAPIDevice, XLADevice, DeviceIterator
using MLUtils: MLUtils, DataLoader

for dev in (CPUDevice, CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice)
for dev in (CPUDevice, CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice, XLADevice)
@eval function (D::$(dev))(dataloader::DataLoader)
if dataloader.parallel
if dataloader.buffer
Expand Down
4 changes: 2 additions & 2 deletions ext/MLDataDevicesReactantExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module MLDataDevicesReactantExt

using Adapt: Adapt
using MLDataDevices: MLDataDevices, Internal, XLADevice, CPUDevice
using Reactant: Reactant, RArray, ConcreteRArray
using Reactant: Reactant, RArray

MLDataDevices.loaded(::Union{XLADevice, Type{<:XLADevice}}) = true
MLDataDevices.functional(::Union{XLADevice, Type{<:XLADevice}}) = true
Expand All @@ -21,6 +21,6 @@ Internal.get_device_type(::RArray) = XLADevice
Internal.unsafe_free_internal!(::Type{XLADevice}, x::AbstractArray) = nothing

# Device Transfer
Adapt.adapt_storage(::XLADevice, x::AbstractArray) = ConcreteRArray(x)
Adapt.adapt_storage(::XLADevice, x::AbstractArray) = Reactant.to_rarray(x)

end
14 changes: 8 additions & 6 deletions src/internal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,15 @@ for T in (CPUDevice, CUDADevice{Nothing}, AMDGPUDevice{Nothing},
@eval get_device_id(::$(T)) = nothing
end

struct DeviceSelectionException <: Exception end
struct DeviceSelectionException <: Exception
dev::String
end

function Base.showerror(io::IO, ::DeviceSelectionException)
return print(io, "DeviceSelectionException(No functional GPU device found!!)")
function Base.showerror(io::IO, d::DeviceSelectionException)
return print(io, "DeviceSelectionException: No functional $(d.dev) device found!")
end

function get_gpu_device(; force_gpu_usage::Bool)
function get_gpu_device(; force::Bool)
backend = load_preference(MLDataDevices, "gpu_backend", nothing)

# If backend set with preferences, use it
Expand Down Expand Up @@ -88,7 +90,7 @@ function get_gpu_device(; force_gpu_usage::Bool)
end
end

force_gpu_usage && throw(DeviceSelectionException())
force && throw(DeviceSelectionException("GPU"))
@warn """No functional GPU backend found! Defaulting to CPU.
1. If no GPU is available, nothing needs to be done.
Expand Down Expand Up @@ -147,7 +149,7 @@ for op in (:get_device, :get_device_type)
end
end

for T in (Number, AbstractRNG, Val, Symbol, String, Nothing)
for T in (Number, AbstractRNG, Val, Symbol, String, Nothing, AbstractRange)
@eval $(op)(::$(T)) = $(op == :get_device ? nothing : Nothing)
end
end
Expand Down
41 changes: 27 additions & 14 deletions src/public.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ supported_gpu_backends() = map(Internal.get_device_name, GPU_DEVICES)

"""
gpu_device(device_id::Union{Nothing, Integer}=nothing;
force_gpu_usage::Bool=false) -> AbstractDevice()
force::Bool=false) -> AbstractDevice
Selects GPU device based on the following criteria:
Expand All @@ -75,7 +75,7 @@ Selects GPU device based on the following criteria:
2. Otherwise, an automatic selection algorithm is used. We go over possible device
backends in the order specified by `supported_gpu_backends()` and select the first
functional backend.
3. If no GPU device is functional and `force_gpu_usage` is `false`, then `cpu_device()` is
3. If no GPU device is functional and `force` is `false`, then `cpu_device()` is
invoked.
4. If nothing works, an error is thrown.
Expand All @@ -102,17 +102,24 @@ Selects GPU device based on the following criteria:
## Keyword Arguments
- `force_gpu_usage::Bool`: If `true`, then an error is thrown if no functional GPU
- `force::Bool`: If `true`, then an error is thrown if no functional GPU
device is found.
"""
function gpu_device(device_id::Union{Nothing, <:Integer}=nothing;
force_gpu_usage::Bool=false)::AbstractDevice
function gpu_device(device_id::Union{Nothing, <:Integer}=nothing; force::Bool=false,
force_gpu_usage::Union{Missing, Bool}=missing)::AbstractDevice
if force_gpu_usage !== missing
Base.depwarn(
"`force_gpu_usage` is deprecated and will be removed in v2. Use \
`force` instead.", :gpu_device)
force = force_gpu_usage
end

device_id == 0 && throw(ArgumentError("`device_id` is 1-indexed."))

if GPU_DEVICE[] !== nothing
dev = GPU_DEVICE[]
if device_id === nothing
force_gpu_usage &&
force &&
!(dev isa AbstractGPUDevice) &&
throw(Internal.DeviceSelectionException())
return dev
Expand All @@ -122,7 +129,7 @@ function gpu_device(device_id::Union{Nothing, <:Integer}=nothing;
end
end

device_type = Internal.get_gpu_device(; force_gpu_usage)
device_type = Internal.get_gpu_device(; force)
device = Internal.with_device(device_type, device_id)
GPU_DEVICE[] = device

Expand Down Expand Up @@ -179,19 +186,25 @@ Return a `CPUDevice` object which can be used to transfer data to CPU.
cpu_device() = CPUDevice()

"""
xla_device() -> XLADevice()
xla_device(; force::Bool=false) -> Union{XLADevice, CPUDevice}
Return a `XLADevice` object.
Return a `XLADevice` object if functional. Otherwise, throw an error if `force` is `true`.
Falls back to `CPUDevice` if `force` is `false`.
!!! danger
This is an experimental feature and might change without deprecations
"""
function xla_device()
@assert loaded(XLADevice)&&functional(XLADevice) "`XLADevice` is not loaded or not \
functional. Load `Reactant.jl` \
before calling this function."
return XLADevice()
function xla_device(; force::Bool=false)
msg = "`XLADevice` is not loaded or not functional. Load `Reactant.jl` before calling \
this function. Defaulting to CPU."
if loaded(XLADevice)
functional(XLADevice) && return XLADevice()
msg = "`XLADevice` is loaded but not functional. Defaulting to CPU."
end
force && throw(Internal.DeviceSelectionException("XLA"))
@warn msg maxlog=1
return cpu_device()
end

"""
Expand Down
7 changes: 3 additions & 4 deletions test/amdgpu_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@ using ArrayInterface: parameterless_type
@test !MLDataDevices.functional(AMDGPUDevice)
@test cpu_device() isa CPUDevice
@test gpu_device() isa CPUDevice
@test_throws MLDataDevices.Internal.DeviceSelectionException gpu_device(;
force_gpu_usage=true)
@test_throws MLDataDevices.Internal.DeviceSelectionException gpu_device(; force=true)
@test_throws Exception default_device_rng(AMDGPUDevice(nothing))
@test_logs (:warn, "`AMDGPU.jl` hasn't been loaded. Ignoring the device setting.") MLDataDevices.set_device!(
AMDGPUDevice, nothing, 1)
Expand All @@ -20,12 +19,12 @@ using AMDGPU
if MLDataDevices.functional(AMDGPUDevice)
@info "AMDGPU is functional"
@test gpu_device() isa AMDGPUDevice
@test gpu_device(; force_gpu_usage=true) isa AMDGPUDevice
@test gpu_device(; force=true) isa AMDGPUDevice
else
@info "AMDGPU is NOT functional"
@test gpu_device() isa CPUDevice
@test_throws MLDataDevices.Internal.DeviceSelectionException gpu_device(;
force_gpu_usage=true)
force=true)
end
@test MLDataDevices.GPU_DEVICE[] !== nothing
end
Expand Down
7 changes: 3 additions & 4 deletions test/cuda_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@ using ArrayInterface: parameterless_type
@test !MLDataDevices.functional(CUDADevice)
@test cpu_device() isa CPUDevice
@test gpu_device() isa CPUDevice
@test_throws MLDataDevices.Internal.DeviceSelectionException gpu_device(;
force_gpu_usage=true)
@test_throws MLDataDevices.Internal.DeviceSelectionException gpu_device(; force=true)
@test_throws Exception default_device_rng(CUDADevice(nothing))
@test_logs (:warn, "`CUDA.jl` hasn't been loaded. Ignoring the device setting.") MLDataDevices.set_device!(
CUDADevice, nothing, 1)
Expand All @@ -20,12 +19,12 @@ using LuxCUDA
if MLDataDevices.functional(CUDADevice)
@info "LuxCUDA is functional"
@test gpu_device() isa CUDADevice
@test gpu_device(; force_gpu_usage=true) isa CUDADevice
@test gpu_device(; force=true) isa CUDADevice
else
@info "LuxCUDA is NOT functional"
@test gpu_device() isa CPUDevice
@test_throws MLDataDevices.Internal.DeviceSelectionException gpu_device(;
force_gpu_usage=true)
force=true)
end
@test MLDataDevices.GPU_DEVICE[] !== nothing
end
Expand Down
34 changes: 24 additions & 10 deletions test/iterator_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,18 @@ if BACKEND_GROUP == "oneapi" || BACKEND_GROUP == "all"
using oneAPI
end

DEVICES = [CPUDevice, CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice]
if BACKEND_GROUP == "xla" || BACKEND_GROUP == "all"
using Reactant
if "gpu" in keys(Reactant.XLA.backends)
Reactant.set_default_backend("gpu")
end
end

DEVICES = [CPUDevice, CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice, XLADevice]

freed_if_can_be_freed(x) = freed_if_can_be_freed(get_device_type(x), x)
freed_if_can_be_freed(::Type{CPUDevice}, x) = true
freed_if_can_be_freed(::Type{XLADevice}, x) = true
function freed_if_can_be_freed(::Type, x)
try
Array(x)
Expand Down Expand Up @@ -53,17 +61,20 @@ end

@testset "DataLoader: parallel=$parallel" for parallel in (true, false)
X = rand(Float64, 3, 33)
pre = DataLoader(dev(X); batchsize=13, shuffle=false)
post = DataLoader(X; batchsize=13, shuffle=false) |> dev
pre = DataLoader(dev(X); batchsize=13, shuffle=false, parallel)
post = DataLoader(X; batchsize=13, shuffle=false, parallel) |> dev

for epoch in 1:2
prev_pre, prev_post = nothing, nothing
for (p, q) in zip(pre, post)
@test get_device_type(p) == dev_type
@test get_device_type(q) == dev_type
@test p q
# Ordering is not guaranteed in parallel
!parallel && @test p q

dev_type === CPUDevice && continue
if dev_type === CPUDevice || dev_type === XLADevice
continue
end

prev_pre === nothing || @test !freed_if_can_be_freed(prev_pre)
prev_pre = p
Expand All @@ -74,8 +85,8 @@ end
end

Y = rand(Float64, 1, 33)
pre = DataLoader((; x=dev(X), y=dev(Y)); batchsize=13, shuffle=false)
post = DataLoader((; x=X, y=Y); batchsize=13, shuffle=false) |> dev
pre = DataLoader((; x=dev(X), y=dev(Y)); batchsize=13, shuffle=false, parallel)
post = DataLoader((; x=X, y=Y); batchsize=13, shuffle=false, parallel) |> dev

for epoch in 1:2
prev_pre, prev_post = nothing, nothing
Expand All @@ -84,10 +95,13 @@ end
@test get_device_type(p.y) == dev_type
@test get_device_type(q.x) == dev_type
@test get_device_type(q.y) == dev_type
@test p.x q.x
@test p.y q.y
# Ordering is not guaranteed in parallel
!parallel && @test p.x q.x
!parallel && @test p.y q.y

dev_type === CPUDevice && continue
if dev_type === CPUDevice || dev_type === XLADevice
continue
end

if prev_pre !== nothing
@test !freed_if_can_be_freed(prev_pre.x)
Expand Down
7 changes: 3 additions & 4 deletions test/metal_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@ using ArrayInterface: parameterless_type
@test !MLDataDevices.functional(MetalDevice)
@test cpu_device() isa CPUDevice
@test gpu_device() isa CPUDevice
@test_throws MLDataDevices.Internal.DeviceSelectionException gpu_device(;
force_gpu_usage=true)
@test_throws MLDataDevices.Internal.DeviceSelectionException gpu_device(; force=true)
@test_throws Exception default_device_rng(MetalDevice())
end

Expand All @@ -18,12 +17,12 @@ using Metal
if MLDataDevices.functional(MetalDevice)
@info "Metal is functional"
@test gpu_device() isa MetalDevice
@test gpu_device(; force_gpu_usage=true) isa MetalDevice
@test gpu_device(; force=true) isa MetalDevice
else
@info "Metal is NOT functional"
@test gpu_device() isa MetalDevice
@test_throws MLDataDevices.Internal.DeviceSelectionException gpu_device(;
force_gpu_usage=true)
force=true)
end
@test MLDataDevices.GPU_DEVICE[] !== nothing
end
Expand Down
7 changes: 3 additions & 4 deletions test/oneapi_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@ using ArrayInterface: parameterless_type
@test !MLDataDevices.functional(oneAPIDevice)
@test cpu_device() isa CPUDevice
@test gpu_device() isa CPUDevice
@test_throws MLDataDevices.Internal.DeviceSelectionException gpu_device(;
force_gpu_usage=true)
@test_throws MLDataDevices.Internal.DeviceSelectionException gpu_device(; force=true)
@test_throws Exception default_device_rng(oneAPIDevice())
end

Expand All @@ -18,12 +17,12 @@ using oneAPI
if MLDataDevices.functional(oneAPIDevice)
@info "oneAPI is functional"
@test gpu_device() isa oneAPIDevice
@test gpu_device(; force_gpu_usage=true) isa oneAPIDevice
@test gpu_device(; force=true) isa oneAPIDevice
else
@info "oneAPI is NOT functional"
@test gpu_device() isa oneAPIDevice
@test_throws MLDataDevices.Internal.DeviceSelectionException gpu_device(;
force_gpu_usage=true)
force=true)
end
@test MLDataDevices.GPU_DEVICE[] !== nothing
end
Expand Down
Loading

0 comments on commit 575e0ac

Please sign in to comment.