diff --git a/.JuliaFormatter.toml b/.JuliaFormatter.toml index d134ef2..dbc3116 100644 --- a/.JuliaFormatter.toml +++ b/.JuliaFormatter.toml @@ -4,6 +4,5 @@ always_use_return = true margin = 92 indent = 4 format_docstrings = true -join_lines_based_on_source = false separate_kwargs_with_semicolon = true always_for_in = true diff --git a/Project.toml b/Project.toml index b6b6eb6..714c201 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "LuxDeviceUtils" uuid = "34f89e08-e1d5-43b4-8944-0b49ac560553" authors = ["Avik Pal and contributors"] -version = "0.1.5" +version = "0.1.6" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" @@ -14,6 +14,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" [weakdeps] +ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" LuxAMDGPU = "83120cb1-ca15-4f04-bf3b-6967d2e6b60b" LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" @@ -21,6 +22,7 @@ Metal = "dde4c033-4e86-420c-a63e-0dd931031962" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [extensions] +LuxDeviceUtilsComponentArraysExt = "ComponentArrays" LuxDeviceUtilsFillArraysExt = "FillArrays" LuxDeviceUtilsLuxAMDGPUExt = "LuxAMDGPU" LuxDeviceUtilsLuxCUDAExt = "LuxCUDA" @@ -30,6 +32,7 @@ LuxDeviceUtilsZygoteExt = "Zygote" [compat] Adapt = "3" ChainRulesCore = "1" +ComponentArrays = "0.13, 0.14" FillArrays = "0.13, 1" Functors = "0.2, 0.3, 0.4" LuxAMDGPU = "0.1" @@ -42,6 +45,7 @@ Zygote = "0.6" julia = "1.6" [extras] +ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" LuxAMDGPU = "83120cb1-ca15-4f04-bf3b-6967d2e6b60b" LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" diff --git a/ext/LuxDeviceUtilsComponentArraysExt.jl b/ext/LuxDeviceUtilsComponentArraysExt.jl new file mode 100644 index 0000000..eaf3ac7 --- /dev/null +++ b/ext/LuxDeviceUtilsComponentArraysExt.jl @@ -0,0 +1,10 @@ +module LuxDeviceUtilsComponentArraysExt + +# FIXME: Needs upstreaming +using Adapt, ComponentArrays + +function Adapt.adapt_structure(to, ca::ComponentArray) + return ComponentArray(adapt(to, getdata(ca)), getaxes(ca)) +end + +end diff --git a/src/LuxDeviceUtils.jl b/src/LuxDeviceUtils.jl index ca439dd..45cd396 100644 --- a/src/LuxDeviceUtils.jl +++ b/src/LuxDeviceUtils.jl @@ -68,6 +68,11 @@ Return a tuple of supported GPU backends. This is not the list of functional backends on the system, but rather backends which `Lux.jl` supports. + +!!! warning + + `Metal.jl` support is **extremely** experimental and most things are not expected to + work. """ supported_gpu_backends() = map(_get_device_name, GPU_DEVICES) @@ -87,8 +92,7 @@ Selects GPU device based on the following criteria: """ function gpu_device(; force_gpu_usage::Bool=false)::AbstractLuxDevice if GPU_DEVICE[] !== nothing - force_gpu_usage && - !(GPU_DEVICE[] isa AbstractLuxGPUDevice) && + force_gpu_usage && !(GPU_DEVICE[] isa AbstractLuxGPUDevice) && throw(LuxDeviceSelectionException()) return GPU_DEVICE[] end @@ -202,10 +206,10 @@ Return a `LuxCPUDevice` object which can be used to transfer data to CPU. """ @inline cpu_device() = LuxCPUDevice() -(::LuxCPUDevice)(x) = fmap(x -> adapt(LuxCPUAdaptor(), x), x; exclude=_isleaf) -(::LuxCUDADevice)(x) = fmap(x -> adapt(LuxCUDAAdaptor(), x), x; exclude=_isleaf) -(::LuxAMDGPUDevice)(x) = fmap(x -> adapt(LuxAMDGPUAdaptor(), x), x; exclude=_isleaf) -(::LuxMetalDevice)(x) = fmap(x -> adapt(LuxMetalAdaptor(), x), x; exclude=_isleaf) +(::LuxCPUDevice)(x) = fmap(Base.Fix1(adapt, LuxCPUAdaptor()), x; exclude=_isleaf) +(::LuxCUDADevice)(x) = fmap(Base.Fix1(adapt, LuxCUDAAdaptor()), x; exclude=_isleaf) +(::LuxAMDGPUDevice)(x) = fmap(Base.Fix1(adapt, LuxAMDGPUAdaptor()), x; exclude=_isleaf) +(::LuxMetalDevice)(x) = fmap(Base.Fix1(adapt, LuxMetalAdaptor()), x; exclude=_isleaf) for dev in (LuxCPUDevice, LuxCUDADevice, LuxAMDGPUDevice, LuxMetalDevice) @eval begin diff --git a/test/Project.toml b/test/Project.toml index 71a2921..9aa4125 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,5 +1,6 @@ [deps] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" +ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" diff --git a/test/component_arrays.jl b/test/component_arrays.jl new file mode 100644 index 0000000..3825a22 --- /dev/null +++ b/test/component_arrays.jl @@ -0,0 +1,17 @@ +using LuxDeviceUtils, ComponentArrays, Random + +@testset "https://github.com/LuxDL/LuxDeviceUtils.jl/issues/10 patch" begin + dev = LuxCPUDevice() + ps = (; weight=randn(10, 1), bias=randn(1)) + + ps_ca = ps |> ComponentArray + + ps_ca_dev = ps_ca |> dev + + @test ps_ca_dev isa ComponentArray + + @test ps_ca_dev.weight == ps.weight + @test ps_ca_dev.bias == ps.bias + + @test ps_ca_dev == (ps |> dev |> ComponentArray) +end diff --git a/test/runtests.jl b/test/runtests.jl index aa9c898..0e10e2a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -47,4 +47,10 @@ end Aqua.test_all(LuxDeviceUtils; piracy=false) end end + + @testset "Others" begin + @safetestset "Component Arrays" begin + include("component_arrays.jl") + end + end end