diff --git a/Project.toml b/Project.toml index 01a7c1fc..7fdcdead 100644 --- a/Project.toml +++ b/Project.toml @@ -1,11 +1,13 @@ name = "ComponentArrays" uuid = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" authors = ["Jonnie Diegelman <47193959+jonniedie@users.noreply.github.com>"] -version = "0.15.18" +version = "0.15.19" [deps] +Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" @@ -13,8 +15,6 @@ StaticArrayInterface = "0d7ed370-da01-4f52-bd93-41d350b8b718" StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" [weakdeps] -Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" -ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" @@ -24,8 +24,6 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [extensions] -ComponentArraysAdaptExt = "Adapt" -ComponentArraysConstructionBaseExt = "ConstructionBase" ComponentArraysGPUArraysExt = "GPUArrays" ComponentArraysOptimisersExt = "Optimisers" ComponentArraysRecursiveArrayToolsExt = "RecursiveArrayTools" diff --git a/ext/ComponentArraysAdaptExt.jl b/ext/ComponentArraysAdaptExt.jl deleted file mode 100644 index 8e04c0a9..00000000 --- a/ext/ComponentArraysAdaptExt.jl +++ /dev/null @@ -1,13 +0,0 @@ -module ComponentArraysAdaptExt - -using ComponentArrays, Adapt - -function Adapt.adapt_structure(to, x::ComponentArray) - data = adapt(to, getdata(x)) - return ComponentArray(data, getaxes(x)) -end - -Adapt.adapt_storage(::Type{ComponentArray{T,N,A,Ax}}, xs::AT) where {T,N,A,Ax,AT<:AbstractArray} = - Adapt.adapt_storage(A, xs) - -end diff --git a/ext/ComponentArraysConstructionBaseExt.jl b/ext/ComponentArraysConstructionBaseExt.jl deleted file mode 100644 index db6eeb29..00000000 --- a/ext/ComponentArraysConstructionBaseExt.jl +++ /dev/null @@ -1,7 +0,0 @@ -module ComponentArraysConstructionBaseExt - -using ComponentArrays, ConstructionBase - -ConstructionBase.setproperties(x::ComponentVector, patch::NamedTuple) = ComponentVector(x; patch...) - -end diff --git a/src/ComponentArrays.jl b/src/ComponentArrays.jl index 0bf6e6ec..2f7c1bf8 100644 --- a/src/ComponentArrays.jl +++ b/src/ComponentArrays.jl @@ -2,6 +2,8 @@ module ComponentArrays import ChainRulesCore import StaticArrayInterface, ArrayInterface, Functors +import ConstructionBase +import Adapt using LinearAlgebra using StaticArraysCore: StaticArray, SArray, SVector, SMatrix @@ -9,7 +11,6 @@ using StaticArraysCore: StaticArray, SArray, SVector, SMatrix const FlatIdx = Union{Integer, CartesianIndex, CartesianIndices, AbstractArray{<:Integer}} const FlatOrColonIdx = Union{FlatIdx, Colon} - include("utils.jl") export fastindices # Deprecated diff --git a/src/componentarray.jl b/src/componentarray.jl index 00158204..ecb0a632 100644 --- a/src/componentarray.jl +++ b/src/componentarray.jl @@ -58,6 +58,14 @@ function ComponentArray(data, ax::AbstractAxis...) return LazyArray(ComponentArray(x, axs...) for x in part_data) end +function Adapt.adapt_structure(to, x::ComponentArray) + data = Adapt.adapt(to, getdata(x)) + return ComponentArray(data, getaxes(x)) +end + +Adapt.adapt_storage(::Type{ComponentArray{T,N,A,Ax}}, xs::AT) where {T,N,A,Ax,AT<:AbstractArray} = + Adapt.adapt_storage(A, xs) + # Entry from NamedTuple, Dict, or kwargs ComponentArray{T}(nt::NamedTuple) where T = ComponentArray(make_carray_args(T, nt)...) ComponentArray{T}(::NamedTuple{(), Tuple{}}) where T = ComponentArray(T[], (FlatAxis(),)) @@ -89,6 +97,8 @@ ComponentVector{T}(::UndefInitializer, ax) where {T} = ComponentArray{T}(undef, ComponentVector(data::AbstractVector, ax) = ComponentArray(data, ax) ComponentVector(data::AbstractArray, ax) = throw(DimensionMismatch("A `ComponentVector` must be initialized with a 1-dimensional array. This array is $(ndims(data))-dimensional.")) +ConstructionBase.setproperties(x::ComponentVector, patch::NamedTuple) = ComponentVector(x; patch...) + # Add new fields to component Vector function ComponentArray(x::ComponentVector; kwargs...) return foldl((x1, kwarg) -> _maybe_add_field(x1, kwarg), (kwargs...,); init=x)