From 7891f843ff2eb55ded01e5fdb8eb1c8a292f1532 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 3 Nov 2024 18:42:57 -0500 Subject: [PATCH] refactor: rename XLADevice/xla_device to ReactantDevice/reactant_device --- Project.toml | 4 +- .../api/Accelerator_Support/MLDataDevices.md | 2 +- docs/src/index.md | 6 +- docs/src/manual/compiling_lux_models.md | 8 +-- lib/LuxLib/Project.toml | 4 +- lib/LuxLib/src/impl/Impl.jl | 2 +- lib/LuxLib/src/impl/conv.jl | 2 +- lib/MLDataDevices/README.md | 2 +- .../ext/MLDataDevicesMLUtilsExt.jl | 4 +- .../ext/MLDataDevicesReactantExt.jl | 20 +++---- lib/MLDataDevices/src/MLDataDevices.jl | 3 +- lib/MLDataDevices/src/internal.jl | 8 +-- lib/MLDataDevices/src/public.jl | 33 ++++++----- lib/MLDataDevices/test/iterator_tests.jl | 12 ++-- lib/MLDataDevices/test/xla_tests.jl | 56 +++++++++---------- src/helpers/training.jl | 6 +- test/reactant/training_tests.jl | 2 +- 17 files changed, 90 insertions(+), 84 deletions(-) diff --git a/Project.toml b/Project.toml index 3eaa8de65e..f584e52c1b 100644 --- a/Project.toml +++ b/Project.toml @@ -87,8 +87,8 @@ GPUArraysCore = "0.1.6, 0.2" LinearAlgebra = "1.10" LossFunctions = "0.11.1" LuxCore = "1" -LuxLib = "1.3.4" -MLDataDevices = "1.3" +LuxLib = "1.3.7" +MLDataDevices = "1.5" MLUtils = "0.4.4" MPI = "0.20.19" MacroTools = "0.5.13" diff --git a/docs/src/api/Accelerator_Support/MLDataDevices.md b/docs/src/api/Accelerator_Support/MLDataDevices.md index 5abc0dd6a6..6f351e5d1e 100644 --- a/docs/src/api/Accelerator_Support/MLDataDevices.md +++ b/docs/src/api/Accelerator_Support/MLDataDevices.md @@ -18,7 +18,7 @@ MLDataDevices.gpu_backend! ```@docs MLDataDevices.cpu_device MLDataDevices.gpu_device -MLDataDevices.xla_device +MLDataDevices.reactant_device ``` ## Miscellaneous diff --git a/docs/src/index.md b/docs/src/index.md index bd40ff104c..74b19619d7 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -142,21 +142,21 @@ Run the following to access a device: using Reactant, Lux Reactant.set_default_backend("cpu") # default -const dev = xla_device() +const dev = reactant_device() ``` ```julia [GPU Backend] using Reactant, Lux Reactant.set_default_backend("gpu") -const dev = xla_device() +const dev = reactant_device() ``` ```julia [TPU Backend] using Reactant, Lux Reactant.set_default_backend("tpu") -const dev = xla_device() +const dev = reactant_device() ``` ::: diff --git a/docs/src/manual/compiling_lux_models.md b/docs/src/manual/compiling_lux_models.md index 65cdb79417..44e7fc27c3 100644 --- a/docs/src/manual/compiling_lux_models.md +++ b/docs/src/manual/compiling_lux_models.md @@ -45,11 +45,11 @@ x = randn(Float32, 2, 32) y = x .^ 2 ``` -We will use [`xla_device`](@ref) similar to [`gpu_device`](@ref) to move the arrays to +We will use [`reactant_device`](@ref) similar to [`gpu_device`](@ref) to move the arrays to `Reactant`. ```@example compile_lux_model -const xdev = xla_device() +const xdev = reactant_device() x_ra = x |> xdev y_ra = y |> xdev @@ -66,7 +66,7 @@ pred_lux, _ = model(x, ps, Lux.testmode(st)) To run it using `XLA` we need to compile the model. We can do this using the `Reactant.@compile` macro. Note that the inputs need to be moved to the device using -[`xla_device`](@ref) first. +[`reactant_device`](@ref) first. ```@example compile_lux_model model_compiled = @compile model(x_ra, ps_ra, Lux.testmode(st_ra)) @@ -122,7 +122,7 @@ fmap(Broadcast.BroadcastFunction(-), ∂ps_zyg, ∂ps_enzyme) Now that we saw the low-level API let's see how to train the model without any of this boilerplate. Simply follow the following steps: -1. Create a device using `xla_device`. Remember to load `Reactant.jl` before doing this. +1. Create a device using `reactant_device`. Remember to load `Reactant.jl` before doing this. 2. Similar to other device functions move the model, parameters, states and data to the device. Note that you might want to use [`DeviceIterator`](@ref) to move the data loader to the device with an iterator. diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index a053be0706..037e5c65be 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -1,7 +1,7 @@ name = "LuxLib" uuid = "82251201-b29d-42c6-8e01-566dec8acb11" authors = ["Avik Pal and contributors"] -version = "1.3.6" +version = "1.3.7" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" @@ -75,7 +75,7 @@ LinearAlgebra = "1.10" LoopVectorization = "0.12.171" LuxCore = "1" MKL = "0.7" -MLDataDevices = "1.2" +MLDataDevices = "1.5" Markdown = "1.10" NNlib = "0.9.24" Octavian = "0.3.28" diff --git a/lib/LuxLib/src/impl/Impl.jl b/lib/LuxLib/src/impl/Impl.jl index 3bd59797d0..df6ad01d5e 100644 --- a/lib/LuxLib/src/impl/Impl.jl +++ b/lib/LuxLib/src/impl/Impl.jl @@ -19,7 +19,7 @@ using Random: Random, AbstractRNG, rand! using Statistics: Statistics, mean, var using LuxCore: LuxCore -using MLDataDevices: get_device_type, CPUDevice, AMDGPUDevice, CUDADevice, XLADevice, +using MLDataDevices: get_device_type, CPUDevice, AMDGPUDevice, CUDADevice, ReactantDevice, AbstractGPUDevice, AbstractDevice using NNlib: NNlib, ConvDims diff --git a/lib/LuxLib/src/impl/conv.jl b/lib/LuxLib/src/impl/conv.jl index 3a3d22ee3d..eda9eda134 100644 --- a/lib/LuxLib/src/impl/conv.jl +++ b/lib/LuxLib/src/impl/conv.jl @@ -74,7 +74,7 @@ end conv(x, weight, cdims::ConvDims) = conv(get_device_type((x, weight)), x, weight, cdims) -function conv(::Type{<:Union{CPUDevice, CUDADevice, AMDGPUDevice, XLADevice}}, +function conv(::Type{<:Union{CPUDevice, CUDADevice, AMDGPUDevice, ReactantDevice}}, x′, weight′, cdims::ConvDims) x, weight = get_conv_input_weight(x′, weight′) return NNlib.conv(x, weight, cdims) diff --git a/lib/MLDataDevices/README.md b/lib/MLDataDevices/README.md index 2fda26602f..34a2d0ef76 100644 --- a/lib/MLDataDevices/README.md +++ b/lib/MLDataDevices/README.md @@ -10,7 +10,7 @@ Currently we provide support for the following backends: 3. `AMDGPUDevice`: `AMDGPU.jl` for AMD ROCM GPUs. 4. `MetalDevice`: `Metal.jl` for Apple Metal GPUs. **(Experimental)** 5. `oneAPIDevice`: `oneAPI.jl` for Intel GPUs. **(Experimental)** -6. `XLADevice`: `Reactant.jl` for XLA Support. **(Experimental)** +6. `ReactantDevice`: `Reactant.jl` for XLA Support. **(Experimental)** ## Updating to v1.0 diff --git a/lib/MLDataDevices/ext/MLDataDevicesMLUtilsExt.jl b/lib/MLDataDevices/ext/MLDataDevicesMLUtilsExt.jl index be3d285b07..d4d1064579 100644 --- a/lib/MLDataDevices/ext/MLDataDevicesMLUtilsExt.jl +++ b/lib/MLDataDevices/ext/MLDataDevicesMLUtilsExt.jl @@ -1,10 +1,10 @@ module MLDataDevicesMLUtilsExt using MLDataDevices: MLDataDevices, AbstractDevice, CPUDevice, CUDADevice, AMDGPUDevice, - MetalDevice, oneAPIDevice, XLADevice, DeviceIterator + MetalDevice, oneAPIDevice, ReactantDevice, DeviceIterator using MLUtils: MLUtils, DataLoader -for dev in (CPUDevice, CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice, XLADevice) +for dev in (CPUDevice, CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice, ReactantDevice) @eval function (D::$(dev))(dataloader::DataLoader) if dataloader.parallel if dataloader.buffer diff --git a/lib/MLDataDevices/ext/MLDataDevicesReactantExt.jl b/lib/MLDataDevices/ext/MLDataDevicesReactantExt.jl index a62f87aa10..2e53362c67 100644 --- a/lib/MLDataDevices/ext/MLDataDevicesReactantExt.jl +++ b/lib/MLDataDevices/ext/MLDataDevicesReactantExt.jl @@ -1,7 +1,7 @@ module MLDataDevicesReactantExt using Adapt: Adapt -using MLDataDevices: MLDataDevices, Internal, XLADevice, CPUDevice, get_device_type +using MLDataDevices: MLDataDevices, Internal, ReactantDevice, CPUDevice, get_device_type using Reactant: Reactant, XLA, RArray, ConcreteRArray, TracedRArray, TracedRNumber @static if isdefined(Reactant, :ConcreteRNumber) @@ -14,11 +14,11 @@ else const ReactantType = Union{RArray, TracedRArray, TracedRNumber} end -MLDataDevices.loaded(::Union{XLADevice, Type{<:XLADevice}}) = true -MLDataDevices.functional(::Union{XLADevice, Type{<:XLADevice}}) = true +MLDataDevices.loaded(::Union{ReactantDevice, Type{<:ReactantDevice}}) = true +MLDataDevices.functional(::Union{ReactantDevice, Type{<:ReactantDevice}}) = true # Default RNG: Forward to CPU, we will compile it -function MLDataDevices.default_device_rng(::XLADevice) +function MLDataDevices.default_device_rng(::ReactantDevice) return MLDataDevices.default_device_rng(CPUDevice()) end @@ -26,26 +26,26 @@ end function Internal.get_device(x::ConcreteRType) client = XLA.client(x.data) device = XLA.device(x.data) - return XLADevice(client, device) + return ReactantDevice(client, device) end function Internal.get_device(::Union{TracedRArray, TracedRNumber}) error("`get_device` isn't meant to be called inside `Reactant.@compile` context.") end -Internal.get_device_type(::ReactantType) = XLADevice +Internal.get_device_type(::ReactantType) = ReactantDevice # unsafe_free! -Internal.unsafe_free_internal!(::Type{XLADevice}, x::AbstractArray) = nothing +Internal.unsafe_free_internal!(::Type{ReactantDevice}, x::AbstractArray) = nothing # Device Transfer -function Adapt.adapt_storage(dev::XLADevice, x::AbstractArray{<:Reactant.ReactantPrimitive}) - @warn "XLADevice got an array on device: $(get_device_type(x)). We will have to \ +function Adapt.adapt_storage(dev::ReactantDevice, x::AbstractArray{<:Reactant.ReactantPrimitive}) + @warn "ReactantDevice got an array on device: $(get_device_type(x)). We will have to \ transfer this via CPU." maxlog=1 return Adapt.adapt_storage(dev, Adapt.adapt_storage(CPUDevice(), x)) end -function Adapt.adapt_storage(dev::XLADevice, x::Array{<:Reactant.ReactantPrimitive}) +function Adapt.adapt_storage(dev::ReactantDevice, x::Array{<:Reactant.ReactantPrimitive}) client = dev.client === missing ? XLA.default_backend[] : dev.client device = dev.device === missing ? XLA.ClientGetDevice(client, XLA.default_device_idx[]) : dev.device diff --git a/lib/MLDataDevices/src/MLDataDevices.jl b/lib/MLDataDevices/src/MLDataDevices.jl index af528d6d3c..04d9111c03 100644 --- a/lib/MLDataDevices/src/MLDataDevices.jl +++ b/lib/MLDataDevices/src/MLDataDevices.jl @@ -17,7 +17,8 @@ include("internal.jl") export gpu_backend!, supported_gpu_backends, reset_gpu_device! export default_device_rng -export gpu_device, cpu_device, xla_device +export gpu_device, cpu_device +export xla_device, reactant_device export CPUDevice export CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice diff --git a/lib/MLDataDevices/src/internal.jl b/lib/MLDataDevices/src/internal.jl index 5da37ac20b..45e0a066ce 100644 --- a/lib/MLDataDevices/src/internal.jl +++ b/lib/MLDataDevices/src/internal.jl @@ -5,7 +5,7 @@ using Preferences: load_preference using Random: AbstractRNG using ..MLDataDevices: MLDataDevices, AbstractDevice, CPUDevice, CUDADevice, AMDGPUDevice, - MetalDevice, oneAPIDevice, XLADevice, UnknownDevice, + MetalDevice, oneAPIDevice, ReactantDevice, UnknownDevice, supported_gpu_backends, GPU_DEVICES, loaded, functional for dev in (CPUDevice, MetalDevice, oneAPIDevice) @@ -27,11 +27,11 @@ for name in (:CPU, :CUDA, :AMDGPU, :Metal, :oneAPI) get_triggerpkg_name(::Union{$ldev, Type{<:$ldev}}) = $(tpkg) end end -get_device_name(::XLADevice) = "XLA" -get_triggerpkg_name(::XLADevice) = "Reactant" +get_device_name(::ReactantDevice) = "XLA" +get_triggerpkg_name(::ReactantDevice) = "Reactant" for T in (CPUDevice, CUDADevice{Nothing}, AMDGPUDevice{Nothing}, - MetalDevice, oneAPIDevice, XLADevice) + MetalDevice, oneAPIDevice, ReactantDevice) @eval get_device_id(::$(T)) = nothing end diff --git a/lib/MLDataDevices/src/public.jl b/lib/MLDataDevices/src/public.jl index 50016b431b..d8c193ae36 100644 --- a/lib/MLDataDevices/src/public.jl +++ b/lib/MLDataDevices/src/public.jl @@ -9,12 +9,12 @@ end struct MetalDevice <: AbstractGPUDevice end struct oneAPIDevice <: AbstractGPUDevice end -@kwdef struct XLADevice{C, D} <: AbstractAcceleratorDevice +@kwdef struct ReactantDevice{C, D} <: AbstractAcceleratorDevice client::C = missing device::D = missing end -function Base.:(==)(x::XLADevice, y::XLADevice) +function Base.:(==)(x::ReactantDevice, y::ReactantDevice) if x.client !== missing y.client === missing && return false x.client.client != y.client.client && return false @@ -30,6 +30,9 @@ function Base.:(==)(x::XLADevice, y::XLADevice) return true end +# XXX: Deprecate in v2 +const XLADevice = ReactantDevice + # Fallback for when we don't know the device type struct UnknownDevice <: AbstractDevice end @@ -207,11 +210,11 @@ Return a `CPUDevice` object which can be used to transfer data to CPU. cpu_device() = CPUDevice() """ - xla_device(; + reactant_device(; force::Bool=false, client=missing, device=missing - ) -> Union{XLADevice, CPUDevice} + ) -> Union{ReactantDevice, CPUDevice} -Return a `XLADevice` object if functional. Otherwise, throw an error if `force` is `true`. +Return a `ReactantDevice` object if functional. Otherwise, throw an error if `force` is `true`. Falls back to `CPUDevice` if `force` is `false`. `client` and `device` are used to specify the client and particular device to use. If not @@ -221,18 +224,20 @@ specified, then the default client and index are used. This is an experimental feature and might change without deprecations """ -function xla_device(; force::Bool=false, client=missing, device=missing) - msg = "`XLADevice` is not loaded or not functional. Load `Reactant.jl` before calling \ - this function. Defaulting to CPU." - if loaded(XLADevice) - functional(XLADevice) && return XLADevice(client, device) - msg = "`XLADevice` is loaded but not functional. Defaulting to CPU." +function reactant_device(; force::Bool=false, client=missing, device=missing) + msg = "`ReactantDevice` is not loaded or not functional. Load `Reactant.jl` before \ + calling this function. Defaulting to CPU." + if loaded(ReactantDevice) + functional(ReactantDevice) && return ReactantDevice(client, device) + msg = "`ReactantDevice` is loaded but not functional. Defaulting to CPU." end force && throw(Internal.DeviceSelectionException("XLA")) @warn msg maxlog=1 return cpu_device() end +Base.@deprecate xla_device(; kwargs...) reactant_device(; kwargs...) + """ default_device_rng(::AbstractDevice) @@ -335,8 +340,8 @@ function set_device!(::Type{T}, dev_or_id) where {T <: AbstractDevice} @warn "Support for Multi Device oneAPI hasn't been implemented yet. Ignoring the device setting." T === CPUDevice && @warn "Setting device for `CPUDevice` doesn't make sense. Ignoring the device setting." - T === XLADevice && - @warn "Setting device for `XLADevice` hasn't been implemented yet. Ignoring the device setting." + T === ReactantDevice && + @warn "Setting device for `ReactantDevice` hasn't been implemented yet. Ignoring the device setting." return end @@ -389,7 +394,7 @@ end Adapt.adapt_storage(::CPUDevice, x::AbstractArray) = Adapt.adapt(Array, x) Adapt.adapt_storage(::CPUDevice, rng::AbstractRNG) = rng -for T in (AMDGPUDevice, CUDADevice, MetalDevice, oneAPIDevice, XLADevice) +for T in (AMDGPUDevice, CUDADevice, MetalDevice, oneAPIDevice, ReactantDevice) @eval begin function Adapt.adapt_storage(to::$(T), ::Random.TaskLocalRNG) return default_device_rng(to) diff --git a/lib/MLDataDevices/test/iterator_tests.jl b/lib/MLDataDevices/test/iterator_tests.jl index 132acd7deb..24267a7ba1 100644 --- a/lib/MLDataDevices/test/iterator_tests.jl +++ b/lib/MLDataDevices/test/iterator_tests.jl @@ -25,11 +25,11 @@ if BACKEND_GROUP == "xla" || BACKEND_GROUP == "all" end end -DEVICES = [CPUDevice, CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice, XLADevice] +DEVICES = [CPUDevice, CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice, ReactantDevice] freed_if_can_be_freed(x) = freed_if_can_be_freed(get_device_type(x), x) freed_if_can_be_freed(::Type{CPUDevice}, x) = true -freed_if_can_be_freed(::Type{XLADevice}, x) = true +freed_if_can_be_freed(::Type{ReactantDevice}, x) = true function freed_if_can_be_freed(::Type, x) try Array(x) @@ -63,7 +63,7 @@ end @info "Testing DataLoader with parallel=$parallel" X = rand(Float64, 3, 33) post = DataLoader(X; batchsize=13, shuffle=false, parallel) |> dev - if dev_type === XLADevice + if dev_type === ReactantDevice pre = post # XXX: deadlocks and other shenanigans else pre = DataLoader(dev(X); batchsize=13, shuffle=false, parallel) @@ -77,7 +77,7 @@ end # Ordering is not guaranteed in parallel !parallel && @test p ≈ q - if dev_type === CPUDevice || dev_type === XLADevice + if dev_type === CPUDevice || dev_type === ReactantDevice continue end @@ -91,7 +91,7 @@ end Y = rand(Float64, 1, 33) post = DataLoader((; x=X, y=Y); batchsize=13, shuffle=false, parallel) |> dev - if dev_type === XLADevice + if dev_type === ReactantDevice pre = post # XXX: deadlocks and other shenanigans else pre = DataLoader((; x=dev(X), y=dev(Y)); batchsize=13, shuffle=false, parallel) @@ -108,7 +108,7 @@ end !parallel && @test p.x ≈ q.x !parallel && @test p.y ≈ q.y - if dev_type === CPUDevice || dev_type === XLADevice + if dev_type === CPUDevice || dev_type === ReactantDevice continue end diff --git a/lib/MLDataDevices/test/xla_tests.jl b/lib/MLDataDevices/test/xla_tests.jl index 21466bd1d2..ee741411f3 100644 --- a/lib/MLDataDevices/test/xla_tests.jl +++ b/lib/MLDataDevices/test/xla_tests.jl @@ -2,11 +2,11 @@ using MLDataDevices, Random, Test using ArrayInterface: parameterless_type @testset "CPU Fallback" begin - @test !MLDataDevices.functional(XLADevice) + @test !MLDataDevices.functional(ReactantDevice) @test cpu_device() isa CPUDevice - @test xla_device() isa CPUDevice - @test_throws MLDataDevices.Internal.DeviceSelectionException xla_device(; force=true) - @test_throws Exception default_device_rng(XLADevice()) + @test reactant_device() isa CPUDevice + @test_throws MLDataDevices.Internal.DeviceSelectionException reactant_device(; force=true) + @test_throws Exception default_device_rng(ReactantDevice()) end using Reactant @@ -15,14 +15,14 @@ if "gpu" in keys(Reactant.XLA.backends) end @testset "Loaded Trigger Package" begin - if MLDataDevices.functional(XLADevice) + if MLDataDevices.functional(ReactantDevice) @info "Reactant is functional" - @test xla_device() isa XLADevice - @test xla_device(; force=true) isa XLADevice + @test reactant_device() isa ReactantDevice + @test reactant_device(; force=true) isa ReactantDevice else @info "Reactant is NOT functional" - @test xla_device() isa CPUDevice - @test_throws MLDataDevices.Internal.DeviceSelectionException xla_device(; + @test reactant_device() isa CPUDevice + @test_throws MLDataDevices.Internal.DeviceSelectionException reactant_device(; force=true) end end @@ -36,13 +36,13 @@ using FillArrays, Zygote # Extensions rng_default=Random.default_rng(), rng=MersenneTwister(), one_elem=Zygote.OneElement(2.0f0, (2, 3), (1:3, 1:4)), farray=Fill(1.0f0, (2, 3))) - device = xla_device() - aType = MLDataDevices.functional(XLADevice) ? Reactant.ConcreteRArray : Array + device = reactant_device() + aType = MLDataDevices.functional(ReactantDevice) ? Reactant.ConcreteRArray : Array rngType = Random.AbstractRNG ps_xpu = ps |> device - @test get_device(ps_xpu) isa XLADevice - @test get_device_type(ps_xpu) <: XLADevice + @test get_device(ps_xpu) isa ReactantDevice + @test get_device_type(ps_xpu) <: ReactantDevice @test ps_xpu.a.c isa aType @test ps_xpu.b isa aType @test ps_xpu.a.d == ps.a.d @@ -60,7 +60,7 @@ using FillArrays, Zygote # Extensions @test get_device(ps_xpu.rng) === nothing @test get_device_type(ps_xpu.rng) <: Nothing - if MLDataDevices.functional(XLADevice) + if MLDataDevices.functional(ReactantDevice) @test ps_xpu.one_elem isa Reactant.RArray @test ps_xpu.farray isa Reactant.RArray else @@ -90,7 +90,7 @@ using FillArrays, Zygote # Extensions @test get_device(ps_cpu.rng) === nothing @test get_device_type(ps_cpu.rng) <: Nothing - if MLDataDevices.functional(XLADevice) + if MLDataDevices.functional(ReactantDevice) @test ps_cpu.one_elem isa Array @test ps_cpu.farray isa Array else @@ -115,7 +115,7 @@ using FillArrays, Zygote # Extensions end @testset "Functions" begin - if MLDataDevices.functional(XLADevice) + if MLDataDevices.functional(ReactantDevice) @test get_device(tanh) isa MLDataDevices.UnknownDevice @test get_device_type(tanh) <: MLDataDevices.UnknownDevice @@ -125,9 +125,9 @@ end @test get_device(ff) isa CPUDevice @test get_device_type(ff) <: CPUDevice - ff_xpu = ff |> XLADevice() - @test get_device(ff_xpu) isa XLADevice - @test get_device_type(ff_xpu) <: XLADevice + ff_xpu = ff |> ReactantDevice() + @test get_device(ff_xpu) isa ReactantDevice + @test get_device_type(ff_xpu) <: ReactantDevice ff_cpu = ff_xpu |> cpu_device() @test get_device(ff_cpu) isa CPUDevice @@ -136,20 +136,20 @@ end end @testset "Wrapped Arrays" begin - if MLDataDevices.functional(XLADevice) - x = rand(10, 10) |> XLADevice() - @test get_device(x) isa XLADevice - @test get_device_type(x) <: XLADevice + if MLDataDevices.functional(ReactantDevice) + x = rand(10, 10) |> ReactantDevice() + @test get_device(x) isa ReactantDevice + @test get_device_type(x) <: ReactantDevice x_view = view(x, 1:5, 1:5) - @test get_device(x_view) isa XLADevice - @test get_device_type(x_view) <: XLADevice + @test get_device(x_view) isa ReactantDevice + @test get_device_type(x_view) <: ReactantDevice end end @testset "setdevice!" begin - if MLDataDevices.functional(XLADevice) + if MLDataDevices.functional(ReactantDevice) @test_logs (:warn, - "Setting device for `XLADevice` hasn't been implemented yet. Ignoring the device setting.") MLDataDevices.set_device!( - XLADevice, nothing, 1) + "Setting device for `ReactantDevice` hasn't been implemented yet. Ignoring the device setting.") MLDataDevices.set_device!( + ReactantDevice, nothing, 1) end end diff --git a/src/helpers/training.jl b/src/helpers/training.jl index c0e6644ffe..79ec6eaf24 100644 --- a/src/helpers/training.jl +++ b/src/helpers/training.jl @@ -10,7 +10,7 @@ using Static: StaticBool, Static, False, True using ..Lux: Lux using LuxCore: LuxCore, AbstractLuxLayer -using MLDataDevices: XLADevice, get_device_type, get_device, cpu_device +using MLDataDevices: ReactantDevice, get_device_type, get_device, cpu_device """ TrainState @@ -63,7 +63,7 @@ Constructor for [`TrainState`](@ref). """ function TrainState(model::AbstractLuxLayer, ps, st, optimizer::Optimisers.AbstractRule) dev = get_device(ps) - st_opt = if dev isa XLADevice + st_opt = if dev isa ReactantDevice ps_cpu = ps |> cpu_device() Optimisers.setup(optimizer, ps_cpu) |> dev else @@ -199,7 +199,7 @@ end maybe_wrap_adtype(backend::ReactantBackend, _) = backend maybe_wrap_adtype(ad::AbstractADType, _) = ad -function maybe_wrap_adtype(ad::AbstractADType, ::Type{XLADevice}) +function maybe_wrap_adtype(ad::AbstractADType, ::Type{ReactantDevice}) ad isa AutoEnzyme && return ReactantBackend() throw(ArgumentError("Computing gradients for models on XLA is supported only with \ Enzyme.jl (`AutoEnzyme`).")) diff --git a/test/reactant/training_tests.jl b/test/reactant/training_tests.jl index cc9d0f9b07..73926a7eed 100644 --- a/test/reactant/training_tests.jl +++ b/test/reactant/training_tests.jl @@ -13,7 +13,7 @@ Reactant.set_default_backend("cpu") end - xdev = xla_device(; force=true) + xdev = reactant_device(; force=true) @testset "MLP Training: $(version)" for version in (:iip, :oop) model = Chain(