Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Add a get_device function #33

Merged
merged 1 commit into from
Feb 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "LuxDeviceUtils"
uuid = "34f89e08-e1d5-43b4-8944-0b49ac560553"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "0.1.14"
version = "0.1.15"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down
23 changes: 4 additions & 19 deletions ext/LuxDeviceUtilsLuxAMDGPUExt.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
module LuxDeviceUtilsLuxAMDGPUExt

using ChainRulesCore, LuxAMDGPU, LuxDeviceUtils, Random
using LuxAMDGPU, LuxDeviceUtils, Random
import Adapt: adapt_storage, adapt
import ChainRulesCore as CRC

__init__() = reset_gpu_device!()

Expand All @@ -12,6 +11,9 @@ LuxDeviceUtils.__is_functional(::LuxAMDGPUDevice) = LuxAMDGPU.functional()
# Default RNG
LuxDeviceUtils.default_device_rng(::LuxAMDGPUDevice) = AMDGPU.rocrand_rng()

# Query Device from Array
LuxDeviceUtils.get_device(::AMDGPU.AnyROCArray) = LuxAMDGPUDevice()

# Device Transfer
## To GPU
adapt_storage(::LuxAMDGPUAdaptor, x) = roc(x)
Expand All @@ -20,21 +22,4 @@ adapt_storage(::LuxAMDGPUAdaptor, rng::Random.TaskLocalRNG) = AMDGPU.rocrand_rng

adapt_storage(::LuxCPUAdaptor, rng::AMDGPU.rocRAND.RNG) = Random.default_rng()

## Chain Rules
CRC.rrule(::Type{Array}, x::ROCArray) = Array(x), Δ -> (NoTangent(), roc(Δ))

function CRC.rrule(::typeof(adapt_storage), to::LuxCPUAdaptor, x::AMDGPU.AnyROCArray)
function ∇adapt_storage(Δ)
return (NoTangent(), NoTangent(), adapt_storage(LuxAMDGPUAdaptor(), Δ))
end
return adapt_storage(to, x), ∇adapt_storage
end

function CRC.rrule(::typeof(adapt_storage), to::LuxAMDGPUAdaptor, x::Array)
function ∇adapt_storage(Δ)
return (NoTangent(), NoTangent(), adapt_storage(LuxCPUAdaptor(), Δ))
end
return adapt_storage(to, x), ∇adapt_storage
end

end
23 changes: 4 additions & 19 deletions ext/LuxDeviceUtilsLuxCUDAExt.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
module LuxDeviceUtilsLuxCUDAExt

using ChainRulesCore, LuxCUDA, LuxDeviceUtils, Random
using LuxCUDA, LuxDeviceUtils, Random
import Adapt: adapt_storage, adapt
import ChainRulesCore as CRC

__init__() = reset_gpu_device!()

Expand All @@ -12,6 +11,9 @@ LuxDeviceUtils.__is_functional(::LuxCUDADevice) = LuxCUDA.functional()
# Default RNG
LuxDeviceUtils.default_device_rng(::LuxCUDADevice) = CUDA.default_rng()

# Query Device from Array
LuxDeviceUtils.get_device(::CUDA.AnyCuArray) = LuxCUDADevice()

# Device Transfer
## To GPU
adapt_storage(::LuxCUDAAdaptor, x) = cu(x)
Expand All @@ -23,21 +25,4 @@ adapt_storage(::LuxCPUAdaptor, rng::CUDA.RNG) = Random.default_rng()
## To CPU
adapt_storage(::LuxCPUAdaptor, x::CUSPARSE.AbstractCuSparseMatrix) = adapt(Array, x)

## Chain Rules
CRC.rrule(::Type{Array}, x::CuArray) = Array(x), Δ -> (NoTangent(), cu(Δ))

function CRC.rrule(::typeof(adapt_storage), to::LuxCPUAdaptor, x::CUDA.AnyCuArray)
function ∇adapt_storage(Δ)
return (NoTangent(), NoTangent(), adapt_storage(LuxCUDAAdaptor(), Δ))
end
return adapt_storage(to, x), ∇adapt_storage
end

