Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Add fast and type stable paths for certain datastructures #15

Merged
merged 3 commits into from
Sep 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "LuxDeviceUtils"
uuid = "34f89e08-e1d5-43b4-8944-0b49ac560553"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "0.1.7"
version = "0.1.8"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand All @@ -14,15 +14,13 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"

[weakdeps]
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
LuxAMDGPU = "83120cb1-ca15-4f04-bf3b-6967d2e6b60b"
LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda"
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[extensions]
LuxDeviceUtilsComponentArraysExt = "ComponentArrays"
LuxDeviceUtilsFillArraysExt = "FillArrays"
LuxDeviceUtilsLuxAMDGPUExt = "LuxAMDGPU"
LuxDeviceUtilsLuxCUDAExt = "LuxCUDA"
Expand All @@ -32,7 +30,6 @@ LuxDeviceUtilsZygoteExt = "Zygote"
[compat]
Adapt = "3"
ChainRulesCore = "1"
ComponentArrays = "0.13, 0.14"
FillArrays = "0.13, 1"
Functors = "0.2, 0.3, 0.4"
LuxAMDGPU = "0.1"
Expand All @@ -45,7 +42,6 @@ Zygote = "0.6"
julia = "1.6"

[extras]
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
LuxAMDGPU = "83120cb1-ca15-4f04-bf3b-6967d2e6b60b"
LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda"
Expand Down
10 changes: 0 additions & 10 deletions ext/LuxDeviceUtilsComponentArraysExt.jl

This file was deleted.

25 changes: 18 additions & 7 deletions src/LuxDeviceUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -209,14 +209,25 @@ Return a `LuxCPUDevice` object which can be used to transfer data to CPU.
"""
@inline cpu_device() = LuxCPUDevice()

(::LuxCPUDevice)(x) = fmap(Base.Fix1(adapt, LuxCPUAdaptor()), x; exclude=_isleaf)
(::LuxCUDADevice)(x) = fmap(Base.Fix1(adapt, LuxCUDAAdaptor()), x; exclude=_isleaf)
(::LuxAMDGPUDevice)(x) = fmap(Base.Fix1(adapt, LuxAMDGPUAdaptor()), x; exclude=_isleaf)
(::LuxMetalDevice)(x) = fmap(Base.Fix1(adapt, LuxMetalAdaptor()), x; exclude=_isleaf)

for dev in (LuxCPUDevice, LuxCUDADevice, LuxAMDGPUDevice, LuxMetalDevice)
# Dispatches for Different Data Structures
# Abstract Array / Tuples / NamedTuples have special fast paths to facilitate type stability
# For all other types we rely on fmap which means we lose type stability.
# For Lux, typically models only has these 3 datastructures so we should be mostly fine.
for (dev) in (:CPU, :CUDA, :AMDGPU, :Metal)
ldev = Symbol("Lux$(dev)Device")
ladaptor = Symbol("Lux$(dev)Adaptor")
@eval begin
function (::$dev)(::LuxCore.AbstractExplicitLayer)
function (::$(ldev))(x::AbstractArray)
fn = Base.Fix1(adapt, $(ladaptor)())
return _isbitsarray(x) ? fn(x) : map(fn, x)
end
(::$(ldev))(x::Tuple) = map(Base.Fix1(adapt, $(ladaptor)()), x)
(dev::$(ldev))(x::NamedTuple{F}) where {F} = NamedTuple{F}(dev(values(x)))
function (::$(ldev))(x)
_isleaf(x) && return adapt($(ladaptor)(), x)
return fmap(Base.Fix1(adapt, $(ladaptor)()), x; exclude=_isleaf)
end
function (::$(ldev))(::LuxCore.AbstractExplicitLayer)
throw(ArgumentError("Lux layers are stateless and hence don't participate in device transfers. Apply this function on the parameters and states generated using `Lux.setup`."))
end
end
Expand Down
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
ComponentArrays = "0.14.1"
julia = "1.6"