diff --git a/lib/MLDataDevices/Project.toml b/lib/MLDataDevices/Project.toml index e04a022566..e018a66db6 100644 --- a/lib/MLDataDevices/Project.toml +++ b/lib/MLDataDevices/Project.toml @@ -1,7 +1,7 @@ name = "MLDataDevices" uuid = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" authors = ["Avik Pal and contributors"] -version = "1.7.0" +version = "1.8.0" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/lib/MLDataDevices/ext/MLDataDevicesChainRulesCoreExt.jl b/lib/MLDataDevices/ext/MLDataDevicesChainRulesCoreExt.jl index 900e233411..2c6e25344b 100644 --- a/lib/MLDataDevices/ext/MLDataDevicesChainRulesCoreExt.jl +++ b/lib/MLDataDevices/ext/MLDataDevicesChainRulesCoreExt.jl @@ -23,7 +23,10 @@ function ChainRulesCore.rrule( return y, ∇adapt_storage_unknown else ∇adapt_storage = let dev = dev, x = x - Δ -> (NoTangent(), NoTangent(), ProjectTo(x)(dev(Δ))) + Δ -> begin + ∂x = ChainRulesCore.@thunk ProjectTo(x)(dev(ChainRulesCore.unthunk(Δ))) + return NoTangent(), NoTangent(), ∂x + end end return Adapt.adapt_storage(to, x), ∇adapt_storage end diff --git a/lib/MLDataDevices/ext/MLDataDevicesReactantExt.jl b/lib/MLDataDevices/ext/MLDataDevicesReactantExt.jl index 3a4ba7f787..f81c457cb9 100644 --- a/lib/MLDataDevices/ext/MLDataDevicesReactantExt.jl +++ b/lib/MLDataDevices/ext/MLDataDevicesReactantExt.jl @@ -3,30 +3,43 @@ module MLDataDevicesReactantExt using Adapt: Adapt using MLDataDevices: MLDataDevices, Internal, ReactantDevice, CPUDevice using Random: Random -using Reactant: Reactant, XLA, ConcreteRArray, ConcreteRNumber, TracedRArray, - TracedRNumber +using Reactant: Reactant, ConcreteRArray, ConcreteRNumber, TracedRArray, TracedRNumber MLDataDevices.loaded(::Union{ReactantDevice, Type{<:ReactantDevice}}) = true MLDataDevices.functional(::Union{ReactantDevice, Type{<:ReactantDevice}}) = true # Default RNG -function MLDataDevices.default_device_rng(::ReactantDevice) - return Reactant.TracedRandom.default_rng() -end +MLDataDevices.default_device_rng(::ReactantDevice) = Reactant.TracedRandom.default_rng() # Query Device from Array -function Internal.get_device(x::Union{ConcreteRNumber, ConcreteRArray}) - return ReactantDevice(XLA.client(x), XLA.device(x)) +@static if isdefined(Reactant, :ConcreteIFRTArray) + const AllConcreteTypes = Union{ + Reactant.ConcreteIFRTNumber, Reactant.ConcreteIFRTArray, + Reactant.ConcretePJRTNumber, Reactant.ConcretePJRTArray + } +elseif isdefined(Reactant, :ConcretePJRTArray) + const AllConcreteTypes = Union{ + Reactant.ConcretePJRTNumber, Reactant.ConcretePJRTArray + } +else + const AllConcreteTypes = Union{ConcreteRNumber, ConcreteRArray} end -function Internal.get_device(::Union{TracedRArray, TracedRNumber}) - error("`get_device` isn't meant to be called inside `Reactant.@compile` context.") +function Internal.get_device(x::AllConcreteTypes) + return ReactantDevice( + Reactant.XLA.client(x), + Reactant.XLA.device(x), + IdDict{AllConcreteTypes, Reactant.Sharding.AbstractSharding}( + x => x.sharding.sharding + ) + ) end +Internal.get_device_type(::AllConcreteTypes) = ReactantDevice -function Internal.get_device_type( - ::Union{TracedRArray, TracedRNumber, ConcreteRArray, ConcreteRNumber}) - return ReactantDevice +function Internal.get_device(::Union{TracedRArray, TracedRNumber}) + error("`get_device` isn't meant to be called inside `Reactant.@compile` context.") end +Internal.get_device_type(::Union{TracedRArray, TracedRNumber}) = ReactantDevice # unsafe_free! Internal.unsafe_free_internal!(::Type{ReactantDevice}, x::AbstractArray) = nothing @@ -36,6 +49,19 @@ function Adapt.adapt_storage(dev::ReactantDevice, x::AbstractArray) kwargs = (;) dev.client === missing || (kwargs = (; kwargs..., client=dev.client)) dev.device === missing || (kwargs = (; kwargs..., device=dev.device)) + if dev.sharding !== missing + if dev.sharding isa IdDict + sharding = dev.sharding[x] + @assert sharding isa Reactant.Sharding.AbstractSharding + kwargs = (; kwargs..., sharding) + elseif dev.sharding isa Reactant.Sharding.AbstractSharding + kwargs = (; kwargs..., dev.sharding) + else + throw(ArgumentError("`sharding` must be an `IdDict` or a \ + `Reactant.Sharding.AbstractSharding` but got \ + $(typeof(dev.sharding)).")) + end + end return ConcreteRArray(x; kwargs...) end diff --git a/lib/MLDataDevices/src/internal.jl b/lib/MLDataDevices/src/internal.jl index b95a67ea46..4a26442bff 100644 --- a/lib/MLDataDevices/src/internal.jl +++ b/lib/MLDataDevices/src/internal.jl @@ -129,7 +129,14 @@ end combine_devices(dev::ReactantDevice, ::AbstractDevice) = dev combine_devices(::AbstractDevice, dev::ReactantDevice) = dev function combine_devices(dev1::ReactantDevice, dev2::ReactantDevice) - dev1 == dev2 && return dev1 + if dev1 == dev2 + # `merge(...)` of `IdDict` constructs a `Dict` + sharding = dev1.sharding + for (k, v) in dev2.sharding + sharding[k] = v + end + return ReactantDevice(dev1.client, dev1.device, sharding) + end throw(ArgumentError("Objects are on different devices: $(dev1) and $(dev2).")) end combine_devices(::Type{ReactantDevice}, ::Type{UnknownDevice}) = ReactantDevice diff --git a/lib/MLDataDevices/src/public.jl b/lib/MLDataDevices/src/public.jl index 73e99ae634..31151a9d44 100644 --- a/lib/MLDataDevices/src/public.jl +++ b/lib/MLDataDevices/src/public.jl @@ -9,10 +9,10 @@ end struct MetalDevice <: AbstractGPUDevice end struct oneAPIDevice <: AbstractGPUDevice end -# TODO: We need to rethink how to store the client and device. what about sharding?? -@kwdef struct ReactantDevice{C, D} <: AbstractAcceleratorDevice +@kwdef struct ReactantDevice{C, D, S} <: AbstractAcceleratorDevice client::C = missing device::D = missing + sharding::S = missing end function Base.:(==)(x::ReactantDevice, y::ReactantDevice) @@ -212,24 +212,26 @@ cpu_device() = CPUDevice() """ reactant_device(; - force::Bool=false, client=missing, device=missing + force::Bool=false, client=missing, device=missing, sharding=missing ) -> Union{ReactantDevice, CPUDevice} -Return a `ReactantDevice` object if functional. Otherwise, throw an error if `force` is `true`. -Falls back to `CPUDevice` if `force` is `false`. +Return a `ReactantDevice` object if functional. Otherwise, throw an error if `force` is +`true`. Falls back to `CPUDevice` if `force` is `false`. `client` and `device` are used to specify the client and particular device to use. If not specified, then the default client and index are used. -!!! danger - - This is an experimental feature and might change without deprecations +`sharding` is used to specify the sharding strategy. If a +`Reactant.Sharding.AbstractSharding` is specified, then we use it to shard all abstract +arrays. Alternatively, pass in a `IdDict` to specify the sharding for specific leaves. """ -function reactant_device(; force::Bool=false, client=missing, device=missing) +function reactant_device(; + force::Bool=false, client=missing, device=missing, sharding=missing +) msg = "`ReactantDevice` is not loaded or not functional. Load `Reactant.jl` before \ calling this function. Defaulting to CPU." if loaded(ReactantDevice) - functional(ReactantDevice) && return ReactantDevice(client, device) + functional(ReactantDevice) && return ReactantDevice(client, device, sharding) msg = "`ReactantDevice` is loaded but not functional. Defaulting to CPU." end force && throw(Internal.DeviceSelectionException("XLA")) diff --git a/lib/WeightInitializers/test/Project.toml b/lib/WeightInitializers/test/Project.toml index ce6ba79947..5728c04995 100644 --- a/lib/WeightInitializers/test/Project.toml +++ b/lib/WeightInitializers/test/Project.toml @@ -19,7 +19,7 @@ Aqua = "0.8.7" Documenter = "1.5.0" ExplicitImports = "1.9.0" GPUArrays = "10.2" -GPUArraysCore = "0.1.6" +GPUArraysCore = "0.1.6, 0.2" Hwloc = "3.3" InteractiveUtils = "<0.0.1, 1" LinearAlgebra = "1.10"