function CRC.rrule(::typeof(adapt_storage), to::LuxCUDAAdaptor, x::Array)
function ∇adapt_storage(Δ)
return (NoTangent(), NoTangent(), adapt_storage(LuxCPUAdaptor(), Δ))
end
return adapt_storage(to, x), ∇adapt_storage
end

end
23 changes: 4 additions & 19 deletions ext/LuxDeviceUtilsMetalGPUArraysExt.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
module LuxDeviceUtilsMetalGPUArraysExt

using ChainRulesCore, GPUArrays, LuxDeviceUtils, Metal, Random
using GPUArrays, LuxDeviceUtils, Metal, Random
import Adapt: adapt_storage, adapt
import ChainRulesCore as CRC

__init__() = reset_gpu_device!()

Expand All @@ -12,27 +11,13 @@ LuxDeviceUtils.__is_functional(::LuxMetalDevice) = Metal.functional()
# Default RNG
LuxDeviceUtils.default_device_rng(::LuxMetalDevice) = GPUArrays.default_rng(MtlArray)

# Query Device from Array
LuxDeviceUtils.get_device(::MtlArray) = LuxMetalDevice()

# Device Transfer
## To GPU
adapt_storage(::LuxMetalAdaptor, x) = mtl(x)
adapt_storage(::LuxMetalAdaptor, rng::AbstractRNG) = rng
adapt_storage(::LuxMetalAdaptor, rng::Random.TaskLocalRNG) = GPUArrays.default_rng(MtlArray)

## Chain Rules
CRC.rrule(::Type{Array}, x::MtlArray) = Array(x), Δ -> (NoTangent(), MtlArray(Δ))

function CRC.rrule(::typeof(adapt_storage), to::LuxCPUAdaptor, x::MtlArray)
function ∇adapt_storage(Δ)
return (NoTangent(), NoTangent(), adapt_storage(LuxMetalAdaptor(), Δ))
end
return adapt_storage(to, x), ∇adapt_storage
end

function CRC.rrule(::typeof(adapt_storage), to::LuxMetalAdaptor, x::Array)
function ∇adapt_storage(Δ)
return (NoTangent(), NoTangent(), adapt_storage(LuxCPUAdaptor(), Δ))
end
return adapt_storage(to, x), ∇adapt_storage
end

end
20 changes: 20 additions & 0 deletions src/LuxDeviceUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@ import PrecompileTools: @recompile_invalidations
@recompile_invalidations begin
using ChainRulesCore, Functors, LuxCore, Preferences, Random, SparseArrays
import Adapt: adapt, adapt_storage
import ChainRulesCore as CRC
end

export gpu_backend!, supported_gpu_backends, reset_gpu_device!
export default_device_rng
export gpu_device, cpu_device, LuxCPUDevice, LuxCUDADevice, LuxAMDGPUDevice, LuxMetalDevice
export LuxCPUAdaptor, LuxCUDAAdaptor, LuxAMDGPUAdaptor, LuxMetalAdaptor
export get_device

abstract type AbstractLuxDevice <: Function end
abstract type AbstractLuxGPUDevice <: AbstractLuxDevice end
Expand Down Expand Up @@ -255,6 +257,15 @@ for (dev) in (:CPU, :CUDA, :AMDGPU, :Metal)
end
end

# Query Device from Array
"""
get_device(x::AbstractArray) -> AbstractLuxDevice

Returns the device of the array `x`. Trigger Packages must be loaded for this to return the
correct device.
"""
get_device(x::AbstractArray) = LuxCPUDevice()

# Adapt Interface
abstract type AbstractLuxDeviceAdaptor end

Expand All @@ -277,4 +288,13 @@ _isbitsarray(x) = false
_isleaf(::AbstractRNG) = true
_isleaf(x) = _isbitsarray(x) || Functors.isleaf(x)

# Chain Rules Core
function CRC.rrule(::typeof(adapt_storage), to::AbstractLuxDeviceAdaptor, x::AbstractArray)
function ∇adapt_storage(Δ)
dev = get_device(x)
return (NoTangent(), NoTangent(), dev(Δ))
end
return adapt_storage(to, x), ∇adapt_storage
end

end
Loading