Skip to content

Commit

Permalink
fix: use the correct dispatches for device overloads
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Dec 4, 2024
1 parent 921abf3 commit ab1b045
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 5 deletions.
6 changes: 4 additions & 2 deletions lib/LuxCore/Project.toml
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
name = "LuxCore"
uuid = "bb33d45b-7691-41d6-9220-0943567d0623"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "1.2.0"
version = "1.2.1"

[deps]
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
DispatchDoctor = "8d63f2c5-f18a-4cf2-ba9d-b3f60fc568c8"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

[weakdeps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
Expand All @@ -25,11 +26,12 @@ LuxCoreArrayInterfaceTrackerExt = ["ArrayInterface", "Tracker"]
LuxCoreChainRulesCoreExt = "ChainRulesCore"
LuxCoreEnzymeCoreExt = "EnzymeCore"
LuxCoreFunctorsExt = "Functors"
LuxCoreMLDataDevicesExt = "MLDataDevices"
LuxCoreMLDataDevicesExt = ["Adapt", "MLDataDevices"]
LuxCoreReactantExt = "Reactant"
LuxCoreSetfieldExt = "Setfield"

[compat]
Adapt = "4.1"
ArrayInterface = "7.9"
ChainRulesCore = "1.24"
Compat = "4.16"
Expand Down
9 changes: 6 additions & 3 deletions lib/LuxCore/ext/LuxCoreMLDataDevicesExt.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
module LuxCoreMLDataDevicesExt

using Adapt: Adapt
using LuxCore: LuxCore
using MLDataDevices: MLDataDevices

for (dev) in (:CPU, :CUDA, :AMDGPU, :Metal, :oneAPI)
MLDataDevices.isleaf(::LuxCore.AbstractLuxLayer) = true

for (dev) in (:CPU, :CUDA, :AMDGPU, :Metal, :oneAPI, :Reactant)
ldev = Symbol(dev, :Device)
@eval function (::MLDataDevices.$(ldev))(NN::LuxCore.AbstractLuxLayer)
@eval function Adapt.adapt_storage(::MLDataDevices.$(ldev), x::LuxCore.AbstractLuxLayer)
@warn "Lux layers are stateless and hence don't participate in device transfers. \
Apply this function on the parameters and states generated using \
`LuxCore.setup`."
return NN
return x
end
end

Expand Down

0 comments on commit ab1b045

Please sign in to comment.