Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Commit

Permalink
Add adapt_structure for CA
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Aug 12, 2023
1 parent db95b0a commit a79ca1d
Show file tree
Hide file tree
Showing 7 changed files with 49 additions and 8 deletions.
1 change: 0 additions & 1 deletion .JuliaFormatter.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,5 @@ always_use_return = true
margin = 92
indent = 4
format_docstrings = true
join_lines_based_on_source = false
separate_kwargs_with_semicolon = true
always_for_in = true
6 changes: 5 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "LuxDeviceUtils"
uuid = "34f89e08-e1d5-43b4-8944-0b49ac560553"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "0.1.5"
version = "0.1.6"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand All @@ -14,13 +14,15 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"

[weakdeps]
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
LuxAMDGPU = "83120cb1-ca15-4f04-bf3b-6967d2e6b60b"
LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda"
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[extensions]
LuxDeviceUtilsComponentArraysExt = "ComponentArrays"
LuxDeviceUtilsFillArraysExt = "FillArrays"
LuxDeviceUtilsLuxAMDGPUExt = "LuxAMDGPU"
LuxDeviceUtilsLuxCUDAExt = "LuxCUDA"
Expand All @@ -30,6 +32,7 @@ LuxDeviceUtilsZygoteExt = "Zygote"
[compat]
Adapt = "3"
ChainRulesCore = "1"
ComponentArrays = "0.13, 0.14"
FillArrays = "0.13, 1"
Functors = "0.2, 0.3, 0.4"
LuxAMDGPU = "0.1"
Expand All @@ -42,6 +45,7 @@ Zygote = "0.6"
julia = "1.6"

[extras]
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
LuxAMDGPU = "83120cb1-ca15-4f04-bf3b-6967d2e6b60b"
LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda"
Expand Down
10 changes: 10 additions & 0 deletions ext/LuxDeviceUtilsComponentArraysExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
module LuxDeviceUtilsComponentArraysExt

# FIXME: Needs upstreaming
using Adapt, ComponentArrays

function Adapt.adapt_structure(to, ca::ComponentArray)
return ComponentArray(adapt(to, getdata(ca)), getaxes(ca))
end

end
16 changes: 10 additions & 6 deletions src/LuxDeviceUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,11 @@ Return a tuple of supported GPU backends.
This is not the list of functional backends on the system, but rather backends which
`Lux.jl` supports.
!!! warning
`Metal.jl` support is **extremely** experimental and most things are not expected to
work.
"""
supported_gpu_backends() = map(_get_device_name, GPU_DEVICES)

Expand All @@ -87,8 +92,7 @@ Selects GPU device based on the following criteria:
"""
function gpu_device(; force_gpu_usage::Bool=false)::AbstractLuxDevice
if GPU_DEVICE[] !== nothing
force_gpu_usage &&
!(GPU_DEVICE[] isa AbstractLuxGPUDevice) &&
force_gpu_usage && !(GPU_DEVICE[] isa AbstractLuxGPUDevice) &&
throw(LuxDeviceSelectionException())
return GPU_DEVICE[]
end
Expand Down Expand Up @@ -202,10 +206,10 @@ Return a `LuxCPUDevice` object which can be used to transfer data to CPU.
"""
@inline cpu_device() = LuxCPUDevice()

(::LuxCPUDevice)(x) = fmap(x -> adapt(LuxCPUAdaptor(), x), x; exclude=_isleaf)
(::LuxCUDADevice)(x) = fmap(x -> adapt(LuxCUDAAdaptor(), x), x; exclude=_isleaf)
(::LuxAMDGPUDevice)(x) = fmap(x -> adapt(LuxAMDGPUAdaptor(), x), x; exclude=_isleaf)
(::LuxMetalDevice)(x) = fmap(x -> adapt(LuxMetalAdaptor(), x), x; exclude=_isleaf)
(::LuxCPUDevice)(x) = fmap(Base.Fix1(adapt, LuxCPUAdaptor()), x; exclude=_isleaf)
(::LuxCUDADevice)(x) = fmap(Base.Fix1(adapt, LuxCUDAAdaptor()), x; exclude=_isleaf)
(::LuxAMDGPUDevice)(x) = fmap(Base.Fix1(adapt, LuxAMDGPUAdaptor()), x; exclude=_isleaf)
(::LuxMetalDevice)(x) = fmap(Base.Fix1(adapt, LuxMetalAdaptor()), x; exclude=_isleaf)

for dev in (LuxCPUDevice, LuxCUDADevice, LuxAMDGPUDevice, LuxMetalDevice)
@eval begin
Expand Down
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
[deps]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Expand Down
17 changes: 17 additions & 0 deletions test/component_arrays.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
using LuxDeviceUtils, ComponentArrays, Random

@testset "https://github.com/LuxDL/LuxDeviceUtils.jl/issues/10 patch" begin
dev = LuxCPUDevice()
ps = (; weight=randn(10, 1), bias=randn(1))

ps_ca = ps |> ComponentArray

ps_ca_dev = ps_ca |> dev

@test ps_ca_dev isa ComponentArray

@test ps_ca_dev.weight == ps.weight
@test ps_ca_dev.bias == ps.bias

@test ps_ca_dev == (ps |> dev |> ComponentArray)
end
6 changes: 6 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,4 +47,10 @@ end
Aqua.test_all(LuxDeviceUtils; piracy=false)
end
end

@testset "Others" begin
@safetestset "Component Arrays" begin
include("component_arrays.jl")
end
end
end

0 comments on commit a79ca1d

Please sign in to comment.