Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Commit

Permalink
fix: improve load times by moving CRC to ext
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Sep 11, 2024
1 parent 7d2776e commit 7361845
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 20 deletions.
5 changes: 3 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
name = "MLDataDevices"
uuid = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "1.1.0"
version = "1.1.1"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand All @@ -14,6 +13,7 @@ UnrolledUtilities = "0fe1646c-419e-43be-ac14-22321958931b"
[weakdeps]
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
Expand All @@ -29,6 +29,7 @@ oneAPI = "8f75cd03-7ff8-4ecb-9b8f-daf728133b1b"
[extensions]
MLDataDevicesAMDGPUExt = "AMDGPU"
MLDataDevicesCUDAExt = "CUDA"
MLDataDevicesChainRulesCoreExt = "ChainRulesCore"
MLDataDevicesFillArraysExt = "FillArrays"
MLDataDevicesGPUArraysExt = "GPUArrays"
MLDataDevicesMLUtilsExt = "MLUtils"
Expand Down
19 changes: 19 additions & 0 deletions ext/MLDataDevicesChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
module MLDataDevicesChainRulesCoreExt

using Adapt: Adapt
using ChainRulesCore: ChainRulesCore, NoTangent, @non_differentiable

using MLDataDevices: AbstractDevice, get_device, get_device_type

@non_differentiable get_device(::Any)
@non_differentiable get_device_type(::Any)

function ChainRulesCore.rrule(
::typeof(Adapt.adapt_storage), to::AbstractDevice, x::AbstractArray)
∇adapt_storage = let x = x
Δ -> (NoTangent(), NoTangent(), (get_device(x))(Δ))
end
return Adapt.adapt_storage(to, x), ∇adapt_storage
end

end
3 changes: 0 additions & 3 deletions src/MLDataDevices.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
module MLDataDevices

using Adapt: Adapt
using ChainRulesCore: ChainRulesCore, NoTangent
using Functors: Functors, fleaves
using Preferences: @delete_preferences!, @load_preference, @set_preferences!
using Random: AbstractRNG, Random

const CRC = ChainRulesCore

abstract type AbstractDevice <: Function end
abstract type AbstractGPUDevice <: AbstractDevice end

Expand Down
18 changes: 3 additions & 15 deletions src/public.jl
Original file line number Diff line number Diff line change
Expand Up @@ -308,13 +308,9 @@ for (dev) in (:CPU, :CUDA, :AMDGPU, :Metal, :oneAPI)
end

for op in (:get_device, :get_device_type)
@eval begin
function $(op)(x)
hasmethod(Internal.$(op), Tuple{typeof(x)}) && return Internal.$(op)(x)
return mapreduce(Internal.$(op), Internal.combine_devices, fleaves(x))
end

CRC.@non_differentiable $op(::Any)
@eval function $(op)(x)
hasmethod(Internal.$(op), Tuple{typeof(x)}) && return Internal.$(op)(x)
return mapreduce(Internal.$(op), Internal.combine_devices, fleaves(x))

Check warning on line 313 in src/public.jl

View check run for this annotation

Codecov / codecov/patch

src/public.jl#L313

Added line #L313 was not covered by tests
end
end

Expand All @@ -337,11 +333,3 @@ for T in (AMDGPUDevice, AMDGPUDevice{Nothing}, CUDADevice,
CUDADevice{Nothing}, MetalDevice, oneAPIDevice)
@eval Adapt.adapt_storage(to::$(T), x::AbstractRange) = Adapt.adapt(to, collect(x))
end

# Chain Rules Core
function CRC.rrule(::typeof(Adapt.adapt_storage), to::AbstractDevice, x::AbstractArray)
∇adapt_storage = let x = x
Δ -> (NoTangent(), NoTangent(), (get_device(x))(Δ))
end
return Adapt.adapt_storage(to, x), ∇adapt_storage
end

0 comments on commit 7361845

Please sign in to comment.