diff --git a/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl b/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl index 53c883bf54..41f7fc3478 100644 --- a/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl +++ b/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl @@ -1,6 +1,6 @@ @eval module $(gensym()) using Compat: Returns -using Test: @test, @testset, @test_broken +using Test: @test, @testset using BlockArrays: AbstractBlockArray, Block, BlockedOneTo, blockedrange, blocklengths, blocksize using NDTensors.BlockSparseArrays: BlockSparseArray, block_nstored @@ -217,10 +217,10 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) @test size(a[I, I]) == (1, 1) @test isdual(axes(a[I, :], 2)) @test isdual(axes(a[:, I], 1)) - @test_broken isdual(axes(a[I, :], 1)) - @test_broken isdual(axes(a[:, I], 2)) - @test_broken isdual(axes(a[I, I], 1)) - @test_broken isdual(axes(a[I, I], 2)) + @test isdual(axes(a[I, :], 1)) + @test isdual(axes(a[:, I], 2)) + @test isdual(axes(a[I, I], 1)) + @test isdual(axes(a[I, I], 2)) end @testset "dual GradedUnitRange" begin @@ -243,10 +243,10 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) @test size(a[I, I]) == (1, 1) @test isdual(axes(a[I, :], 2)) @test isdual(axes(a[:, I], 1)) - @test_broken isdual(axes(a[I, :], 1)) - @test_broken isdual(axes(a[:, I], 2)) - @test_broken isdual(axes(a[I, I], 1)) - @test_broken isdual(axes(a[I, I], 2)) + @test isdual(axes(a[I, :], 1)) + @test isdual(axes(a[:, I], 2)) + @test isdual(axes(a[I, I], 1)) + @test isdual(axes(a[I, I], 2)) end @testset "dual BlockedUnitRange" begin # self dual diff --git a/NDTensors/src/lib/GradedAxes/src/GradedAxes.jl b/NDTensors/src/lib/GradedAxes/src/GradedAxes.jl index 7edd09bf84..ba17c175f0 100644 --- a/NDTensors/src/lib/GradedAxes/src/GradedAxes.jl +++ b/NDTensors/src/lib/GradedAxes/src/GradedAxes.jl @@ -2,6 +2,7 @@ module GradedAxes include("blockedunitrange.jl") include("gradedunitrange.jl") include("dual.jl") +include("labelledunitrangedual.jl") include("gradedunitrangedual.jl") include("onetoone.jl") include("fusion.jl") diff --git a/NDTensors/src/lib/GradedAxes/src/dual.jl b/NDTensors/src/lib/GradedAxes/src/dual.jl index ca985e30a0..877ba1a857 100644 --- a/NDTensors/src/lib/GradedAxes/src/dual.jl +++ b/NDTensors/src/lib/GradedAxes/src/dual.jl @@ -1,5 +1,5 @@ -# default behavior: self-dual -dual(r::AbstractUnitRange) = r +# default behavior: any object is self-dual +dual(x) = x nondual(r::AbstractUnitRange) = r isdual(::AbstractUnitRange) = false @@ -11,4 +11,5 @@ label_dual(x) = label_dual(LabelledStyle(x), x) label_dual(::NotLabelled, x) = x label_dual(::IsLabelled, x) = labelled(unlabel(x), dual(label(x))) +flip(a::AbstractUnitRange) = dual(label_dual(a)) flip(g::AbstractGradedUnitRange) = dual(gradedrange(label_dual.(blocklengths(g)))) diff --git a/NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl b/NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl index 0bd35707a7..76eaf42692 100644 --- a/NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl +++ b/NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl @@ -68,6 +68,7 @@ end # == is just a range comparison that ignores labels. Need dedicated function to check equality. struct NoLabel end blocklabels(r::AbstractUnitRange) = Fill(NoLabel(), blocklength(r)) +blocklabels(la::LabelledUnitRange) = [label(la)] function LabelledNumbers.labelled_isequal(a1::AbstractUnitRange, a2::AbstractUnitRange) return blockisequal(a1, a2) && (blocklabels(a1) == blocklabels(a2)) diff --git a/NDTensors/src/lib/GradedAxes/src/gradedunitrangedual.jl b/NDTensors/src/lib/GradedAxes/src/gradedunitrangedual.jl index 217d4b401f..97c8a96d71 100644 --- a/NDTensors/src/lib/GradedAxes/src/gradedunitrangedual.jl +++ b/NDTensors/src/lib/GradedAxes/src/gradedunitrangedual.jl @@ -31,11 +31,15 @@ function blockedunitrange_getindices(a::GradedUnitRangeDual, indices::Integer) end function blockedunitrange_getindices(a::GradedUnitRangeDual, indices::Block{1}) - return label_dual(getindex(nondual(a), indices)) + return dual(getindex(nondual(a), indices)) end function blockedunitrange_getindices(a::GradedUnitRangeDual, indices::BlockRange) - return label_dual(getindex(nondual(a), indices)) + return dual(getindex(nondual(a), indices)) +end + +function blockedunitrange_getindices(a::GradedUnitRangeDual, indices::BlockIndexRange) + return dual(nondual(a)[indices]) end # fix ambiguity @@ -49,20 +53,51 @@ function BlockArrays.blocklengths(a::GradedUnitRangeDual) return dual.(blocklengths(nondual(a))) end -function gradedunitrangedual_getindices_blocks(a::GradedUnitRangeDual, indices) +# TODO: Move this to a `BlockArraysExtensions` library. +function blockedunitrange_getindices( + a::GradedUnitRangeDual, indices::Vector{<:BlockIndexRange{1}} +) a_indices = getindex(nondual(a), indices) - return mortar([label_dual(b) for b in blocks(a_indices)]) + v = mortar(dual.(blocks(a_indices))) + # flip v to stay consistent with other cases where axes(v) are used + return flip_blockvector(v) end -# TODO: Move this to a `BlockArraysExtensions` library. -function blockedunitrange_getindices(a::GradedUnitRangeDual, indices::Vector{<:Block{1}}) - return gradedunitrangedual_getindices_blocks(a, indices) +function blockedunitrange_getindices( + a::GradedUnitRangeDual, + indices::BlockVector{<:BlockIndex{1},<:Vector{<:BlockIndexRange{1}}}, +) + v = mortar(map(b -> a[b], blocks(indices))) + # GradedOneTo appears in mortar + # flip v axis to preserve dual information + # axes(v) will appear in axes(view(::BlockSparseArray, [Block(1)[1:1]])) + return flip_blockvector(v) end function blockedunitrange_getindices( - a::GradedUnitRangeDual, indices::Vector{<:BlockIndexRange{1}} + a::GradedUnitRangeDual, indices::AbstractVector{<:Union{Block{1},BlockIndexRange{1}}} ) - return gradedunitrangedual_getindices_blocks(a, indices) + # Without converting `indices` to `Vector`, + # mapping `indices` outputs a `BlockVector` + # which is harder to reason about. + vblocks = map(index -> a[index], Vector(indices)) + # We pass `length.(blocks)` to `mortar` in order + # to pass block labels to the axes of the output, + # if they exist. This makes it so that + # `only(axes(a[indices])) isa `GradedUnitRange` + # if `a isa `GradedUnitRange`, for example. + + v = mortar(vblocks, length.(vblocks)) + # GradedOneTo appears in mortar + # flip v axis to preserve dual information + # axes(v) will appear in axes(view(::BlockSparseArray, [Block(1)])) + return flip_blockvector(v) +end + +function flip_blockvector(v::BlockVector) + block_axes = flip.(axes(v)) + flipped = mortar(vec.(blocks(v)), block_axes) + return flipped end Base.axes(a::GradedUnitRangeDual) = axes(nondual(a)) diff --git a/NDTensors/src/lib/GradedAxes/src/labelledunitrangedual.jl b/NDTensors/src/lib/GradedAxes/src/labelledunitrangedual.jl new file mode 100644 index 0000000000..466d64945b --- /dev/null +++ b/NDTensors/src/lib/GradedAxes/src/labelledunitrangedual.jl @@ -0,0 +1,49 @@ +# LabelledUnitRangeDual is obtained by slicing a GradedUnitRangeDual with a block + +using ..LabelledNumbers: LabelledNumbers, label, labelled, unlabel + +struct LabelledUnitRangeDual{T,NondualUnitRange<:AbstractUnitRange{T}} <: + AbstractUnitRange{T} + nondual_unitrange::NondualUnitRange +end + +dual(a::LabelledUnitRange) = LabelledUnitRangeDual(a) +nondual(a::LabelledUnitRangeDual) = a.nondual_unitrange +dual(a::LabelledUnitRangeDual) = nondual(a) +label_dual(::IsLabelled, a::LabelledUnitRangeDual) = dual(label_dual(nondual(a))) +isdual(::LabelledUnitRangeDual) = true +blocklabels(la::LabelledUnitRangeDual) = [label(la)] + +LabelledNumbers.label(a::LabelledUnitRangeDual) = dual(label(nondual(a))) +LabelledNumbers.unlabel(a::LabelledUnitRangeDual) = unlabel(nondual(a)) +LabelledNumbers.LabelledStyle(::LabelledUnitRangeDual) = IsLabelled() + +for f in [:first, :getindex, :last, :length, :step] + @eval Base.$f(a::LabelledUnitRangeDual, args...) = + labelled($f(unlabel(a), args...), label(a)) +end + +# fix ambiguities +Base.getindex(a::LabelledUnitRangeDual, i::Integer) = dual(nondual(a)[i]) +function Base.getindex(a::LabelledUnitRangeDual, indices::AbstractUnitRange{<:Integer}) + return dual(nondual(a)[indices]) +end + +function Base.iterate(a::LabelledUnitRangeDual, i) + i == last(a) && return nothing + next = convert(eltype(a), labelled(i + step(a), label(a))) + return (next, next) +end + +function Base.show(io::IO, ::MIME"text/plain", a::LabelledUnitRangeDual) + println(io, typeof(a)) + return print(io, label(a), " => ", unlabel(a)) +end + +function Base.show(io::IO, a::LabelledUnitRangeDual) + return print(io, nameof(typeof(a)), " ", label(a), " => ", unlabel(a)) +end + +function Base.AbstractUnitRange{T}(a::LabelledUnitRangeDual) where {T} + return AbstractUnitRange{T}(nondual(a)) +end diff --git a/NDTensors/src/lib/GradedAxes/test/test_dual.jl b/NDTensors/src/lib/GradedAxes/test/test_dual.jl index 18dbac045c..98b8838542 100644 --- a/NDTensors/src/lib/GradedAxes/test/test_dual.jl +++ b/NDTensors/src/lib/GradedAxes/test/test_dual.jl @@ -17,6 +17,7 @@ using NDTensors.GradedAxes: AbstractGradedUnitRange, GradedAxes, GradedUnitRangeDual, + LabelledUnitRangeDual, OneToOne, blocklabels, blockmergesortperm, @@ -27,7 +28,8 @@ using NDTensors.GradedAxes: gradedrange, isdual, nondual -using NDTensors.LabelledNumbers: LabelledInteger, label, labelled, labelled_isequal +using NDTensors.LabelledNumbers: + LabelledInteger, LabelledUnitRange, label, label_type, labelled, labelled_isequal, unlabel using Test: @test, @test_broken, @testset struct U1 n::Int @@ -58,6 +60,92 @@ Base.isless(c1::U1, c2::U1) = c1.n < c2.n @test blockisequal(ad, a) end +@testset "LabelledUnitRangeDual" begin + la = labelled(1:2, U1(1)) + @test la isa LabelledUnitRange + @test label(la) == U1(1) + @test blocklabels(la) == [U1(1)] + @test unlabel(la) == 1:2 + @test la == 1:2 + @test !isdual(la) + @test labelled_isequal(la, la) + @test space_isequal(la, la) + @test label_type(la) == U1 + + @test iterate(la) == (1, 1) + @test iterate(la) == (1, 1) + @test iterate(la, 1) == (2, 2) + @test isnothing(iterate(la, 2)) + + lad = dual(la) + @test lad isa LabelledUnitRangeDual + @test label(lad) == U1(-1) + @test blocklabels(lad) == [U1(-1)] + @test unlabel(lad) == 1:2 + @test lad == 1:2 + @test labelled_isequal(lad, lad) + @test space_isequal(lad, lad) + @test !labelled_isequal(la, lad) + @test !space_isequal(la, lad) + @test isdual(lad) + @test nondual(lad) === la + @test dual(lad) === la + @test label_type(lad) == U1 + + @test iterate(lad) == (1, 1) + @test iterate(lad) == (1, 1) + @test iterate(lad, 1) == (2, 2) + @test isnothing(iterate(lad, 2)) + + lad2 = lad[1:1] + @test lad2 isa LabelledUnitRangeDual + @test label(lad2) == U1(-1) + @test unlabel(lad2) == 1:1 + + laf = flip(la) + @test laf isa LabelledUnitRangeDual + @test label(laf) == U1(1) + @test unlabel(laf) == 1:2 + @test labelled_isequal(la, laf) + @test !space_isequal(la, laf) + + ladf = flip(dual(la)) + @test ladf isa LabelledUnitRange + @test label(ladf) == U1(-1) + @test unlabel(ladf) == 1:2 + + lafd = dual(flip(la)) + @test lafd isa LabelledUnitRange + @test label(lafd) == U1(-1) + @test unlabel(lafd) == 1:2 + + # check default behavior for objects without dual + la = labelled(1:2, 'x') + lad = dual(la) + @test lad isa LabelledUnitRangeDual + @test label(lad) == 'x' + @test blocklabels(lad) == ['x'] + @test unlabel(lad) == 1:2 + @test lad == 1:2 + @test labelled_isequal(lad, lad) + @test space_isequal(lad, lad) + @test labelled_isequal(la, lad) + @test !space_isequal(la, lad) + @test isdual(lad) + @test nondual(lad) === la + @test dual(lad) === la + + laf = flip(la) + @test laf isa LabelledUnitRangeDual + @test label(laf) == 'x' + @test unlabel(laf) == 1:2 + + ladf = flip(lad) + @test ladf isa LabelledUnitRange + @test label(ladf) == 'x' + @test unlabel(ladf) == 1:2 +end + @testset "GradedUnitRangeDual" begin for a in [gradedrange([U1(0) => 2, U1(1) => 3]), gradedrange([U1(0) => 2, U1(1) => 3])[1:5]] @@ -124,13 +212,21 @@ end @test blockmergesortperm(a) == [Block(1), Block(2)] @test blockmergesortperm(ad) == [Block(1), Block(2)] - @test_broken isdual(ad[Block(1)]) - @test_broken isdual(ad[Block(1)[1:1]]) + @test isdual(ad[Block(1)]) + @test isdual(ad[Block(1)[1:1]]) + @test ad[Block(1)] isa LabelledUnitRangeDual + @test ad[Block(1)[1:1]] isa LabelledUnitRangeDual + @test label(ad[Block(2)]) == U1(-1) + @test label(ad[Block(2)[1:1]]) == U1(-1) + I = mortar([Block(2)[1:1]]) g = ad[I] @test length(g) == 1 @test label(first(g)) == U1(-1) - @test_broken isdual(g[Block(1)]) + @test isdual(g[Block(1)]) + + @test isdual(axes(ad[[Block(1)]], 1)) # used in view(::BlockSparseVector, [Block(1)]) + @test isdual(axes(ad[mortar([Block(1)[1:1]])], 1)) # used in view(::BlockSparseVector, [Block(1)[1:1]]) end end diff --git a/NDTensors/src/lib/LabelledNumbers/src/labelledunitrange.jl b/NDTensors/src/lib/LabelledNumbers/src/labelledunitrange.jl index 03965f62f5..4f432c9226 100644 --- a/NDTensors/src/lib/LabelledNumbers/src/labelledunitrange.jl +++ b/NDTensors/src/lib/LabelledNumbers/src/labelledunitrange.jl @@ -52,3 +52,12 @@ function Base.iterate(a::LabelledUnitRange, i) next = convert(eltype(a), labelled(i + step(a), label(a))) return (next, next) end + +function Base.show(io::IO, ::MIME"text/plain", a::LabelledUnitRange) + println(io, typeof(a)) + return print(io, label(a), " => ", unlabel(a)) +end + +function Base.show(io::IO, a::LabelledUnitRange) + return print(io, nameof(typeof(a)), " ", label(a), " => ", unlabel(a)) +end