diff --git a/Project.toml b/Project.toml index dcced62..281a398 100644 --- a/Project.toml +++ b/Project.toml @@ -1,10 +1,11 @@ name = "NamedDimsArrays" uuid = "60cbd0c0-df58-4cb7-918c-6f5607b73fde" authors = ["ITensor developers and contributors"] -version = "0.2.0" +version = "0.3.0" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" +ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a" BroadcastMapConversion = "4a4adec5-520f-4750-bb37-d5e66b4ddeb2" Derive = "a07dfc7f-7d04-4eb5-84cc-a97f051f655a" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" @@ -13,8 +14,16 @@ SimpleTraits = "699a6c99-e7fa-54fc-8d76-47d257e15c1d" TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a" TypeParameterAccessors = "7e5a90cf-f82e-492e-a09b-e3e26432c138" +[weakdeps] +BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e" + +[extensions] +NamedDimsArraysBlockArraysExt = "BlockArrays" + [compat] Adapt = "4.1.1" +ArrayLayouts = "1.11.0" +BlockArrays = "1.3.0" BroadcastMapConversion = "0.1.2" Derive = "0.3.6" LinearAlgebra = "1.10" diff --git a/README.md b/README.md index ecc0fd9..5d863c1 100644 --- a/README.md +++ b/README.md @@ -32,33 +32,67 @@ julia> Pkg.add("NamedDimsArrays") ## Examples ````julia -using NamedDimsArrays: aligndims, dename, dimnames, named +using NamedDimsArrays: aligndims, dimnames, named, nameddimsindices, namedoneto, unname using TensorAlgebra: contract +using Test: @test # Named dimensions -i = named(2, "i") -j = named(2, "j") -k = named(2, "k") +i = namedoneto(2, "i") +j = namedoneto(2, "j") +k = namedoneto(2, "k") # Arrays with named dimensions -na1 = randn(i, j) -na2 = randn(j, k) +a1 = randn(i, j) +a2 = randn(j, k) -@show dimnames(na1) == ("i", "j") +@test dimnames(a1) == ("i", "j") +@test nameddimsindices(a1) == (i, j) +@test axes(a1) == (named(1:2, i), named(1:2, j)) +@test size(a1) == (named(2, i), named(2, j)) # Indexing -@show na1[j => 2, i => 1] == na1[1, 2] +@test a1[j => 2, i => 1] == a1[1, 2] +@test a1[j[2], i[1]] == a1[1, 2] # Tensor contraction -na_dest = contract(na1, na2) +a_dest = contract(a1, a2) -@show issetequal(dimnames(na_dest), ("i", "k")) -# `dename` removes the names and returns an `Array` -@show dename(na_dest, (i, k)) ≈ dename(na1) * dename(na2) +@test issetequal(nameddimsindices(a_dest), (i, k)) +# `unname` removes the names and returns an `Array` +@test unname(a_dest, (i, k)) ≈ unname(a1, (i, j)) * unname(a2, (j, k)) # Permute dimensions (like `ITensors.permute`) -na1 = aligndims(na1, (j, i)) -@show na1[i => 1, j => 2] == na1[2, 1] +a1′ = aligndims(a1, (j, i)) +@test a1′[i => 1, j => 2] == a1[i => 1, j => 2] +@test a1′[i[1], j[2]] == a1[i[1], j[2]] + +# Contiguous slicing +b1 = a1[i => 1:2, j => 1:1] +@test b1 == a1[i[1:2], j[1:1]] + +b2 = a2[j => 1:1, k => 1:2] +@test b2 == a2[j[1:1], k[1:2]] + +@test nameddimsindices(b1) == (i[1:2], j[1:1]) +@test nameddimsindices(b2) == (j[1:1], k[1:2]) + +b_dest = contract(b1, b2) + +@test issetequal(nameddimsindices(b_dest), (i, k)) + +# Non-contiguous slicing +c1 = a1[i[[2, 1]], j[[2, 1]]] +@test nameddimsindices(c1) == (i[[2, 1]], j[[2, 1]]) +@test unname(c1, (i[[2, 1]], j[[2, 1]])) == unname(a1, (i, j))[[2, 1], [2, 1]] +@test c1[i[2], j[1]] == a1[i[2], j[1]] +@test c1[2, 1] == a1[1, 2] + +a1[i[[2, 1]], j[[2, 1]]] = [22 21; 12 11] +@test a1[i[1], j[1]] == 11 + +x = randn(i[1:2], j[2:2]) +a1[i[1:2], j[2:2]] = x +@test a1[i[1], j[2]] == x[i[1], j[2]] ```` --- diff --git a/TODO.md b/TODO.md index 5208a5b..bef8805 100644 --- a/TODO.md +++ b/TODO.md @@ -1,20 +1,33 @@ +- Define `@align`/`@aligned` such that: +```julia +i = namedoneto(2, "i") +j = namedoneto(2, "j") +a = randn(i, j) +@align a[j, i] +@aligned a[j, i] +``` +aligns the dimensions (currently `a[j, i]` doesn't align the dimensions). +It could be written in terms of `align_getindex`/`align_view`. - `svd`, `eigen` (including tensor versions) -- `reshape`, `vec` -- `swapdimnames` -- `mapdimnames(f, a::AbstractNamedDimsArray)` (rename `replacedimnames(f, a)` to `mapdimnames(f, a)`, or have both?) +- `reshape`, `vec`, including fused dimension names. +- Dimension name set logic, i.e. `setdiffnameddimsindices(a::AbstractNamedDimsArray, b::AbstractNamedDimsArray)`, etc. +- `swapnameddimsindices` (written in terms of `mapnameddimsindices`/`replacenameddimsindices`). +- `mapnameddimsindices(f, a::AbstractNamedDimsArray)` (rename `replacenameddimsindices(f, a)` to `mapnameddimsindices(f, a)`, or have both?) - `cat` (define `CatName` as a combination of the input names?). - `canonize`/`flatten_array_wrappers` (https://github.com/mcabbott/NamedPlus.jl/blob/v0.0.5/src/permute.jl#L207) - - `nameddims(PermutedDimsArray(a, perm), dimnames)` -> `nameddims(a, dimnames[invperm(perm)])` - - `nameddims(transpose(a), dimnames)` -> `nameddims(a, reverse(dimnames))` - - `Transpose(nameddims(a, dimnames))` -> `nameddims(a, reverse(dimnames))` + - `nameddims(PermutedDimsArray(a, perm), nameddimsindices)` -> `nameddims(a, nameddimsindices[invperm(perm)])` + - `nameddims(transpose(a), nameddimsindices)` -> `nameddims(a, reverse(nameddimsindices))` + - `Transpose(nameddims(a, nameddimsindices))` -> `nameddims(a, reverse(nameddimsindices))` - etc. - `MappedName(old_name, name)`, acts like `Name(name)` but keeps track of the old name. - - `namedmap(a, ::Pair...)`: `namedmap(named(randn(2, 2, 2, 2), i, j, k, l), i => k, j => l)` + - `nameddimsmap(a, ::Pair...)`: `namedmap(named(randn(2, 2, 2, 2), i, j, k, l), i => k, j => l)` represents that the names map back and forth to each other for the sake of `transpose`, `tr`, `eigen`, etc. Operators are generally `namedmap(named(randn(2, 2), i, i'), i => i')`. - `prime(:i) = PrimedName(:i)`, `prime(:i, 2) = PrimedName(:i, 2)`, `prime(prime(:i)) = PrimedName(:i, 2)`, `Name(:i)' = prime(:i)`, etc. -- `transpose`/`adjoint` based on `swapdimnames` and `MappedName(old_name, new_name)`. + - Also `prime(f, a::AbstractNamedDimsArray)` where `f` is a filter function to determine + which dimensions to filter. +- `transpose`/`adjoint` based on `swapnameddimsindices` and `MappedName(old_name, new_name)`. - `adjoint` could make use of a lazy `ConjArray`. - `transpose(a, dimname1 => dimname1′, dimname2 => dimname2′)` like `https://github.com/mcabbott/NamedPlus.jl`. - Same as `replacedims(a, dimname1 => dimname1′, dimname1′ => dimname1, dimname2 => dimname2′, dimname2′ => dimname2)`. @@ -23,4 +36,5 @@ - Slicing: `nameddims(a, "i", "j")[1:2, 1:2] = nameddims(a[1:2, 1:2], Name(named(1:2, "i")), Name(named(1:2, "j")))`, i.e. the parent gets sliced and the new dimensions names are the named slice. - Should `NamedDimsArray` store the named axes rather than just the dimension names? - - Should `NamedDimsArray` have special axes types so that `axes(nameddims(a, "i", "j")) == axes(nameddims(a', "j", "i"))`? + - Should `NamedDimsArray` have special axes types so that `axes(nameddims(a, "i", "j")) == axes(nameddims(a', "j", "i"))`, + i.e. equality is based on `issetequal` and not dependent on the ordering of the dimensions? diff --git a/examples/README.jl b/examples/README.jl index 7d48576..3c0e38c 100644 --- a/examples/README.jl +++ b/examples/README.jl @@ -37,30 +37,64 @@ julia> Pkg.add("NamedDimsArrays") # ## Examples -using NamedDimsArrays: aligndims, dename, dimnames, named +using NamedDimsArrays: aligndims, dimnames, named, nameddimsindices, namedoneto, unname using TensorAlgebra: contract +using Test: @test ## Named dimensions -i = named(2, "i") -j = named(2, "j") -k = named(2, "k") +i = namedoneto(2, "i") +j = namedoneto(2, "j") +k = namedoneto(2, "k") ## Arrays with named dimensions -na1 = randn(i, j) -na2 = randn(j, k) +a1 = randn(i, j) +a2 = randn(j, k) -@show dimnames(na1) == ("i", "j") +@test dimnames(a1) == ("i", "j") +@test nameddimsindices(a1) == (i, j) +@test axes(a1) == (named(1:2, i), named(1:2, j)) +@test size(a1) == (named(2, i), named(2, j)) ## Indexing -@show na1[j => 2, i => 1] == na1[1, 2] +@test a1[j => 2, i => 1] == a1[1, 2] +@test a1[j[2], i[1]] == a1[1, 2] ## Tensor contraction -na_dest = contract(na1, na2) +a_dest = contract(a1, a2) -@show issetequal(dimnames(na_dest), ("i", "k")) -## `dename` removes the names and returns an `Array` -@show dename(na_dest, (i, k)) ≈ dename(na1) * dename(na2) +@test issetequal(nameddimsindices(a_dest), (i, k)) +## `unname` removes the names and returns an `Array` +@test unname(a_dest, (i, k)) ≈ unname(a1, (i, j)) * unname(a2, (j, k)) ## Permute dimensions (like `ITensors.permute`) -na1 = aligndims(na1, (j, i)) -@show na1[i => 1, j => 2] == na1[2, 1] +a1′ = aligndims(a1, (j, i)) +@test a1′[i => 1, j => 2] == a1[i => 1, j => 2] +@test a1′[i[1], j[2]] == a1[i[1], j[2]] + +## Contiguous slicing +b1 = a1[i => 1:2, j => 1:1] +@test b1 == a1[i[1:2], j[1:1]] + +b2 = a2[j => 1:1, k => 1:2] +@test b2 == a2[j[1:1], k[1:2]] + +@test nameddimsindices(b1) == (i[1:2], j[1:1]) +@test nameddimsindices(b2) == (j[1:1], k[1:2]) + +b_dest = contract(b1, b2) + +@test issetequal(nameddimsindices(b_dest), (i, k)) + +## Non-contiguous slicing +c1 = a1[i[[2, 1]], j[[2, 1]]] +@test nameddimsindices(c1) == (i[[2, 1]], j[[2, 1]]) +@test unname(c1, (i[[2, 1]], j[[2, 1]])) == unname(a1, (i, j))[[2, 1], [2, 1]] +@test c1[i[2], j[1]] == a1[i[2], j[1]] +@test c1[2, 1] == a1[1, 2] + +a1[i[[2, 1]], j[[2, 1]]] = [22 21; 12 11] +@test a1[i[1], j[1]] == 11 + +x = randn(i[1:2], j[2:2]) +a1[i[1:2], j[2:2]] = x +@test a1[i[1], j[2]] == x[i[1], j[2]] diff --git a/ext/NamedDimsArraysBlockArraysExt/NamedDimsArraysBlockArraysExt.jl b/ext/NamedDimsArraysBlockArraysExt/NamedDimsArraysBlockArraysExt.jl new file mode 100644 index 0000000..05d7d56 --- /dev/null +++ b/ext/NamedDimsArraysBlockArraysExt/NamedDimsArraysBlockArraysExt.jl @@ -0,0 +1,45 @@ +module NamedDimsArraysBlockArraysExt +using ArrayLayouts: ArrayLayouts +using BlockArrays: Block, BlockRange +using NamedDimsArrays: + AbstractNamedDimsArray, + AbstractNamedUnitRange, + named_getindex, + nameddims_getindex, + nameddims_view + +function Base.getindex(r::AbstractNamedUnitRange{<:Integer}, I::Block{1}) + # TODO: Use `Derive.@interface NamedArrayInterface() r[I]` instead. + return named_getindex(r, I) +end + +function Base.getindex(r::AbstractNamedUnitRange{<:Integer}, I::BlockRange{1}) + # TODO: Use `Derive.@interface NamedArrayInterface() r[I]` instead. + return named_getindex(r, I) +end + +const BlockIndex{N} = Union{Block{N},BlockRange{N},AbstractVector{<:Block{N}}} + +function Base.view(a::AbstractNamedDimsArray, I1::Block{1}, Irest::BlockIndex{1}...) + # TODO: Use `Derive.@interface NamedDimsArrayInterface() r[I]` instead. + return nameddims_view(a, I1, Irest...) +end + +function Base.view(a::AbstractNamedDimsArray, I::Block) + # TODO: Use `Derive.@interface NamedDimsArrayInterface() r[I]` instead. + return nameddims_view(a, Tuple(I)...) +end + +function Base.view(a::AbstractNamedDimsArray, I1::BlockIndex{1}, Irest::BlockIndex{1}...) + # TODO: Use `Derive.@interface NamedDimsArrayInterface() r[I]` instead. + return nameddims_view(a, I1, Irest...) +end + +# Fix ambiguity error. +function Base.getindex( + a::AbstractNamedDimsArray, I1::BlockRange{1}, Irest::BlockRange{1}... +) + return ArrayLayouts.layout_getindex(a, I1, Irest...) +end + +end diff --git a/src/NamedDimsArrays.jl b/src/NamedDimsArrays.jl index d3fc0b9..c763551 100644 --- a/src/NamedDimsArrays.jl +++ b/src/NamedDimsArrays.jl @@ -4,6 +4,8 @@ include("isnamed.jl") include("randname.jl") include("abstractnamedinteger.jl") include("namedinteger.jl") +include("abstractnamedarray.jl") +include("namedarray.jl") include("abstractnamedunitrange.jl") include("namedunitrange.jl") include("abstractnameddimsarray.jl") diff --git a/src/abstractnamedarray.jl b/src/abstractnamedarray.jl new file mode 100644 index 0000000..7e14d24 --- /dev/null +++ b/src/abstractnamedarray.jl @@ -0,0 +1,79 @@ +using TypeParameterAccessors: unspecify_type_parameters + +abstract type AbstractNamedArray{T,N,Value<:AbstractArray,Name} <: AbstractArray{T,N} end + +const AbstractNamedVector{T,Value<:AbstractVector,Name} = AbstractNamedArray{T,1,Value,Name} +const AbstractNamedMatrix{T,Value<:AbstractVector,Name} = AbstractNamedArray{T,2,Value,Name} + +# Minimal interface. +dename(a::AbstractNamedArray) = throw(MethodError(dename, Tuple{typeof(a)})) +name(a::AbstractNamedArray) = throw(MethodError(name, Tuple{typeof(a)})) + +# This can be customized to output different named integer types, +# such as `namedarray(a::AbstractArray, name::IndexName) = Index(a, name)`. +namedarray(a::AbstractArray, name) = NamedArray(a, name) + +# Shorthand. +named(a::AbstractArray, name) = namedarray(a, name) + +# Derived interface. +# TODO: Use `Accessors.@set`? +setname(a::AbstractNamedArray, name) = namedarray(dename(a), name) + +# TODO: Use `TypeParameterAccessors`. +denametype(::Type{<:AbstractNamedArray{<:Any,<:Any,Value}}) where {Value} = Value +nametype(::Type{<:AbstractNamedArray{<:Any,<:Any,<:Any,Name}}) where {Name} = Name + +# Traits. +isnamed(::Type{<:AbstractNamedArray}) = true + +# TODO: Should they also have the same base type? +function Base.:(==)(a1::AbstractNamedArray, a2::AbstractNamedArray) + return name(a1) == name(a2) && dename(a1) == dename(a2) +end +function Base.hash(a::AbstractNamedArray, h::UInt) + h = hash(Symbol(unspecify_type_parameters(typeof(a))), h) + # TODO: Double check how this is handling blocking/sector information. + h = hash(dename(a), h) + return hash(name(a), h) +end + +named_getindex(a::AbstractArray, I...) = named(getindex(dename(a), I...), name(a)) + +# Array funcionality. +Base.size(a::AbstractNamedArray) = map(s -> named(s, name(a)), size(dename(a))) +Base.axes(a::AbstractNamedArray) = map(s -> named(s, name(a)), axes(dename(a))) +Base.eachindex(a::AbstractNamedArray) = eachindex(dename(a)) +function Base.getindex(a::AbstractNamedArray{<:Any,N}, I::Vararg{Int,N}) where {N} + return named_getindex(a, I...) +end +function Base.getindex(a::AbstractNamedArray, I::Int) + return named_getindex(a, I) +end +Base.isempty(a::AbstractNamedArray) = isempty(dename(a)) + +## function Base.AbstractArray{Int}(a::AbstractNamedArray) +## return AbstractArray{Int}(dename(a)) +## end +## +## Base.iterate(a::AbstractNamedArray) = isempty(a) ? nothing : (first(a), first(a)) +## function Base.iterate(a::AbstractNamedArray, i) +## i == last(a) && return nothing +## next = named(dename(i) + dename(step(a)), name(a)) +## return (next, next) +## end + +function randname(ang::AbstractRNG, a::AbstractNamedArray) + return named(dename(a), randname(name(a))) +end + +function Base.show(io::IO, a::AbstractNamedArray) + print(io, "named(", dename(a), ", ", repr(name(a)), ")") + return nothing +end +function Base.show(io::IO, mime::MIME"text/plain", a::AbstractNamedArray) + print(io, "named(\n") + show(io, mime, dename(a)) + print(io, ",\n ", repr(name(a)), ")") + return nothing +end diff --git a/src/abstractnameddimsarray.jl b/src/abstractnameddimsarray.jl index f2320a6..fc1aa00 100644 --- a/src/abstractnameddimsarray.jl +++ b/src/abstractnameddimsarray.jl @@ -4,6 +4,7 @@ using Derive: Derive, @derive, AbstractArrayInterface # https://github.com/ITensor/ITensors.jl # https://github.com/invenia/NamedDims.jl # https://github.com/mcabbott/NamedPlus.jl +# https://pytorch.org/docs/stable/named_tensor.html abstract type AbstractNamedDimsArrayInterface <: AbstractArrayInterface end @@ -14,76 +15,200 @@ abstract type AbstractNamedDimsArray{T,N} <: AbstractArray{T,N} end const AbstractNamedDimsVector{T} = AbstractNamedDimsArray{T,1} const AbstractNamedDimsMatrix{T} = AbstractNamedDimsArray{T,2} -Derive.interface(::Type{<:AbstractNamedDimsArray}) = AbstractNamedDimsArrayInterface() +Derive.interface(::Type{<:AbstractNamedDimsArray}) = NamedDimsArrayInterface() # Output the dimension names. -dimnames(a::AbstractArray) = throw(MethodError(dimnames, Tuple{typeof(a)})) +nameddimsindices(a::AbstractArray) = throw(MethodError(nameddimsindices, Tuple{typeof(a)})) # Unwrapping the names Base.parent(a::AbstractNamedDimsArray) = throw(MethodError(parent, Tuple{typeof(a)})) -dimnames(a::AbstractArray, dim::Int) = dimnames(a)[dim] +nameddimsindices(a::AbstractArray, dim::Int) = nameddimsindices(a)[dim] -dim(a::AbstractArray, n) = findfirst(==(name(n)), dimnames(a)) +function dimnames(a::AbstractNamedDimsArray) + return name.(nameddimsindices(a)) +end +function dimnames(a::AbstractNamedDimsArray, dim::Int) + return dimnames(a)[dim] +end + +function dim(a::AbstractArray, n) + dimname = to_dimname(a, n) + return findfirst(==(dimname), nameddimsindices(a)) +end dims(a::AbstractArray, ns) = map(n -> dim(a, n), ns) +dimname_isequal(x) = Base.Fix1(dimname_isequal, x) +dimname_isequal(x, y) = isequal(x, y) + +dimname_isequal(r1::AbstractNamedArray, r2::AbstractNamedArray) = isequal(r1, r2) +dimname_isequal(r1::AbstractNamedArray, r2) = name(r1) == r2 +dimname_isequal(r1, r2::AbstractNamedArray) = r1 == name(r2) + +dimname_isequal(r1::AbstractNamedArray, r2::Name) = name(r1) == name(r2) +dimname_isequal(r1::Name, r2::AbstractNamedArray) = name(r1) == name(r2) + +dimname_isequal(r1::AbstractNamedUnitRange, r2::AbstractNamedUnitRange) = isequal(r1, r2) +dimname_isequal(r1::AbstractNamedUnitRange, r2) = name(r1) == r2 +dimname_isequal(r1, r2::AbstractNamedUnitRange) = r1 == name(r2) + +dimname_isequal(r1::AbstractNamedUnitRange, r2::Name) = name(r1) == name(r2) +dimname_isequal(r1::Name, r2::AbstractNamedUnitRange) = name(r1) == name(r2) + +function to_nameddimsindices(a::AbstractArray, dims) + return to_nameddimsindices(a, axes(a), dims) +end +function to_nameddimsindices(a::AbstractArray, axes, dims) + return map((axis, dim) -> to_dimname(a, axis, dim), axes, dims) +end +function to_dimname(a::AbstractArray, axis, dim::AbstractNamedArray) + # TODO: Check `axis` and `dim` have the same shape? + return dim +end +function to_dimname(a::AbstractArray, axis, dim::AbstractNamedUnitRange) + # TODO: Check `axis` and `dim` have the same shape? + return dim +end +# This is the case where just the name of the axis +# was specified without a range, like: +# ```julia +# a = randn(named(2, "i"), named(2, "j")) +# aligndims(a, ("i", "j")) +# ``` +function to_dimname(a::AbstractArray, axis, dim) + return named(axis, dim) +end +function to_dimname(a::AbstractArray, axis, dim::Name) + return to_dimname(a, axis, name(dim)) +end + +function to_dimname(a::AbstractNamedDimsArray, dimname) + dim = findfirst(dimname_isequal(dimname), nameddimsindices(a)) + return to_dimname(a, axes(a, dim), dimname) +end + +function to_dimname(a::AbstractNamedDimsArray, axis, dim::AbstractNamedArray) + return dim +end +function to_dimname(a::AbstractNamedDimsArray, axis, dim::AbstractNamedUnitRange) + return dim +end +function to_dimname(a::AbstractNamedDimsArray, axis, dim) + return named(dename(axis), dim) +end +function to_dimname(a::AbstractNamedDimsArray, axis, dim::Name) + return to_dimname(a, axis, name(dim)) +end + +function to_nameddimsindices(a::AbstractNamedDimsArray, dims) + return map(dim -> to_dimname(a, dim), dims) +end + # Unwrapping the names (`NamedDimsArrays.jl` interface). # TODO: Use `IsNamed` trait? dename(a::AbstractNamedDimsArray) = parent(a) -function dename(a::AbstractNamedDimsArray, dimnames) - return dename(aligndims(a, dimnames)) +function dename(a::AbstractNamedDimsArray, nameddimsindices) + return dename(aligndims(a, nameddimsindices)) end -function denamed(a::AbstractNamedDimsArray, dimnames) - return dename(aligneddims(a, dimnames)) +function denamed(a::AbstractNamedDimsArray, nameddimsindices) + return dename(aligneddims(a, nameddimsindices)) end -unname(a::AbstractArray, dimnames) = dename(a, dimnames) -unnamed(a::AbstractArray, dimnames) = denamed(a, dimnames) +unname(a::AbstractArray, nameddimsindices) = dename(a, nameddimsindices) +unnamed(a::AbstractArray, nameddimsindices) = denamed(a, nameddimsindices) isnamed(::Type{<:AbstractNamedDimsArray}) = true -# Can overload this to get custom named dims array wrapper -# depending on the dimension name types, for example -# output an `ITensor` if the dimension names are `IndexName`s. -@traitfn function nameddims(a::AbstractArray::!(IsNamed), dims) - dimnames = name.(dims) - # TODO: Check the shape of `dename.(dims)` matches the shape of `a`. - # `mapreduce(typeof, promote_type, xs) == Base.promote_typeof(xs...)`. - return nameddimstype(eltype(dimnames))(a, dimnames) +# TODO: Move to `utils.jl` file. +# TODO: Use `Base.indexin`? +function getperm(x, y; isequal=isequal) + return map(yᵢ -> findfirst(isequal(yᵢ), x), y) end -@traitfn function nameddims(a::AbstractArray::IsNamed, dims) - return aligneddims(a, dims) + +# TODO: Move to `utils.jl` file. +function checked_indexin(x, y) + I = indexin(x, y) + return something.(I) end -function Base.view(a::AbstractArray, dimnames::AbstractName...) - return nameddims(a, dimnames) +function checked_indexin(x::Number, y) + return findfirst(==(x), y) end -function Base.getindex(a::AbstractArray, dimnames::AbstractName...) - return copy(@view(a[dimnames...])) + +function checked_indexin(x::AbstractUnitRange, y::AbstractUnitRange) + return findfirst(==(first(x)), y):findfirst(==(last(x)), y) end -Base.copy(a::AbstractNamedDimsArray) = nameddims(copy(dename(a)), dimnames(a)) +Base.copy(a::AbstractNamedDimsArray) = nameddims(copy(dename(a)), nameddimsindices(a)) + +# Generic constructor. +function nameddims(a::AbstractArray, nameddimsindices) + # TODO: Check the shape of `nameddimsindices` matches the shape of `a`. + return nameddimstype(eltype(nameddimsindices))( + a, to_nameddimsindices(a, nameddimsindices) + ) +end # Can overload this to get custom named dims array wrapper # depending on the dimension name types, for example # output an `ITensor` if the dimension names are `IndexName`s. nameddimstype(dimnametype::Type) = NamedDimsArray -Base.axes(a::AbstractNamedDimsArray) = map(named, axes(dename(a)), dimnames(a)) -Base.size(a::AbstractNamedDimsArray) = map(named, size(dename(a)), dimnames(a)) +Base.axes(a::AbstractNamedDimsArray) = map(named, axes(dename(a)), nameddimsindices(a)) +Base.size(a::AbstractNamedDimsArray) = map(named, size(dename(a)), nameddimsindices(a)) -Base.axes(a::AbstractArray, dimname::AbstractName) = axes(a, dim(a, dimname)) -Base.size(a::AbstractArray, dimname::AbstractName) = size(a, dim(a, dimname)) +# Circumvent issue when ndims isn't known at compile time. +function Base.axes(a::AbstractNamedDimsArray, d) + return d <= ndims(a) ? axes(a)[d] : OneTo(1) +end -setdimnames(a::AbstractNamedDimsArray, dimnames) = nameddims(dename(a), name.(dimnames)) -function replacedimnames(f, a::AbstractNamedDimsArray) - return setdimnames(a, replace(f, dimnames(a))) +# Circumvent issue when ndims isn't known at compile time. +function Base.size(a::AbstractNamedDimsArray, d) + return d <= ndims(a) ? size(a)[d] : OneTo(1) end -function replacedimnames(a::AbstractNamedDimsArray, replacements::Pair...) - replacement_names = map(replacements) do replacement - name(first(replacement)) => name(last(replacement)) - end - new_dimnames = replace(dimnames(a), replacement_names...) - return setdimnames(a, new_dimnames) + +# Circumvent issue when ndims isn't known at compile time. +Base.ndims(a::AbstractNamedDimsArray) = ndims(dename(a)) + +# Circumvent issue when eltype isn't known at compile time. +Base.eltype(a::AbstractNamedDimsArray) = eltype(dename(a)) + +Base.axes(a::AbstractNamedDimsArray, dimname::Name) = axes(a, dim(a, dimname)) +Base.size(a::AbstractNamedDimsArray, dimname::Name) = size(a, dim(a, dimname)) + +const NamedDimsIndices = Union{ + AbstractNamedUnitRange{<:Integer},AbstractNamedArray{<:Integer} +} +const NamedDimsAxis = AbstractNamedUnitRange{ + <:Integer,<:AbstractUnitRange,<:NamedDimsIndices +} + +to_nameddimsaxes(dims) = map(to_nameddimsaxis, dims) +to_nameddimsaxis(ax::NamedDimsAxis) = ax +to_nameddimsaxis(I::NamedDimsIndices) = named(dename(only(axes(I))), I) + +function Base.similar( + a::AbstractArray, elt::Type, inds::Tuple{NamedDimsIndices,Vararg{NamedDimsIndices}} +) + ax = to_nameddimsaxes(inds) + return nameddims(similar(unname(a), elt, dename.(ax)), name.(ax)) +end + +function setnameddimsindices(a::AbstractNamedDimsArray, nameddimsindices) + return nameddims(dename(a), nameddimsindices) +end +function replacenameddimsindices(f, a::AbstractNamedDimsArray) + return setnameddimsindices(a, replace(f, nameddimsindices(a))) +end +function replacenameddimsindices( + a::AbstractNamedDimsArray, + replacements::Pair{<:AbstractNamedUnitRange,<:AbstractNamedUnitRange}..., +) + return setnameddimsindices(a, replace(nameddimsindices(a), replacements...)) +end +function replacenameddimsindices(a::AbstractNamedDimsArray, replacements::Pair...) + old_nameddimsindices = to_nameddimsindices(a, first.(replacements)) + new_nameddimsindices = named.(dename.(old_nameddimsindices), last.(replacements)) + return replacenameddimsindices(a, (old_nameddimsindices .=> new_nameddimsindices)...) end # `Base.isempty(a::AbstractArray)` is defined as `length(a) == 0`, @@ -110,50 +235,52 @@ Base.IndexStyle(s1::IndexStyle, s2::NamedIndexCartesian) = NamedIndexCartesian() Base.IndexStyle(s1::NamedIndexCartesian, s2::IndexStyle) = NamedIndexCartesian() # Like CartesianIndex but with named dimensions. -struct NamedCartesianIndex{N,Index<:Tuple{Vararg{AbstractNamedInteger,N}}} <: +struct NamedDimsCartesianIndex{N,Index<:Tuple{Vararg{AbstractNamedInteger,N}}} <: Base.AbstractCartesianIndex{N} I::Index end -NamedCartesianIndex(I::AbstractNamedInteger...) = NamedCartesianIndex(I) -Base.Tuple(I::NamedCartesianIndex) = I.I -function Base.show(io::IO, I::NamedCartesianIndex) - print(io, "NamedCartesianIndex") +NamedDimsCartesianIndex(I::AbstractNamedInteger...) = NamedDimsCartesianIndex(I) +Base.Tuple(I::NamedDimsCartesianIndex) = I.I +function Base.show(io::IO, I::NamedDimsCartesianIndex) + print(io, "NamedDimsCartesianIndex") show(io, Tuple(I)) return nothing end # Like CartesianIndices but with named dimensions. -struct NamedCartesianIndices{ +struct NamedDimsCartesianIndices{ N, Indices<:Tuple{Vararg{AbstractNamedUnitRange,N}}, Index<:Tuple{Vararg{AbstractNamedInteger,N}}, -} <: AbstractNamedDimsArray{NamedCartesianIndex{N,Index},N} +} <: AbstractNamedDimsArray{NamedDimsCartesianIndex{N,Index},N} indices::Indices - function NamedCartesianIndices(indices::Tuple{Vararg{AbstractNamedUnitRange}}) + function NamedDimsCartesianIndices(indices::Tuple{Vararg{AbstractNamedUnitRange}}) return new{length(indices),typeof(indices),Tuple{eltype.(indices)...}}(indices) end end -Base.axes(I::NamedCartesianIndices) = map(only ∘ axes, I.indices) -Base.size(I::NamedCartesianIndices) = length.(I.indices) +Base.eltype(I::NamedDimsCartesianIndices) = eltype(typeof(I)) +Base.axes(I::NamedDimsCartesianIndices) = map(only ∘ axes, I.indices) +Base.size(I::NamedDimsCartesianIndices) = length.(I.indices) -function Base.getindex(a::NamedCartesianIndices{N}, I::Vararg{Int,N}) where {N} - index = map(a.indices, I) do r, i - return getindex(r, i) +function Base.getindex(a::NamedDimsCartesianIndices{N}, I::Vararg{Int,N}) where {N} + # TODO: Check if `nameddimsindices(a)` is correct here. + index = map(nameddimsindices(a), I) do r, i + return r[i] end - return NamedCartesianIndex(index) + return NamedDimsCartesianIndex(index) end -dimnames(I::NamedCartesianIndices) = name.(I.indices) -function Base.parent(I::NamedCartesianIndices) +nameddimsindices(I::NamedDimsCartesianIndices) = name.(I.indices) +function Base.parent(I::NamedDimsCartesianIndices) return CartesianIndices(dename.(I.indices)) end function Base.eachindex(::NamedIndexCartesian, a1::AbstractArray, a_rest::AbstractArray...) - all(a -> issetequal(dimnames(a1), dimnames(a)), a_rest) || - throw(NameMismatch("Dimension name mismatch $(dimnames.((a1, a_rest...))).")) + all(a -> issetequal(nameddimsindices(a1), nameddimsindices(a)), a_rest) || + throw(NameMismatch("Dimension name mismatch $(nameddimsindices.((a1, a_rest...))).")) # TODO: Check the shapes match. - return NamedCartesianIndices(axes(a1)) + return NamedDimsCartesianIndices(axes(a1)) end # Base version ignores dimension names. @@ -173,167 +300,346 @@ function Base.:(==)(a1::AbstractNamedDimsArray, a2::AbstractNamedDimsArray) end end -# TODO: Move to `utils.jl` file. -# TODO: Use `Base.indexin`? -function getperm(x, y) - return map(yᵢ -> findfirst(isequal(yᵢ), x), y) -end - # Indexing. -function Base.getindex(a::AbstractNamedDimsArray{<:Any,N}, I::Vararg{Int,N}) where {N} - return getindex(dename(a), I...) -end -function Base.getindex( - a::AbstractArray{<:Any,N}, I::Vararg{AbstractNamedInteger,N} -) where {N} - return getindex(a, to_indices(a, I)...) + +# Scalar indexing + +function Base.getindex(a::AbstractNamedDimsArray, I1::Int, Irest::Int...) + return getindex(dename(a), I1, Irest...) end -function Base.getindex( - a::AbstractNamedDimsArray{<:Any,N}, I::NamedCartesianIndex{N} -) where {N} +function Base.getindex(a::AbstractNamedDimsArray, I::CartesianIndex) return getindex(a, to_indices(a, (I,))...) end function Base.getindex( - a::AbstractNamedDimsArray{<:Any,N}, I::Vararg{Pair{<:Any,Int},N} -) where {N} - return getindex(a, to_indices(a, I)...) + a::AbstractNamedDimsArray, I1::AbstractNamedInteger, Irest::AbstractNamedInteger... +) + I = (I1, Irest...) + # TODO: Check if this permuation should be inverted. + perm = getperm(name.(nameddimsindices(a)), name.(I)) + # TODO: Throw a `NameMismatch` error. + @assert isperm(perm) + I = map(p -> I[p], perm) + subinds = map(nameddimsindices(a), I) do dimname, i + return checked_indexin(dename(i), dename(dimname)) + end + return getindex(dename(a), subinds...) end +function Base.getindex(a::AbstractNamedDimsArray, I::NamedDimsCartesianIndex) + return getindex(a, Tuple(I)...) +end +function Base.getindex(a::AbstractNamedDimsArray, I1::Pair, Irest::Pair...) + I = (I1, Irest...) + nameddimsindices = to_nameddimsindices(a, first.(I)) + return getindex(a, map((i, name) -> name[i], last.(I), nameddimsindices)...) +end +function Base.getindex(a::AbstractNamedDimsArray) + return getindex(dename(a)) +end +# Linear indexing. function Base.getindex(a::AbstractNamedDimsArray, I::Int) return getindex(dename(a), I) end -function Base.setindex!( - a::AbstractNamedDimsArray{<:Any,N}, value, I::Vararg{Int,N} -) where {N} - setindex!(dename(a), value, I...) + +function Base.setindex!(a::AbstractNamedDimsArray, value, I1::Int, Irest::Int...) + setindex!(dename(a), value, I1, Irest...) return a end -function Base.setindex!( - a::AbstractArray{<:Any,N}, value, I::Vararg{AbstractNamedInteger,N} -) where {N} - setindex!(a, value, to_indices(a, I)...) +function Base.setindex!(a::AbstractNamedDimsArray, value, I::CartesianIndex) + setindex!(a, value, to_indices(a, (I,))...) return a end function Base.setindex!( - a::AbstractNamedDimsArray{<:Any,N}, value, I::NamedCartesianIndex{N} -) where {N} - setindex!(a, value, to_indices(a, (I,))...) + a::AbstractNamedDimsArray, value, I1::AbstractNamedInteger, Irest::AbstractNamedInteger... +) + I = (I1, Irest...) + # TODO: Check if this permuation should be inverted. + perm = getperm(name.(nameddimsindices(a)), name.(I)) + # TODO: Throw a `NameMismatch` error. + @assert isperm(perm) + I = map(p -> I[p], perm) + subinds = map(nameddimsindices(a), I) do dimname, i + return checked_indexin(dename(i), dename(dimname)) + end + return setindex!(dename(a), value, subinds...) +end +function Base.setindex!(a::AbstractNamedDimsArray, value, I::NamedDimsCartesianIndex) + setindex!(a, value, Tuple(I)...) return a end -function Base.setindex!( - a::AbstractNamedDimsArray{<:Any,N}, value, I::Vararg{Pair{<:Any,Int},N} -) where {N} - setindex!(a, value, to_indices(a, I)...) +function Base.setindex!(a::AbstractNamedDimsArray, value, I1::Pair, Irest::Pair...) + I = (I1, Irest...) + nameddimsindices = to_nameddimsindices(a, first.(I)) + setindex!(a, value, map((i, name) -> name[i], last.(I), nameddimsindices)...) + return a +end +function Base.setindex!(a::AbstractNamedDimsArray, value) + setindex!(dename(a), value) return a end +# Linear indexing. function Base.setindex!(a::AbstractNamedDimsArray, value, I::Int) setindex!(dename(a), value, I) return a end -# Handles permutation of indices to align dimension names. -function Base.to_indices( - a::AbstractArray{<:Any,N}, I::Tuple{Vararg{AbstractNamedInteger,N}} -) where {N} - # TODO: Check this permutation is correct (it may be the inverse of what we want). - # We unwrap the names twice in case named axes were passed as indices. - return dename.(map(i -> I[i], getperm(dimnames(a), name.(name.(I))))) + +function Base.isassigned(a::AbstractNamedDimsArray, I::Int...) + return isassigned(dename(a), I...) end -function Base.to_indices( - a::AbstractArray{<:Any,N}, I::Tuple{NamedCartesianIndex{N}} -) where {N} - return to_indices(a, Tuple(only(I))) + +# Slicing + +# Like `const ViewIndex = Union{Real,AbstractArray}`. +const NamedViewIndex = Union{AbstractNamedInteger,AbstractNamedUnitRange,AbstractNamedArray} + +using ArrayLayouts: ArrayLayouts, MemoryLayout + +abstract type AbstractNamedDimsArrayLayout <: MemoryLayout end +struct NamedDimsArrayLayout{ParentLayout} <: AbstractNamedDimsArrayLayout end + +function ArrayLayouts.MemoryLayout(arrtype::Type{<:AbstractNamedDimsArray}) + return NamedDimsArrayLayout{typeof(MemoryLayout(parenttype(arrtype)))}() +end + +function ArrayLayouts.sub_materialize(::NamedDimsArrayLayout, a, ax) + return copy(a) +end + +function Base.view(a::AbstractArray, I1::NamedViewIndex, Irest::NamedViewIndex...) + I = (I1, Irest...) + sub_dims = filter(dim -> I[dim] isa AbstractArray, ntuple(identity, ndims(a))) + sub_nameddimsindices = map(dim -> I[dim], sub_dims) + return nameddims(view(a, dename.(I)...), sub_nameddimsindices) end -# Support indexing with pairs `a[:i => 1, :j => 2]`. -function Base.to_index(a::AbstractNamedDimsArray, i::Pair{<:Any,Int}) - return named(last(i), first(i)) + +function Base.getindex(a::AbstractArray, I1::NamedViewIndex, Irest::NamedViewIndex...) + return copy(view(a, I1, Irest...)) end -function Base.isassigned(a::AbstractNamedDimsArray{<:Any,N}, I::Vararg{Int,N}) where {N} - return isassigned(parent(a), I...) + +function Base.view(a::AbstractArray, I1::Name, Irest::Name...) + return nameddims(a, name.((I1, Irest...))) end +function Base.view(a::AbstractNamedDimsArray, I1::Name, Irest::Name...) + return view(a, to_nameddimsindices(a, (I1, Irest...))...) +end + +function Base.getindex(a::AbstractArray, I1::Name, Irest::Name...) + return copy(view(a, I1, Irest...)) +end + +function Base.view(a::AbstractNamedDimsArray, I1::NamedViewIndex, Irest::NamedViewIndex...) + I = (I1, Irest...) + # TODO: Check if this permuation should be inverted. + perm = getperm(name.(nameddimsindices(a)), name.(I)) + # TODO: Throw a `NameMismatch` error. + @assert isperm(perm) + I = map(p -> I[p], perm) + sub_dims = filter(dim -> I[dim] isa AbstractArray, ntuple(identity, ndims(a))) + sub_nameddimsindices = map(dim -> I[dim], sub_dims) + subinds = map(nameddimsindices(a), I) do dimname, i + return checked_indexin(dename(i), dename(dimname)) + end + return nameddims(view(dename(a), subinds...), sub_nameddimsindices) +end + +function Base.getindex( + a::AbstractNamedDimsArray, I1::NamedViewIndex, Irest::NamedViewIndex... +) + return copy(view(a, I1, Irest...)) +end + +# Repeated definition of `Base.ViewIndex`. +const ViewIndex = Union{Real,AbstractArray} + +function nameddims_view(a::AbstractArray, I...) + sub_dims = filter(dim -> !(I[dim] isa Real), ntuple(identity, ndims(a))) + sub_nameddimsindices = map(dim -> nameddimsindices(a, dim)[I[dim]], sub_dims) + return nameddims(view(dename(a), I...), sub_nameddimsindices) +end + +function Base.view(a::AbstractNamedDimsArray, I::ViewIndex...) + return nameddims_view(a, I...) +end + +function nameddims_getindex(a::AbstractArray, I...) + return copy(view(a, I...)) +end + +function Base.getindex(a::AbstractNamedDimsArray, I::ViewIndex...) + return nameddims_getindex(a, I...) +end + +function Base.setindex!( + a::AbstractNamedDimsArray, + value::AbstractNamedDimsArray, + I1::NamedViewIndex, + Irest::NamedViewIndex..., +) + view(a, I1, Irest...) .= value + return a +end +function Base.setindex!( + a::AbstractNamedDimsArray, + value::AbstractArray, + I1::NamedViewIndex, + Irest::NamedViewIndex..., +) + I = (I1, Irest...) + setindex!(a, nameddims(value, I), I...) + return a +end +function Base.setindex!( + a::AbstractNamedDimsArray, + value::AbstractNamedDimsArray, + I1::ViewIndex, + Irest::ViewIndex..., +) + view(a, I1, Irest...) .= value + return a +end +function Base.setindex!( + a::AbstractNamedDimsArray, value::AbstractArray, I1::ViewIndex, Irest::ViewIndex... +) + setindex!(dename(a), value, I1, Irest...) + return a +end + +# Permute/align dimensions + function aligndims(a::AbstractArray, dims) + new_nameddimsindices = to_nameddimsindices(a, dims) # TODO: Check this permutation is correct (it may be the inverse of what we want). - perm = getperm(dimnames(a), name.(dims)) - return nameddims(permutedims(dename(a), perm), name.(dims)) + perm = getperm(nameddimsindices(a), new_nameddimsindices) + isperm(perm) || throw( + NameMismatch( + "Dimension name mismatch $(nameddimsindices(a)), $(new_nameddimsindices)." + ), + ) + return nameddims(permutedims(dename(a), perm), new_nameddimsindices) end function aligneddims(a::AbstractArray, dims) + new_nameddimsindices = to_nameddimsindices(a, dims) # TODO: Check this permutation is correct (it may be the inverse of what we want). - new_dimnames = name.(dims) - perm = getperm(dimnames(a), new_dimnames) - !isperm(perm) && - throw(NameMismatch("Dimension name mismatch $(dimnames(a)), $(new_dimnames).")) - return nameddims(PermutedDimsArray(dename(a), perm), new_dimnames) + perm = getperm(nameddimsindices(a), new_nameddimsindices) + isperm(perm) || throw( + NameMismatch( + "Dimension name mismatch $(nameddimsindices(a)), $(new_nameddimsindices)." + ), + ) + return nameddims(PermutedDimsArray(dename(a), perm), new_nameddimsindices) end +# Convenient constructors + using Random: Random, AbstractRNG -# Convenient constructors +# TODO: Come up with a better name for this. +_rand(args...) = Base.rand(args...) +function _rand( + rng::AbstractRNG, elt::Type, dims::Tuple{Base.OneTo{Int},Vararg{Base.OneTo{Int}}} +) + return Base.rand(rng, elt, length.(dims)) +end + +# TODO: Come up with a better name for this. +_randn(args...) = Base.randn(args...) +function _randn( + rng::AbstractRNG, elt::Type, dims::Tuple{Base.OneTo{Int},Vararg{Base.OneTo{Int}}} +) + return Base.randn(rng, elt, length.(dims)) +end + default_eltype() = Float64 -for f in [:rand, :randn] +for (f, f′) in [(:rand, :_rand), (:randn, :_randn)] @eval begin function Base.$f( rng::AbstractRNG, elt::Type{<:Number}, - dims::Tuple{AbstractNamedInteger,Vararg{AbstractNamedInteger}}, + inds::Tuple{NamedDimsIndices,Vararg{NamedDimsIndices}}, ) - a = $f(rng, elt, unname.(dims)) - return nameddims(a, name.(dims)) + ax = to_nameddimsaxes(inds) + a = $f′(rng, elt, dename.(ax)) + return nameddims(a, name.(ax)) end function Base.$f( rng::AbstractRNG, elt::Type{<:Number}, - dim1::AbstractNamedInteger, - dims::Vararg{AbstractNamedInteger}, + dims::Tuple{AbstractNamedInteger,Vararg{AbstractNamedInteger}}, ) - return $f(rng, elt, (dim1, dims...)) + return $f(rng, elt, Base.oneto.(dims)) + end + end + for dimtype in [:AbstractNamedInteger, :NamedDimsIndices] + @eval begin + function Base.$f( + rng::AbstractRNG, elt::Type{<:Number}, dim1::$dimtype, dims::Vararg{$dimtype} + ) + return $f(rng, elt, (dim1, dims...)) + end + Base.$f(elt::Type{<:Number}, dims::Tuple{$dimtype,Vararg{$dimtype}}) = + $f(Random.default_rng(), elt, dims) + Base.$f(elt::Type{<:Number}, dim1::$dimtype, dims::Vararg{$dimtype}) = + $f(elt, (dim1, dims...)) + Base.$f(dims::Tuple{$dimtype,Vararg{$dimtype}}) = $f(default_eltype(), dims) + Base.$f(dim1::$dimtype, dims::Vararg{$dimtype}) = $f((dim1, dims...)) end - Base.$f( - elt::Type{<:Number}, dims::Tuple{AbstractNamedInteger,Vararg{AbstractNamedInteger}} - ) = $f(Random.default_rng(), elt, dims) - Base.$f( - elt::Type{<:Number}, dim1::AbstractNamedInteger, dims::Vararg{AbstractNamedInteger} - ) = $f(elt, (dim1, dims...)) - Base.$f(dims::Tuple{AbstractNamedInteger,Vararg{AbstractNamedInteger}}) = - $f(default_eltype(), dims) - Base.$f(dim1::AbstractNamedInteger, dims::Vararg{AbstractNamedInteger}) = - $f((dim1, dims...)) end end for f in [:zeros, :ones] @eval begin function Base.$f( - elt::Type{<:Number}, dims::Tuple{AbstractNamedInteger,Vararg{AbstractNamedInteger}} + elt::Type{<:Number}, ax::Tuple{NamedDimsIndices,Vararg{NamedDimsIndices}} ) - a = $f(elt, unname.(dims)) - return nameddims(a, name.(dims)) + ax = to_nameddimsaxes(inds) + a = $f(elt, dename.(ax)) + return nameddims(a, name.(ax)) end function Base.$f( - elt::Type{<:Number}, dim1::AbstractNamedInteger, dims::Vararg{AbstractNamedInteger} + elt::Type{<:Number}, dims::Tuple{AbstractNamedInteger,Vararg{AbstractNamedInteger}} ) - return $f(elt, (dim1, dims...)) + a = $f(elt, dename.(dims)) + return nameddims(a, Base.oneto.(dims)) + end + end + for dimtype in [:AbstractNamedInteger, :NamedDimsIndices] + @eval begin + function Base.$f(elt::Type{<:Number}, dim1::$dimtype, dims::Vararg{$dimtype}) + return $f(elt, (dim1, dims...)) + end + Base.$f(dims::Tuple{$dimtype,Vararg{$dimtype}}) = $f(default_eltype(), dims) + Base.$f(dim1::$dimtype, dims::Vararg{$dimtype}) = $f((dim1, dims...)) end - Base.$f(dims::Tuple{AbstractNamedInteger,Vararg{AbstractNamedInteger}}) = - $f(default_eltype(), dims) - Base.$f(dim1::AbstractNamedInteger, dims::Vararg{AbstractNamedInteger}) = - $f((dim1, dims...)) end end -function Base.fill(value, dims::Tuple{AbstractNamedInteger,Vararg{AbstractNamedInteger}}) - a = fill(value, unname.(dims)) - return nameddims(a, name.(dims)) +@eval begin + function Base.fill(value, inds::Tuple{NamedDimsIndices,Vararg{NamedDimsIndices}}) + ax = to_nameddimsaxes(inds) + a = fill(value, dename.(ax)) + return nameddims(a, name.(ax)) + end + function Base.fill(value, dims::Tuple{AbstractNamedInteger,Vararg{AbstractNamedInteger}}) + a = fill(value, dename.(dims)) + return nameddims(a, Base.oneto.(dims)) + end end -function Base.fill(value, dim1::AbstractNamedInteger, dims::Vararg{AbstractNamedInteger}) - return fill(value, (dim1, dims...)) +for dimtype in [:AbstractNamedInteger, :NamedDimsIndices] + @eval begin + function Base.fill(value, dim1::$dimtype, dims::Vararg{$dimtype}) + return fill(value, (dim1, dims...)) + end + end end +# Broadcasting + using Base.Broadcast: AbstractArrayStyle, Broadcasted, broadcast_shape, broadcasted, check_broadcast_shape, - combine_axes, - combine_eltypes + combine_axes using BroadcastMapConversion: Mapped, mapped abstract type AbstractNamedDimsArrayStyle{N} <: AbstractArrayStyle{N} end @@ -374,7 +680,7 @@ function Base.promote_shape( ax1::Tuple{AbstractNamedUnitRange,Vararg{AbstractNamedUnitRange,N}}, ax2::Tuple{AbstractNamedUnitRange,Vararg{AbstractNamedUnitRange,N}}, ) where {N} - perm = getperm(name.(ax1), name.(ax2)) + perm = getperm(ax1, ax2) ax2_aligned = map(i -> ax2[i], perm) ax_promoted = promote_shape(dename.(ax1), dename.(ax2_aligned)) return named.(ax_promoted, name.(ax1)) @@ -385,7 +691,7 @@ function Broadcast.check_broadcast_shape( ax1::Tuple{AbstractNamedUnitRange,Vararg{AbstractNamedUnitRange,N}}, ax2::Tuple{AbstractNamedUnitRange,Vararg{AbstractNamedUnitRange,N}}, ) where {N} - perm = getperm(name.(ax1), name.(ax2)) + perm = getperm(ax1, ax2) ax2_aligned = map(i -> ax2[i], perm) check_broadcast_shape(dename.(ax1), dename.(ax2_aligned)) return nothing @@ -400,14 +706,15 @@ end # Dename and lazily permute the arguments using the reference # dimension names. -# TODO: Make a version that gets the dimnames from `m`. -function denamed(m::Mapped, dimnames) - return mapped(m.f, map(arg -> denamed(arg, dimnames), m.args)...) +# TODO: Make a version that gets the nameddimsindices from `m`. +function denamed(m::Mapped, nameddimsindices) + return mapped(m.f, map(arg -> denamed(arg, nameddimsindices), m.args)...) end function Base.similar(bc::Broadcasted{<:AbstractNamedDimsArrayStyle}, elt::Type, ax::Tuple) - m′ = denamed(Mapped(bc), name.(ax)) - return nameddims(similar(m′, elt, dename.(ax)), name.(ax)) + nameddimsindices = name.(ax) + m′ = denamed(Mapped(bc), nameddimsindices) + return nameddims(similar(m′, elt, dename.(ax)), nameddimsindices) end function Base.copyto!( @@ -419,7 +726,7 @@ end function Base.map!(f, a_dest::AbstractNamedDimsArray, a_srcs::AbstractNamedDimsArray...) a′_dest = dename(a_dest) # TODO: Use `denamed` to do the permutations lazily. - a′_srcs = map(a_src -> dename(a_src, dimnames(a_dest)), a_srcs) + a′_srcs = map(a_src -> dename(a_src, nameddimsindices(a_dest)), a_srcs) map!(f, a′_dest, a′_srcs...) return a_dest end @@ -438,7 +745,8 @@ function LinearAlgebra.norm(a::AbstractNamedDimsArray; kwargs...) return norm(dename(a); kwargs...) end -# Printing. +# Printing + function Base.show(io::IO, mime::MIME"text/plain", a::AbstractNamedDimsArray) summary(io, a) println(io) @@ -449,6 +757,6 @@ end function Base.show(io::IO, a::AbstractNamedDimsArray) print(io, "nameddims(") show(io, dename(a)) - print(io, ", ", dimnames(a), ")") + print(io, ", ", nameddimsindices(a), ")") return nothing end diff --git a/src/abstractnamedinteger.jl b/src/abstractnamedinteger.jl index c557421..84e4e79 100644 --- a/src/abstractnamedinteger.jl +++ b/src/abstractnamedinteger.jl @@ -37,7 +37,8 @@ function Base.hash(i::AbstractNamedInteger, h::UInt) end abstract type AbstractName end -name(n::AbstractName) = throw(MethodError(name, Tuple{typeof(n)})) +# TODO: Decide if this is a good definition, probably not. +# name(n::AbstractName) = throw(MethodError(name, Tuple{typeof(n)})) Base.getindex(n::AbstractName, I) = named(I, name(n)) struct Name{Value} <: AbstractName @@ -60,6 +61,10 @@ fusednames(name1::FusedNames, name2::FusedNames) = FusedNames(generic_vcat(name1 fusednames(name1, name2::FusedNames) = fusednames(FusedNames((name1,)), name2) fusednames(name1::FusedNames, name2) = fusednames(name1, FusedNames((name2,))) +function Base.:(==)(n1::FusedNames, n2::FusedNames) + return mapreduce(==, &, n1.names, n2.names) +end + # Integer interface # TODO: Should this make a random name, or require defining a way # to combine names? @@ -89,11 +94,17 @@ function Base.string(i::AbstractNamedInteger; kwargs...) return "named($(string(dename(i); kwargs...)), $(repr(name(i))))" end +Base.Int(i::AbstractNamedInteger) = Int(dename(i)) + struct NameMismatch <: Exception message::String end NameMismatch() = NameMismatch("") +function randname(rng::AbstractRNG, i::AbstractNamedInteger) + return named(dename(i), randname(name(i))) +end + # Used in bounds checking when indexing with named dimensions. function Base.:<(i1::AbstractNamedInteger, i2::AbstractNamedInteger) name(i1) == name(i2) || throw(NameMismatch("Mismatched names $(name(i1)), $(name(i2))")) diff --git a/src/abstractnamedunitrange.jl b/src/abstractnamedunitrange.jl index 9942b1a..4f40e36 100644 --- a/src/abstractnamedunitrange.jl +++ b/src/abstractnamedunitrange.jl @@ -37,16 +37,32 @@ function Base.hash(r::AbstractNamedUnitRange, h::UInt) end # Unit range funcionality. -# TODO: Also customize `Base.getindex` to preserve the name. Base.first(r::AbstractNamedUnitRange) = named(first(dename(r)), name(r)) Base.last(r::AbstractNamedUnitRange) = named(last(dename(r)), name(r)) Base.length(r::AbstractNamedUnitRange) = named(length(dename(r)), name(r)) Base.size(r::AbstractNamedUnitRange) = (named(length(dename(r)), name(r)),) Base.axes(r::AbstractNamedUnitRange) = (named(only(axes(dename(r))), name(r)),) Base.step(r::AbstractNamedUnitRange) = named(step(dename(r)), name(r)) -Base.getindex(r::AbstractNamedUnitRange, i::Int) = named(getindex(dename(r), i), name(r)) +Base.getindex(r::AbstractNamedUnitRange, I::Int) = named_getindex(r, I) +# Fix ambiguity error. +function Base.getindex(r::AbstractNamedUnitRange, I::AbstractUnitRange{<:Integer}) + return named_getindex(r, I) +end +# Fix ambiguity error. +function Base.getindex(r::AbstractNamedUnitRange, I::Colon) + return named_getindex(r, I) +end +function Base.getindex(r::AbstractNamedUnitRange, I) + return named_getindex(r, I) +end Base.isempty(r::AbstractNamedUnitRange) = isempty(dename(r)) +function Base.AbstractUnitRange{Int}(r::AbstractNamedUnitRange) + return AbstractUnitRange{Int}(dename(r)) +end + +Base.oneto(length::AbstractNamedInteger) = named(Base.OneTo(dename(length)), name(length)) +namedoneto(length::Integer, name) = Base.oneto(named(length, name)) Base.iterate(r::AbstractNamedUnitRange) = isempty(r) ? nothing : (first(r), first(r)) function Base.iterate(r::AbstractNamedUnitRange, i) i == last(r) && return nothing @@ -54,7 +70,18 @@ function Base.iterate(r::AbstractNamedUnitRange, i) return (next, next) end +function randname(rng::AbstractRNG, r::AbstractNamedUnitRange) + return named(dename(r), randname(name(r))) +end + function Base.show(io::IO, r::AbstractNamedUnitRange) print(io, "named(", dename(r), ", ", repr(name(r)), ")") return nothing end + +struct NamedColon{Name} <: Function + name::Name +end +dename(c::NamedColon) = Colon() +name(c::NamedColon) = c.name +named(::Colon, name) = NamedColon(name) diff --git a/src/adapt.jl b/src/adapt.jl index 799f70d..25fcd4f 100644 --- a/src/adapt.jl +++ b/src/adapt.jl @@ -1,5 +1,5 @@ using Adapt: Adapt, adapt function Adapt.adapt_structure(to, a::AbstractNamedDimsArray) - return nameddims(adapt(to, dename(a)), dimnames(a)) + return nameddims(adapt(to, dename(a)), nameddimsindices(a)) end diff --git a/src/namedarray.jl b/src/namedarray.jl new file mode 100644 index 0000000..ec25baa --- /dev/null +++ b/src/namedarray.jl @@ -0,0 +1,9 @@ +struct NamedArray{T,N,Value<:AbstractArray{T,N},Name} <: + AbstractNamedArray{NamedInteger{T,Name},N,Value,Name} + value::Value + name::Name +end + +# Minimal interface. +dename(a::NamedArray) = a.value +name(a::NamedArray) = a.name diff --git a/src/nameddimsarray.jl b/src/nameddimsarray.jl index 2702aa1..9a635ec 100644 --- a/src/nameddimsarray.jl +++ b/src/nameddimsarray.jl @@ -1,9 +1,16 @@ using TypeParameterAccessors: TypeParameterAccessors, parenttype +# nameddimsindices should be a named slice. struct NamedDimsArray{T,N,Parent<:AbstractArray{T,N},DimNames} <: AbstractNamedDimsArray{T,N} parent::Parent - dimnames::DimNames + nameddimsindices::DimNames + function NamedDimsArray(parent::AbstractArray, dims) + nameddimsindices = to_nameddimsindices(parent, dims) + return new{eltype(parent),ndims(parent),typeof(parent),typeof(nameddimsindices)}( + parent, nameddimsindices + ) + end end const NamedDimsVector{T,Parent<:AbstractVector{T},DimNames} = NamedDimsArray{ @@ -13,12 +20,17 @@ const NamedDimsMatrix{T,Parent<:AbstractMatrix{T},DimNames} = NamedDimsArray{ T,2,Parent,DimNames } -function NamedDimsArray(a::AbstractNamedDimsArray, dimnames) - return NamedDimsArray(denamed(a, dimnames), dimnames) +# TODO: Delete this, and just wrap the input naively. +function NamedDimsArray(a::AbstractNamedDimsArray, nameddimsindices) + return error("Already named.") +end + +function NamedDimsArray(a::AbstractNamedDimsArray) + return NamedDimsArray(dename(a), nameddimsindices(a)) end # Minimal interface. -dimnames(a::NamedDimsArray) = a.dimnames +nameddimsindices(a::NamedDimsArray) = a.nameddimsindices Base.parent(a::NamedDimsArray) = a.parent function TypeParameterAccessors.position( diff --git a/src/tensoralgebra.jl b/src/tensoralgebra.jl index fe993ed..2c9f8d3 100644 --- a/src/tensoralgebra.jl +++ b/src/tensoralgebra.jl @@ -11,11 +11,11 @@ function TensorAlgebra.contract!( ) contract!( dename(a_dest), - dimnames(a_dest), + nameddimsindices(a_dest), dename(a1), - dimnames(a1), + nameddimsindices(a1), dename(a2), - dimnames(a2), + nameddimsindices(a2), α, β, ) @@ -25,8 +25,10 @@ end function TensorAlgebra.contract( a1::AbstractNamedDimsArray, a2::AbstractNamedDimsArray, α::Number=true ) - a_dest, dimnames_dest = contract(dename(a1), dimnames(a1), dename(a2), dimnames(a2), α) - return nameddims(a_dest, dimnames_dest) + a_dest, nameddimsindices_dest = contract( + dename(a1), nameddimsindices(a1), dename(a2), nameddimsindices(a2), α + ) + return nameddims(a_dest, nameddimsindices_dest) end function Base.:*(a1::AbstractNamedDimsArray, a2::AbstractNamedDimsArray) @@ -45,11 +47,10 @@ function LinearAlgebra.mul!( end function TensorAlgebra.blockedperm(na::AbstractNamedDimsArray, nameddim_blocks::Tuple...) - # Extract names if named dimensions or axes were passed - dimname_blocks = map(group -> name.(group), nameddim_blocks) - dimnames_a = dimnames(na) + dimname_blocks = map(group -> to_nameddimsindices(na, group), nameddim_blocks) + nameddimsindices_a = nameddimsindices(na) perms = map(dimname_blocks) do dimname_block - return BaseExtensions.indexin(dimname_block, dimnames_a) + return BaseExtensions.indexin(dimname_block, nameddimsindices_a) end return blockedperm(perms...) end @@ -60,33 +61,37 @@ end # fusedims(a, (i, k) => "a", (j, l) => "b") # TODO: Rewrite in terms of `fusedims(a, .., (1, 3))` interface. function TensorAlgebra.fusedims(na::AbstractNamedDimsArray, fusions::Pair...) - dimnames_fuse = map(group -> name.(group), first.(fusions)) - dimnames_fused = map(name, last.(fusions)) - if sum(length, dimnames_fuse) < ndims(na) + nameddimsindices_fuse = map(group -> to_nameddimsindices(na, group), first.(fusions)) + nameddimsindices_fused = last.(fusions) + if sum(length, nameddimsindices_fuse) < ndims(na) # Not all names are specified - dimnames_unspecified = setdiff(dimnames(na), dimnames_fuse...) - dimnames_fuse = vcat(tuple.(dimnames_unspecified), collect(dimnames_fuse)) - dimnames_fused = vcat(dimnames_unspecified, collect(dimnames_fused)) + nameddimsindices_unspecified = setdiff(nameddimsindices(na), nameddimsindices_fuse...) + nameddimsindices_fuse = vcat( + tuple.(nameddimsindices_unspecified), collect(nameddimsindices_fuse) + ) + nameddimsindices_fused = vcat( + nameddimsindices_unspecified, collect(nameddimsindices_fused) + ) end - perm = blockedperm(na, dimnames_fuse...) + perm = blockedperm(na, nameddimsindices_fuse...) a_fused = fusedims(unname(na), perm) - return nameddims(a_fused, dimnames_fused) + return nameddims(a_fused, nameddimsindices_fused) end function TensorAlgebra.splitdims(na::AbstractNamedDimsArray, splitters::Pair...) - fused_names = map(name, first.(splitters)) + splitters = to_nameddimsindices(na, first.(splitters)) .=> last.(splitters) split_namedlengths = last.(splitters) splitters_unnamed = map(splitters) do splitter fused_name, split_namedlengths = splitter - fused_dim = findfirst(isequal(fused_name), dimnames(na)) + fused_dim = findfirst(isequal(fused_name), nameddimsindices(na)) split_lengths = unname.(split_namedlengths) return fused_dim => split_lengths end a_split = splitdims(unname(na), splitters_unnamed...) - names_split = Any[tuple.(dimnames(na))...] + names_split = Any[tuple.(nameddimsindices(na))...] for splitter in splitters fused_name, split_namedlengths = splitter - fused_dim = findfirst(isequal(fused_name), dimnames(na)) + fused_dim = findfirst(isequal(fused_name), nameddimsindices(na)) split_names = name.(split_namedlengths) names_split[fused_dim] = split_names end @@ -95,23 +100,31 @@ function TensorAlgebra.splitdims(na::AbstractNamedDimsArray, splitters::Pair...) end function LinearAlgebra.qr( - a::AbstractNamedDimsArray, dimnames_codomain, dimnames_domain; positive=nothing + a::AbstractNamedDimsArray, + nameddimsindices_codomain, + nameddimsindices_domain; + positive=nothing, ) @assert isnothing(positive) || !positive # TODO: This should be `TensorAlgebra.qr` rather than overloading `LinearAlgebra.qr`. # TODO: Don't require wrapping in `Tuple`. q, r = qr( unname(a), - Tuple(dimnames(a)), - Tuple(name.(dimnames_codomain)), - Tuple(name.(dimnames_domain)), + Tuple(nameddimsindices(a)), + Tuple(to_nameddimsindices(a, nameddimsindices_codomain)), + Tuple(to_nameddimsindices(a, nameddimsindices_domain)), ) - name_qr = randname(dimnames(a)[1]) - dimnames_q = (name.(dimnames_codomain)..., name_qr) - dimnames_r = (name_qr, name.(dimnames_domain)...) - return nameddims(q, dimnames_q), nameddims(r, dimnames_r) + name_qr = randname(nameddimsindices(a)[1]) + nameddimsindices_q = (to_nameddimsindices(a, nameddimsindices_codomain)..., name_qr) + nameddimsindices_r = (name_qr, to_nameddimsindices(a, nameddimsindices_domain)...) + return nameddims(q, nameddimsindices_q), nameddims(r, nameddimsindices_r) end -function LinearAlgebra.qr(a::AbstractNamedDimsArray, dimnames_codomain; kwargs...) - return qr(a, dimnames_codomain, setdiff(dimnames(a), name.(dimnames_codomain)); kwargs...) +function LinearAlgebra.qr(a::AbstractNamedDimsArray, nameddimsindices_codomain; kwargs...) + return qr( + a, + nameddimsindices_codomain, + setdiff(nameddimsindices(a), to_nameddimsindices(a, nameddimsindices_codomain)); + kwargs..., + ) end diff --git a/test/Project.toml b/test/Project.toml index 7e3fdc5..0efab2b 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,14 +1,16 @@ [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" +Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" +BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e" +BlockSparseArrays = "2c9a651f-6452-4ace-a6ac-809f4280fbb4" BroadcastMapConversion = "4a4adec5-520f-4750-bb37-d5e66b4ddeb2" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" NamedDimsArrays = "60cbd0c0-df58-4cb7-918c-6f5607b73fde" NestedPermutedDimsArrays = "2c2a8ec4-3cfc-4276-aa3e-1307b4294e58" -SparseArraysBase = "0d5efcca-f356-4864-8770-e1ed8d78f208" -TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a" -Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" +SparseArraysBase = "0d5efcca-f356-4864-8770-e1ed8d78f208" Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb" +TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" TypeParameterAccessors = "7e5a90cf-f82e-492e-a09b-e3e26432c138" diff --git a/test/basics/test_adapt.jl b/test/basics/test_adapt.jl index 7d553d3..d973d84 100644 --- a/test/basics/test_adapt.jl +++ b/test/basics/test_adapt.jl @@ -2,6 +2,8 @@ using Adapt: adapt using NamedDimsArrays: nameddims using Test: @test, @testset +using NamedDimsArrays: nameddimsindices + @testset "Adapt (eltype=$elt)" for elt in (Float32, Float64, Complex{Float32}, Complex{Float64}) na = nameddims(randn(2, 2), ("i", "j")) diff --git a/test/basics/test_basic.jl b/test/basics/test_basics.jl similarity index 73% rename from test/basics/test_basic.jl rename to test/basics/test_basics.jl index abe9862..fa0a0b8 100644 --- a/test/basics/test_basic.jl +++ b/test/basics/test_basics.jl @@ -1,12 +1,11 @@ -using Test: @test, @test_throws, @testset using NamedDimsArrays: NamedDimsArrays, AbstractNamedDimsArray, AbstractNamedDimsMatrix, Name, NameMismatch, - NamedCartesianIndex, - NamedCartesianIndices, + NamedDimsCartesianIndex, + NamedDimsCartesianIndices, NamedDimsArray, NamedDimsMatrix, aligndims, @@ -21,10 +20,13 @@ using NamedDimsArrays: name, named, nameddims, - replacedimnames, - setdimnames, + nameddimsindices, + namedoneto, + replacenameddimsindices, + setnameddimsindices, unname, unnamed +using Test: @test, @test_throws, @testset @testset "NamedDimsArrays.jl" begin @testset "Basic functionality" begin @@ -36,13 +38,6 @@ using NamedDimsArrays: @test na isa AbstractNamedDimsMatrix{elt} @test na isa NamedDimsArray{elt} @test na isa AbstractNamedDimsArray{elt} - for na′ in (nameddims(na, ("j", "i")), NamedDimsArray(na, ("j", "i"))) - @test na′ isa NamedDimsMatrix{elt,<:PermutedDimsArray} - @test dimnames(na′) == ("j", "i") - @test na′ == na - end - @test_throws NameMismatch nameddims(na, ("j", "k")) - @test_throws NameMismatch NamedDimsArray(na, ("j", "k")) @test_throws MethodError dename(a) @test_throws MethodError dename(a, ("i", "j")) @test_throws MethodError denamed(a, ("i", "j")) @@ -52,11 +47,16 @@ using NamedDimsArrays: @test dename(na) == a si, sj = size(na) ai, aj = axes(na) - @test name(si) == "i" - @test name(sj) == "j" - @test name(ai) == "i" - @test name(aj) == "j" + i = namedoneto(3, "i") + j = namedoneto(4, "j") + @test name(si) == i + @test name(sj) == j + @test name(ai) == i + @test name(aj) == j @test isnamed(na) + @test nameddimsindices(na) == (i, j) + @test nameddimsindices(na, 1) == i + @test nameddimsindices(na, 2) == j @test dimnames(na) == ("i", "j") @test dimnames(na, 1) == "i" @test dimnames(na, 2) == "j" @@ -65,17 +65,28 @@ using NamedDimsArrays: @test dims(na, ("j", "i")) == (2, 1) @test na[1, 1] == a[1, 1] + for na′ in ( + similar(na, Float32, (j, i)), + similar(na, Float32, (aj, ai)), + similar(a, Float32, (j, i)), + similar(a, Float32, (aj, ai)), + ) + @test eltype(na′) === Float32 + @test nameddimsindices(na′) == (j, i) + @test na′ ≠ na + end + # getindex syntax i = Name("i") j = Name("j") @test a[i, j] == na @test @view(a[i, j]) == na @test na[j[1], i[2]] == a[2, 1] - @test dimnames(na[j, i]) == ("j", "i") + @test nameddimsindices(na[j, i]) == (named(1:3, "i"), named(1:4, "j")) @test na[j, i] == na @test @view(na[j, i]) == na - @test i[axes(a, 1)] == ai - @test j[axes(a, 2)] == aj + @test i[axes(a, 1)] == named(1:3, "i") + @test j[axes(a, 2)] == named(1:4, "j") @test axes(na, i) == ai @test axes(na, j) == aj @test size(na, i) == si @@ -115,35 +126,32 @@ using NamedDimsArrays: @test a′ isa PermutedDimsArray{elt} @test a′ == a' end - nb = setdimnames(na, ("k", "j")) - @test dimnames(nb) == ("k", "j") + nb = setnameddimsindices(na, ("k", "j")) + @test nameddimsindices(nb) == (named(1:3, "k"), named(1:4, "j")) @test dename(nb) == a - nb = replacedimnames(na, "i" => "k") - @test dimnames(nb) == ("k", "j") + nb = replacenameddimsindices(na, "i" => "k") + @test nameddimsindices(nb) == (named(1:3, "k"), named(1:4, "j")) @test dename(nb) == a - nb = replacedimnames(na, named(3, "i") => named(3, "k")) - @test dimnames(nb) == ("k", "j") + nb = replacenameddimsindices(na, named(1:3, "i") => named(1:3, "k")) + @test nameddimsindices(nb) == (named(1:3, "k"), named(1:4, "j")) @test dename(nb) == a - nb = replacedimnames(n -> n == "i" ? "k" : n, na) - @test dimnames(nb) == ("k", "j") + nb = replacenameddimsindices(n -> n == named(1:3, "i") ? named(1:3, "k") : n, na) + @test nameddimsindices(nb) == (named(1:3, "k"), named(1:4, "j")) @test dename(nb) == a - nb = setdimnames(na, named(3, "i") => named(3, "k")) + nb = setnameddimsindices(na, named(3, "i") => named(3, "k")) na[1, 1] = 11 @test na[1, 1] == 11 - @test size(na) == (named(3, "i"), named(4, "j")) - @test length(na) == named(12, fusednames("i", "j")) - @test axes(na) == (named(1:3, "i"), named(1:4, "j")) + @test size(na) == (named(3, named(1:3, "i")), named(4, named(1:4, "j"))) + @test length(na) == named(12, fusednames(named(1:3, "i"), named(1:4, "j"))) + @test axes(na) == (named(1:3, named(1:3, "i")), named(1:4, named(1:4, "j"))) @test randn(named.((3, 4), ("i", "j"))) isa NamedDimsArray @test na["i" => 1, "j" => 2] == a[1, 2] @test na["j" => 2, "i" => 1] == a[1, 2] na["j" => 2, "i" => 1] = 12 @test na[1, 2] == 12 @test na[j => 1, i => 2] == a[2, 1] - @test na[aj => 1, ai => 2] == a[2, 1] na[j => 1, i => 2] = 21 @test na[2, 1] == 21 - na[aj => 1, ai => 2] = 2211 - @test na[2, 1] == 2211 na′ = aligndims(na, ("j", "i")) @test unname(na′) isa Matrix{elt} @test a == permutedims(unname(na′), (2, 1)) @@ -156,21 +164,15 @@ using NamedDimsArrays: na′ = aligneddims(na, (j, i)) @test unname(na′) isa PermutedDimsArray{elt} @test a == permutedims(unname(na′), (2, 1)) - na′ = aligndims(na, (aj, ai)) - @test unname(na′) isa Matrix{elt} - @test a == permutedims(unname(na′), (2, 1)) - na′ = aligneddims(na, (aj, ai)) - @test unname(na′) isa PermutedDimsArray{elt} - @test a == permutedims(unname(na′), (2, 1)) na = nameddims(randn(elt, 2, 3), (:i, :j)) nb = nameddims(randn(elt, 3, 2), (:j, :i)) nc = zeros(elt, named.((2, 3), (:i, :j))) Is = eachindex(na, nb) - @test Is isa NamedCartesianIndices{2} - @test issetequal(dimnames(Is), (:i, :j)) + @test Is isa NamedDimsCartesianIndices{2} + @test issetequal(nameddimsindices(Is), (named(1:2, :i), named(1:3, :j))) for I in Is - @test I isa NamedCartesianIndex{2} + @test I isa NamedDimsCartesianIndex{2} @test issetequal(name.(Tuple(I)), (:i, :j)) nc[I] = na[I] + nb[I] end @@ -201,26 +203,26 @@ using NamedDimsArrays: for na in (zeros(elt, i, j), zeros(elt, (i, j))) @test eltype(na) === elt @test na isa NamedDimsArray - @test dimnames(na) == ("i", "j") + @test nameddimsindices(na) == Base.oneto.((i, j)) @test iszero(na) end for na in (fill(value, i, j), fill(value, (i, j))) @test eltype(na) === elt @test na isa NamedDimsArray - @test dimnames(na) == ("i", "j") + @test nameddimsindices(na) == Base.oneto.((i, j)) @test all(isequal(value), na) end for na in (rand(elt, i, j), rand(elt, (i, j))) @test eltype(na) === elt @test na isa NamedDimsArray - @test dimnames(na) == ("i", "j") + @test nameddimsindices(na) == Base.oneto.((i, j)) @test !iszero(na) @test all(x -> real(x) > 0, na) end for na in (randn(elt, i, j), randn(elt, (i, j))) @test eltype(na) === elt @test na isa NamedDimsArray - @test dimnames(na) == ("i", "j") + @test nameddimsindices(na) == Base.oneto.((i, j)) @test !iszero(na) end end @@ -230,20 +232,20 @@ using NamedDimsArrays: for na in (zeros(i, j), zeros((i, j))) @test eltype(na) === default_elt @test na isa NamedDimsArray - @test dimnames(na) == ("i", "j") + @test nameddimsindices(na) == Base.oneto.((i, j)) @test iszero(na) end for na in (rand(i, j), rand((i, j))) @test eltype(na) === default_elt @test na isa NamedDimsArray - @test dimnames(na) == ("i", "j") + @test nameddimsindices(na) == Base.oneto.((i, j)) @test !iszero(na) @test all(x -> real(x) > 0, na) end for na in (randn(i, j), randn((i, j))) @test eltype(na) === default_elt @test na isa NamedDimsArray - @test dimnames(na) == ("i", "j") + @test nameddimsindices(na) == Base.oneto.((i, j)) @test !iszero(na) end end diff --git a/test/basics/test_blockarraysext.jl b/test/basics/test_blockarraysext.jl new file mode 100644 index 0000000..ab8ce07 --- /dev/null +++ b/test/basics/test_blockarraysext.jl @@ -0,0 +1,21 @@ +using BlockArrays: Block +using BlockSparseArrays: BlockSparseArray +using NamedDimsArrays: dename, nameddims, nameddimsindices +using Test: @test, @testset + +@testset "NamedDimsArraysBlockArraysExt" begin + elt = Float64 + a = BlockSparseArray{elt}([2, 3], [2, 3]) + a[Block(2, 1)] = randn(elt, 3, 2) + a[Block(1, 2)] = randn(elt, 2, 3) + n = nameddims(a, ("i", "j")) + i, j = nameddimsindices(n) + @test dename(n[i[Block(2)], j[Block(1)]]) == a[Block(2, 1)] + @test dename(n[Block(2), Block(1)]) == a[Block(2, 1)] + @test dename(n[Block(2, 1)]) == a[Block(2, 1)] + @test dename(n[i[Block(2)], j[Block.(1:2)]]) == a[Block(2), Block.(1:2)] + @test dename(n[Block(2), Block.(1:2)]) == a[Block(2), Block.(1:2)] + @test dename(n[i[Block.(1:2)], j[Block(1)]]) == a[Block.(1:2), Block(1)] + @test dename(n[Block.(1:2), Block(1)]) == a[Block.(1:2), Block(1)] + @test dename(n[Block.(1:2), Block.(1:2)]) == a[Block.(1:2), Block.(1:2)] +end diff --git a/test/basics/test_tensoralgebra.jl b/test/basics/test_tensoralgebra.jl index ceaa4fb..6979df1 100644 --- a/test/basics/test_tensoralgebra.jl +++ b/test/basics/test_tensoralgebra.jl @@ -1,13 +1,13 @@ using LinearAlgebra: qr -using NamedDimsArrays: named, dename +using NamedDimsArrays: namedoneto, dename using TensorAlgebra: TensorAlgebra, contract, fusedims, splitdims using Test: @test, @testset, @test_broken elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) @testset "TensorAlgebra (eltype=$(elt))" for elt in elts @testset "contract" begin - i = named(2, "i") - j = named(2, "j") - k = named(2, "k") + i = namedoneto(2, "i") + j = namedoneto(2, "j") + k = namedoneto(2, "k") na1 = randn(elt, i, j) na2 = randn(elt, j, k) na_dest = contract(na1, na2) @@ -15,20 +15,24 @@ elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) @test dename(na_dest, (i, k)) ≈ dename(na1) * dename(na2) end @testset "fusedims" begin - i, j, k, l = named.((2, 3, 4, 5), ("i", "j", "k", "l")) + i, j, k, l = namedoneto.((2, 3, 4, 5), ("i", "j", "k", "l")) na = randn(elt, i, j, k, l) na_fused = fusedims(na, (k, i) => "a", (j, l) => "b") # Fuse all dimensions. - @test dename(na_fused, ("a", "b")) ≈ - reshape(dename(na, (k, i, j, l)), (dename(k) * dename(i), dename(j) * dename(l))) + @test dename(na_fused, ("a", "b")) ≈ reshape( + dename(na, (k, i, j, l)), + (dename(length(k)) * dename(length(i)), dename(length(j)) * dename(length(l))), + ) na_fused = fusedims(na, (k, i) => "a") # Fuse a subset of dimensions. - @test dename(na_fused, ("a", "j", "l")) ≈ - reshape(dename(na, (k, i, j, l)), (dename(k) * dename(i), dename(j), dename(l))) + @test dename(na_fused, ("a", "j", "l")) ≈ reshape( + dename(na, (k, i, j, l)), + (dename(length(k)) * dename(length(i)), dename(length(j)), dename(length(l))), + ) end @testset "splitdims" begin - a, b = named.((6, 20), ("a", "b")) - i, j, k, l = named.((2, 3, 4, 5), ("i", "j", "k", "l")) + a, b = namedoneto.((6, 20), ("a", "b")) + i, j, k, l = namedoneto.((2, 3, 4, 5), ("i", "j", "k", "l")) na = randn(elt, a, b) # Split all dimensions. na_split = splitdims(na, "a" => (k, i), "b" => (j, l)) @@ -41,7 +45,7 @@ elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) end @testset "qr" begin dims = (2, 2, 2, 2) - i, j, k, l = named.(dims, ("i", "j", "k", "l")) + i, j, k, l = namedoneto.(dims, ("i", "j", "k", "l")) na = randn(elt, i, j) # TODO: Should this be allowed?