From 9a28dc2b06efa4654810dcb94cb31f832e14aa5a Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 18 Feb 2024 18:10:19 -0500 Subject: [PATCH] Add a get_device function --- Project.toml | 2 +- ext/LuxDeviceUtilsLuxAMDGPUExt.jl | 23 ++++------------------- ext/LuxDeviceUtilsLuxCUDAExt.jl | 23 ++++------------------- ext/LuxDeviceUtilsMetalGPUArraysExt.jl | 23 ++++------------------- src/LuxDeviceUtils.jl | 20 ++++++++++++++++++++ 5 files changed, 33 insertions(+), 58 deletions(-) diff --git a/Project.toml b/Project.toml index de99863..da0cab4 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "LuxDeviceUtils" uuid = "34f89e08-e1d5-43b4-8944-0b49ac560553" authors = ["Avik Pal and contributors"] -version = "0.1.14" +version = "0.1.15" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/ext/LuxDeviceUtilsLuxAMDGPUExt.jl b/ext/LuxDeviceUtilsLuxAMDGPUExt.jl index 7a7fbbc..ac951f1 100644 --- a/ext/LuxDeviceUtilsLuxAMDGPUExt.jl +++ b/ext/LuxDeviceUtilsLuxAMDGPUExt.jl @@ -1,8 +1,7 @@ module LuxDeviceUtilsLuxAMDGPUExt -using ChainRulesCore, LuxAMDGPU, LuxDeviceUtils, Random +using LuxAMDGPU, LuxDeviceUtils, Random import Adapt: adapt_storage, adapt -import ChainRulesCore as CRC __init__() = reset_gpu_device!() @@ -12,6 +11,9 @@ LuxDeviceUtils.__is_functional(::LuxAMDGPUDevice) = LuxAMDGPU.functional() # Default RNG LuxDeviceUtils.default_device_rng(::LuxAMDGPUDevice) = AMDGPU.rocrand_rng() +# Query Device from Array +LuxDeviceUtils.get_device(::AMDGPU.AnyROCArray) = LuxAMDGPUDevice() + # Device Transfer ## To GPU adapt_storage(::LuxAMDGPUAdaptor, x) = roc(x) @@ -20,21 +22,4 @@ adapt_storage(::LuxAMDGPUAdaptor, rng::Random.TaskLocalRNG) = AMDGPU.rocrand_rng adapt_storage(::LuxCPUAdaptor, rng::AMDGPU.rocRAND.RNG) = Random.default_rng() -## Chain Rules -CRC.rrule(::Type{Array}, x::ROCArray) = Array(x), Δ -> (NoTangent(), roc(Δ)) - -function CRC.rrule(::typeof(adapt_storage), to::LuxCPUAdaptor, x::AMDGPU.AnyROCArray) - function ∇adapt_storage(Δ) - return (NoTangent(), NoTangent(), adapt_storage(LuxAMDGPUAdaptor(), Δ)) - end - return adapt_storage(to, x), ∇adapt_storage -end - -function CRC.rrule(::typeof(adapt_storage), to::LuxAMDGPUAdaptor, x::Array) - function ∇adapt_storage(Δ) - return (NoTangent(), NoTangent(), adapt_storage(LuxCPUAdaptor(), Δ)) - end - return adapt_storage(to, x), ∇adapt_storage -end - end diff --git a/ext/LuxDeviceUtilsLuxCUDAExt.jl b/ext/LuxDeviceUtilsLuxCUDAExt.jl index 5ed4850..4edf554 100644 --- a/ext/LuxDeviceUtilsLuxCUDAExt.jl +++ b/ext/LuxDeviceUtilsLuxCUDAExt.jl @@ -1,8 +1,7 @@ module LuxDeviceUtilsLuxCUDAExt -using ChainRulesCore, LuxCUDA, LuxDeviceUtils, Random +using LuxCUDA, LuxDeviceUtils, Random import Adapt: adapt_storage, adapt -import ChainRulesCore as CRC __init__() = reset_gpu_device!() @@ -12,6 +11,9 @@ LuxDeviceUtils.__is_functional(::LuxCUDADevice) = LuxCUDA.functional() # Default RNG LuxDeviceUtils.default_device_rng(::LuxCUDADevice) = CUDA.default_rng() +# Query Device from Array +LuxDeviceUtils.get_device(::CUDA.AnyCuArray) = LuxCUDADevice() + # Device Transfer ## To GPU adapt_storage(::LuxCUDAAdaptor, x) = cu(x) @@ -23,21 +25,4 @@ adapt_storage(::LuxCPUAdaptor, rng::CUDA.RNG) = Random.default_rng() ## To CPU adapt_storage(::LuxCPUAdaptor, x::CUSPARSE.AbstractCuSparseMatrix) = adapt(Array, x) -## Chain Rules -CRC.rrule(::Type{Array}, x::CuArray) = Array(x), Δ -> (NoTangent(), cu(Δ)) - -function CRC.rrule(::typeof(adapt_storage), to::LuxCPUAdaptor, x::CUDA.AnyCuArray) - function ∇adapt_storage(Δ) - return (NoTangent(), NoTangent(), adapt_storage(LuxCUDAAdaptor(), Δ)) - end - return adapt_storage(to, x), ∇adapt_storage -end - -function CRC.rrule(::typeof(adapt_storage), to::LuxCUDAAdaptor, x::Array) - function ∇adapt_storage(Δ) - return (NoTangent(), NoTangent(), adapt_storage(LuxCPUAdaptor(), Δ)) - end - return adapt_storage(to, x), ∇adapt_storage -end - end diff --git a/ext/LuxDeviceUtilsMetalGPUArraysExt.jl b/ext/LuxDeviceUtilsMetalGPUArraysExt.jl index 8e8ffe8..836ab07 100644 --- a/ext/LuxDeviceUtilsMetalGPUArraysExt.jl +++ b/ext/LuxDeviceUtilsMetalGPUArraysExt.jl @@ -1,8 +1,7 @@ module LuxDeviceUtilsMetalGPUArraysExt -using ChainRulesCore, GPUArrays, LuxDeviceUtils, Metal, Random +using GPUArrays, LuxDeviceUtils, Metal, Random import Adapt: adapt_storage, adapt -import ChainRulesCore as CRC __init__() = reset_gpu_device!() @@ -12,27 +11,13 @@ LuxDeviceUtils.__is_functional(::LuxMetalDevice) = Metal.functional() # Default RNG LuxDeviceUtils.default_device_rng(::LuxMetalDevice) = GPUArrays.default_rng(MtlArray) +# Query Device from Array +LuxDeviceUtils.get_device(::MtlArray) = LuxMetalDevice() + # Device Transfer ## To GPU adapt_storage(::LuxMetalAdaptor, x) = mtl(x) adapt_storage(::LuxMetalAdaptor, rng::AbstractRNG) = rng adapt_storage(::LuxMetalAdaptor, rng::Random.TaskLocalRNG) = GPUArrays.default_rng(MtlArray) -## Chain Rules -CRC.rrule(::Type{Array}, x::MtlArray) = Array(x), Δ -> (NoTangent(), MtlArray(Δ)) - -function CRC.rrule(::typeof(adapt_storage), to::LuxCPUAdaptor, x::MtlArray) - function ∇adapt_storage(Δ) - return (NoTangent(), NoTangent(), adapt_storage(LuxMetalAdaptor(), Δ)) - end - return adapt_storage(to, x), ∇adapt_storage -end - -function CRC.rrule(::typeof(adapt_storage), to::LuxMetalAdaptor, x::Array) - function ∇adapt_storage(Δ) - return (NoTangent(), NoTangent(), adapt_storage(LuxCPUAdaptor(), Δ)) - end - return adapt_storage(to, x), ∇adapt_storage -end - end diff --git a/src/LuxDeviceUtils.jl b/src/LuxDeviceUtils.jl index 24ab500..04347dc 100644 --- a/src/LuxDeviceUtils.jl +++ b/src/LuxDeviceUtils.jl @@ -5,12 +5,14 @@ import PrecompileTools: @recompile_invalidations @recompile_invalidations begin using ChainRulesCore, Functors, LuxCore, Preferences, Random, SparseArrays import Adapt: adapt, adapt_storage + import ChainRulesCore as CRC end export gpu_backend!, supported_gpu_backends, reset_gpu_device! export default_device_rng export gpu_device, cpu_device, LuxCPUDevice, LuxCUDADevice, LuxAMDGPUDevice, LuxMetalDevice export LuxCPUAdaptor, LuxCUDAAdaptor, LuxAMDGPUAdaptor, LuxMetalAdaptor +export get_device abstract type AbstractLuxDevice <: Function end abstract type AbstractLuxGPUDevice <: AbstractLuxDevice end @@ -255,6 +257,15 @@ for (dev) in (:CPU, :CUDA, :AMDGPU, :Metal) end end +# Query Device from Array +""" + get_device(x::AbstractArray) -> AbstractLuxDevice + +Returns the device of the array `x`. Trigger Packages must be loaded for this to return the +correct device. +""" +get_device(x::AbstractArray) = LuxCPUDevice() + # Adapt Interface abstract type AbstractLuxDeviceAdaptor end @@ -277,4 +288,13 @@ _isbitsarray(x) = false _isleaf(::AbstractRNG) = true _isleaf(x) = _isbitsarray(x) || Functors.isleaf(x) +# Chain Rules Core +function CRC.rrule(::typeof(adapt_storage), to::AbstractLuxDeviceAdaptor, x::AbstractArray) + function ∇adapt_storage(Δ) + dev = get_device(x) + return (NoTangent(), NoTangent(), dev(Δ)) + end + return adapt_storage(to, x), ∇adapt_storage +end + end