Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Commit

Permalink
Merge pull request #11 from LuxDL/ap/ca_patch
Browse files Browse the repository at this point in the history
Add adapt_structure for CA
  • Loading branch information
avik-pal authored Aug 12, 2023
2 parents db95b0a + a79ca1d commit 07dc08a
Show file tree
Hide file tree
Showing 7 changed files with 49 additions and 8 deletions.
1 change: 0 additions & 1 deletion .JuliaFormatter.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,5 @@ always_use_return = true
margin = 92
indent = 4
format_docstrings = true
join_lines_based_on_source = false
separate_kwargs_with_semicolon = true
always_for_in = true
6 changes: 5 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "LuxDeviceUtils"
uuid = "34f89e08-e1d5-43b4-8944-0b49ac560553"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "0.1.5"
version = "0.1.6"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand All @@ -14,13 +14,15 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"

[weakdeps]
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
LuxAMDGPU = "83120cb1-ca15-4f04-bf3b-6967d2e6b60b"
LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda"
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[extensions]
LuxDeviceUtilsComponentArraysExt = "ComponentArrays"
LuxDeviceUtilsFillArraysExt = "FillArrays"
LuxDeviceUtilsLuxAMDGPUExt = "LuxAMDGPU"
LuxDeviceUtilsLuxCUDAExt = "LuxCUDA"
Expand All @@ -30,6 +32,7 @@ LuxDeviceUtilsZygoteExt = "Zygote"
[compat]
Adapt = "3"
ChainRulesCore = "1"
ComponentArrays = "0.13, 0.14"
FillArrays = "0.13, 1"
Functors = "0.2, 0.3, 0.4"
LuxAMDGPU = "0.1"
Expand All @@ -42,6 +45,7 @@ Zygote = "0.6"
julia = "1.6"

[extras]
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
LuxAMDGPU = "83120cb1-ca15-4f04-bf3b-6967d2e6b60b"
LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda"
Expand Down
10 changes: 10 additions & 0 deletions ext/LuxDeviceUtilsComponentArraysExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
module LuxDeviceUtilsComponentArraysExt

# FIXME: Needs upstreaming
using Adapt, ComponentArrays

function Adapt.adapt_structure(to, ca::ComponentArray)
return ComponentArray(adapt(to, getdata(ca)), getaxes(ca))
end

end
16 changes: 10 additions & 6 deletions src/LuxDeviceUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,11 @@ Return a tuple of supported GPU backends.
This is not the list of functional backends on the system, but rather backends which
`Lux.jl` supports.
!!! warning
`Metal.jl` support is **extremely** experimental and most things are not expected to
work.
"""
supported_gpu_backends() = map(_get_device_name, GPU_DEVICES)

Expand All @@ -87,8 +92,7 @@ Selects GPU device based on the following criteria:
"""
function gpu_device(; force_gpu_usage::Bool=false)::AbstractLuxDevice
if GPU_DEVICE[] !== nothing
force_gpu_usage &&
!(GPU_DEVICE[] isa AbstractLuxGPUDevice) &&
force_gpu_usage && !(GPU_DEVICE[] isa AbstractLuxGPUDevice) &&
throw(LuxDeviceSelectionException())
return GPU_DEVICE[]
end
Expand Down Expand Up @@ -202,10 +206,10 @@ Return a `LuxCPUDevice` object which can be used to transfer data to CPU.
"""
@inline cpu_device() = LuxCPUDevice()

(::LuxCPUDevice)(x) = fmap(x -> adapt(LuxCPUAdaptor(), x), x; exclude=_isleaf)
(::LuxCUDADevice)(x) = fmap(x -> adapt(LuxCUDAAdaptor(), x), x; exclude=_isleaf)
(::LuxAMDGPUDevice)(x) = fmap(x -> adapt(LuxAMDGPUAdaptor(), x), x; exclude=_isleaf)
(::LuxMetalDevice)(x) = fmap(x -> adapt(LuxMetalAdaptor(), x), x; exclude=_isleaf)
(::LuxCPUDevice)(x) = fmap(Base.Fix1(adapt, LuxCPUAdaptor()), x; exclude=_isleaf)
(::LuxCUDADevice)(x) = fmap(Base.Fix1(adapt, LuxCUDAAdaptor()), x; exclude=_isleaf)
(::LuxAMDGPUDevice)(x) = fmap(Base.Fix1(adapt, LuxAMDGPUAdaptor()), x; exclude=_isleaf)
(::LuxMetalDevice)(x) = fmap(Base.Fix1(adapt, LuxMetalAdaptor()), x; exclude=_isleaf)

for dev in (LuxCPUDevice, LuxCUDADevice, LuxAMDGPUDevice, LuxMetalDevice)
@eval begin
Expand Down
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
[deps]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Expand Down
17 changes: 17 additions & 0 deletions test/component_arrays.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
using LuxDeviceUtils, ComponentArrays, Random

@testset "https://github.com/LuxDL/LuxDeviceUtils.jl/issues/10 patch" begin
dev = LuxCPUDevice()
ps = (; weight=randn(10, 1), bias=randn(1))

ps_ca = ps |> ComponentArray

ps_ca_dev = ps_ca |> dev

@test ps_ca_dev isa ComponentArray

@test ps_ca_dev.weight == ps.weight
@test ps_ca_dev.bias == ps.bias

@test ps_ca_dev == (ps |> dev |> ComponentArray)
end
6 changes: 6 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,4 +47,10 @@ end
Aqua.test_all(LuxDeviceUtils; piracy=false)
end
end

@testset "Others" begin
@safetestset "Component Arrays" begin
include("component_arrays.jl")
end
end
end

2 comments on commit 07dc08a

@avik-pal
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/89508

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.1.6 -m "<description of version>" 07dc08a148cf9ed8bb73524b62879f5b2a2dd682
git push origin v0.1.6

Please sign in to comment.