From 736184509a40d8a87cd9a33a56f87f39397e3616 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 11 Sep 2024 11:47:14 -0400 Subject: [PATCH] fix: improve load times by moving CRC to ext --- Project.toml | 5 +++-- ext/MLDataDevicesChainRulesCoreExt.jl | 19 +++++++++++++++++++ src/MLDataDevices.jl | 3 --- src/public.jl | 18 +++--------------- 4 files changed, 25 insertions(+), 20 deletions(-) create mode 100644 ext/MLDataDevicesChainRulesCoreExt.jl diff --git a/Project.toml b/Project.toml index 0602650..eedc493 100644 --- a/Project.toml +++ b/Project.toml @@ -1,11 +1,10 @@ name = "MLDataDevices" uuid = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" authors = ["Avik Pal and contributors"] -version = "1.1.0" +version = "1.1.1" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" -ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" Preferences = "21216c6a-2e73-6563-6e65-726566657250" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" @@ -14,6 +13,7 @@ UnrolledUtilities = "0fe1646c-419e-43be-ac14-22321958931b" [weakdeps] AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" @@ -29,6 +29,7 @@ oneAPI = "8f75cd03-7ff8-4ecb-9b8f-daf728133b1b" [extensions] MLDataDevicesAMDGPUExt = "AMDGPU" MLDataDevicesCUDAExt = "CUDA" +MLDataDevicesChainRulesCoreExt = "ChainRulesCore" MLDataDevicesFillArraysExt = "FillArrays" MLDataDevicesGPUArraysExt = "GPUArrays" MLDataDevicesMLUtilsExt = "MLUtils" diff --git a/ext/MLDataDevicesChainRulesCoreExt.jl b/ext/MLDataDevicesChainRulesCoreExt.jl new file mode 100644 index 0000000..c6b9560 --- /dev/null +++ b/ext/MLDataDevicesChainRulesCoreExt.jl @@ -0,0 +1,19 @@ +module MLDataDevicesChainRulesCoreExt + +using Adapt: Adapt +using ChainRulesCore: ChainRulesCore, NoTangent, @non_differentiable + +using MLDataDevices: AbstractDevice, get_device, get_device_type + +@non_differentiable get_device(::Any) +@non_differentiable get_device_type(::Any) + +function ChainRulesCore.rrule( + ::typeof(Adapt.adapt_storage), to::AbstractDevice, x::AbstractArray) + ∇adapt_storage = let x = x + Δ -> (NoTangent(), NoTangent(), (get_device(x))(Δ)) + end + return Adapt.adapt_storage(to, x), ∇adapt_storage +end + +end diff --git a/src/MLDataDevices.jl b/src/MLDataDevices.jl index 574fea4..d7e98b4 100644 --- a/src/MLDataDevices.jl +++ b/src/MLDataDevices.jl @@ -1,13 +1,10 @@ module MLDataDevices using Adapt: Adapt -using ChainRulesCore: ChainRulesCore, NoTangent using Functors: Functors, fleaves using Preferences: @delete_preferences!, @load_preference, @set_preferences! using Random: AbstractRNG, Random -const CRC = ChainRulesCore - abstract type AbstractDevice <: Function end abstract type AbstractGPUDevice <: AbstractDevice end diff --git a/src/public.jl b/src/public.jl index d7a7d27..593ba01 100644 --- a/src/public.jl +++ b/src/public.jl @@ -308,13 +308,9 @@ for (dev) in (:CPU, :CUDA, :AMDGPU, :Metal, :oneAPI) end for op in (:get_device, :get_device_type) - @eval begin - function $(op)(x) - hasmethod(Internal.$(op), Tuple{typeof(x)}) && return Internal.$(op)(x) - return mapreduce(Internal.$(op), Internal.combine_devices, fleaves(x)) - end - - CRC.@non_differentiable $op(::Any) + @eval function $(op)(x) + hasmethod(Internal.$(op), Tuple{typeof(x)}) && return Internal.$(op)(x) + return mapreduce(Internal.$(op), Internal.combine_devices, fleaves(x)) end end @@ -337,11 +333,3 @@ for T in (AMDGPUDevice, AMDGPUDevice{Nothing}, CUDADevice, CUDADevice{Nothing}, MetalDevice, oneAPIDevice) @eval Adapt.adapt_storage(to::$(T), x::AbstractRange) = Adapt.adapt(to, collect(x)) end - -# Chain Rules Core -function CRC.rrule(::typeof(Adapt.adapt_storage), to::AbstractDevice, x::AbstractArray) - ∇adapt_storage = let x = x - Δ -> (NoTangent(), NoTangent(), (get_device(x))(Δ)) - end - return Adapt.adapt_storage(to, x), ∇adapt_storage -end