From 38f1a738326298f6fdb5899be0c8c2f6c0075f48 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Mon, 18 Nov 2024 00:11:17 +0100 Subject: [PATCH] non_differentiable gpu_device and cpu_device (#1089) * non_differentiable gpu_device and cpu_device * Update lib/MLDataDevices/ext/MLDataDevicesChainRulesCoreExt.jl * fix: missing imports * chore: bump version for release --------- Co-authored-by: Avik Pal --- lib/MLDataDevices/Project.toml | 2 +- .../ext/MLDataDevicesChainRulesCoreExt.jl | 10 +++++++--- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/lib/MLDataDevices/Project.toml b/lib/MLDataDevices/Project.toml index f087936bd..b3a7ebf07 100644 --- a/lib/MLDataDevices/Project.toml +++ b/lib/MLDataDevices/Project.toml @@ -1,7 +1,7 @@ name = "MLDataDevices" uuid = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" authors = ["Avik Pal and contributors"] -version = "1.6.1" +version = "1.6.2" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/lib/MLDataDevices/ext/MLDataDevicesChainRulesCoreExt.jl b/lib/MLDataDevices/ext/MLDataDevicesChainRulesCoreExt.jl index 518ff205d..2b230cd87 100644 --- a/lib/MLDataDevices/ext/MLDataDevicesChainRulesCoreExt.jl +++ b/lib/MLDataDevices/ext/MLDataDevicesChainRulesCoreExt.jl @@ -3,10 +3,14 @@ module MLDataDevicesChainRulesCoreExt using Adapt: Adapt using ChainRulesCore: ChainRulesCore, NoTangent, ProjectTo, @non_differentiable -using MLDataDevices: AbstractDevice, UnknownDevice, get_device, get_device_type +using MLDataDevices: AbstractDevice, UnknownDevice, get_device, get_device_type, + reactant_device, cpu_device, gpu_device -@non_differentiable get_device(::Any) -@non_differentiable get_device_type(::Any) +@non_differentiable get_device(::Any...) +@non_differentiable get_device_type(::Any...) +@non_differentiable gpu_device(::Any...) +@non_differentiable cpu_device(::Any...) +@non_differentiable reactant_device(::Any...) function ChainRulesCore.rrule(::typeof(Adapt.adapt), to::AbstractDevice, x::AbstractArray) dev = get_device(x)