From 3081540fb81ad9d26d4aa751de47753decfd93f0 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Sun, 5 Jan 2025 16:32:58 -0500 Subject: [PATCH] Improve logic for determining NamedDimsArray type from dimension name type --- Project.toml | 2 +- src/abstractnameddimsarray.jl | 31 ++++++++++++++++++++----------- 2 files changed, 21 insertions(+), 12 deletions(-) diff --git a/Project.toml b/Project.toml index 281a398..bc3ff1d 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "NamedDimsArrays" uuid = "60cbd0c0-df58-4cb7-918c-6f5607b73fde" authors = ["ITensor developers and contributors"] -version = "0.3.0" +version = "0.3.1" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/src/abstractnameddimsarray.jl b/src/abstractnameddimsarray.jl index fc1aa00..bf4eda6 100644 --- a/src/abstractnameddimsarray.jl +++ b/src/abstractnameddimsarray.jl @@ -140,18 +140,34 @@ end Base.copy(a::AbstractNamedDimsArray) = nameddims(copy(dename(a)), nameddimsindices(a)) +const NamedDimsIndices = Union{ + AbstractNamedUnitRange{<:Integer},AbstractNamedArray{<:Integer} +} +const NamedDimsAxis = AbstractNamedUnitRange{ + <:Integer,<:AbstractUnitRange,<:NamedDimsIndices +} + # Generic constructor. function nameddims(a::AbstractArray, nameddimsindices) # TODO: Check the shape of `nameddimsindices` matches the shape of `a`. - return nameddimstype(eltype(nameddimsindices))( - a, to_nameddimsindices(a, nameddimsindices) - ) + arrtype = mapreduce(nameddimsarraytype, combine_nameddimsarraytype, nameddimsindices) + return arrtype(a, to_nameddimsindices(a, nameddimsindices)) end # Can overload this to get custom named dims array wrapper # depending on the dimension name types, for example # output an `ITensor` if the dimension names are `IndexName`s. -nameddimstype(dimnametype::Type) = NamedDimsArray +nameddimsarraytype(nameddim) = nameddimsarraytype(typeof(nameddim)) +nameddimsarraytype(nameddimtype::Type) = NamedDimsArray +function nameddimsarraytype(nameddimtype::Type{<:NamedDimsIndices}) + return nameddimsarraytype(nametype(nameddimtype)) +end +function combine_nameddimsarraytype( + ::Type{<:AbstractNamedDimsArray}, ::Type{<:AbstractNamedDimsArray} +) + return NamedDimsArray +end +combine_nameddimsarraytype(::Type{T}, ::Type{T}) where {T<:AbstractNamedDimsArray} = T Base.axes(a::AbstractNamedDimsArray) = map(named, axes(dename(a)), nameddimsindices(a)) Base.size(a::AbstractNamedDimsArray) = map(named, size(dename(a)), nameddimsindices(a)) @@ -175,13 +191,6 @@ Base.eltype(a::AbstractNamedDimsArray) = eltype(dename(a)) Base.axes(a::AbstractNamedDimsArray, dimname::Name) = axes(a, dim(a, dimname)) Base.size(a::AbstractNamedDimsArray, dimname::Name) = size(a, dim(a, dimname)) -const NamedDimsIndices = Union{ - AbstractNamedUnitRange{<:Integer},AbstractNamedArray{<:Integer} -} -const NamedDimsAxis = AbstractNamedUnitRange{ - <:Integer,<:AbstractUnitRange,<:NamedDimsIndices -} - to_nameddimsaxes(dims) = map(to_nameddimsaxis, dims) to_nameddimsaxis(ax::NamedDimsAxis) = ax to_nameddimsaxis(I::NamedDimsIndices) = named(dename(only(axes(I))), I)