diff --git a/Project.toml b/Project.toml index 7fdcdead..e4954feb 100644 --- a/Project.toml +++ b/Project.toml @@ -16,6 +16,7 @@ StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" [weakdeps] GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" +KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" @@ -25,6 +26,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [extensions] ComponentArraysGPUArraysExt = "GPUArrays" +ComponentArraysKernelAbstractionsExt = "KernelAbstractions" ComponentArraysOptimisersExt = "Optimisers" ComponentArraysRecursiveArrayToolsExt = "RecursiveArrayTools" ComponentArraysReverseDiffExt = "ReverseDiff" @@ -40,6 +42,7 @@ ConstructionBase = "1" ForwardDiff = "0.10.36" Functors = "0.4.12, 0.5" GPUArrays = "10, 11" +KernelAbstractions = "0.9.29" LinearAlgebra = "1.10" Optimisers = "0.3, 0.4" RecursiveArrayTools = "3.8" diff --git a/ext/ComponentArraysGPUArraysExt.jl b/ext/ComponentArraysGPUArraysExt.jl index 3b9a56e5..9329b2b4 100644 --- a/ext/ComponentArraysGPUArraysExt.jl +++ b/ext/ComponentArraysGPUArraysExt.jl @@ -8,16 +8,24 @@ const GPUComponentVector{T,Ax} = ComponentArray{T,1,<:GPUArrays.AbstractGPUVecto const GPUComponentMatrix{T,Ax} = ComponentArray{T,2,<:GPUArrays.AbstractGPUMatrix,Ax} const GPUComponentVecorMat{T,Ax} = Union{GPUComponentVector{T,Ax},GPUComponentMatrix{T,Ax}} -GPUArrays.backend(x::ComponentArray) = GPUArrays.backend(getdata(x)) +@static if pkgversion(GPUArrays) < v"11" + GPUArrays.backend(x::ComponentArray) = GPUArrays.backend(getdata(x)) -function Base.fill!(A::GPUComponentArray{T}, x) where {T} - length(A) == 0 && return A - GPUArrays.gpu_call(A, convert(T, x)) do ctx, a, val - idx = GPUArrays.@linearidx(a) - @inbounds a[idx] = val - return + function Base.fill!(A::GPUComponentArray{T}, x) where {T} + length(A) == 0 && return A + GPUArrays.gpu_call(A, convert(T, x)) do ctx, a, val + idx = GPUArrays.@linearidx(a) + @inbounds a[idx] = val + return + end + return A + end +else + function Base.fill!(A::GPUComponentArray{T}, x) where {T} + length(A) == 0 && return A + ComponentArrays.fill_componentarray_ka!(A, x) + return A end - A end LinearAlgebra.dot(x::GPUComponentArray, y::GPUComponentArray) = dot(getdata(x), getdata(y)) diff --git a/ext/ComponentArraysKernelAbstractionsExt.jl b/ext/ComponentArraysKernelAbstractionsExt.jl new file mode 100644 index 00000000..84d5c4aa --- /dev/null +++ b/ext/ComponentArraysKernelAbstractionsExt.jl @@ -0,0 +1,19 @@ +module ComponentArraysKernelAbstractionsExt + +using ComponentArrays: ComponentArrays, ComponentArray +using KernelAbstractions: KernelAbstractions, @kernel, @index + +KernelAbstractions.backend(x::ComponentArray) = KernelAbstractions.backend(getdata(x)) + +@kernel function ca_fill_kernel!(A, @Const(x)) + idx = @index(Global, Linear) + @inbounds A[idx] = x +end + +function ComponentArrays.fill_componentarray_ka!(A::ComponentArray{T}, x) where {T} + kernel! = ca_fill_kernel!(KernelAbstractions.get_backend(A)) + kernel!(A, x; ndrange=length(A)) + return A +end + +end diff --git a/src/componentarray.jl b/src/componentarray.jl index ecb0a632..76962f55 100644 --- a/src/componentarray.jl +++ b/src/componentarray.jl @@ -78,6 +78,7 @@ ComponentArray(x::ComponentArray) = x ComponentArray{T}(x::ComponentArray) where {T} = T.(x) (CA::Type{<:ComponentArray{T,N,A,Ax}})(x::ComponentArray) where {T,N,A,Ax} = ComponentArray(T.(getdata(x)), getaxes(x)) +function fill_componentarray_ka! end # defined in extensions ## Some aliases """