Skip to content

Commit

Permalink
reduce diff
Browse files Browse the repository at this point in the history
  • Loading branch information
ogauthe committed Oct 29, 2024
1 parent 5c5abd6 commit d7e3a8d
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 10 deletions.
8 changes: 2 additions & 6 deletions NDTensors/src/lib/GradedAxes/src/fusion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,18 +79,13 @@ end
# Used by `TensorAlgebra.splitdims` in `BlockSparseArraysGradedAxesExt`.
# Get the permutation for sorting, then group by common elements.
# groupsortperm([2, 1, 2, 3]) == [[2], [1, 3], [4]]
blockmergesort(g::AbstractUnitRange) = g
function blockmergesortperm(a::AbstractUnitRange)
return Block.(groupsortperm(blocklabels(a)))
return Block.(groupsortperm(blocklabels(nondual(a))))
end

# Used by `TensorAlgebra.splitdims` in `BlockSparseArraysGradedAxesExt`.
invblockperm(a::Vector{<:Block{1}}) = Block.(invperm(Int.(a)))

function blockmergesortperm(a::GradedUnitRangeDual)
return Block.(groupsortperm(blocklabels(nondual(a))))
end

function blockmergesort(g::AbstractGradedUnitRange)
glabels = blocklabels(g)
gblocklengths = blocklengths(g)
Expand All @@ -101,6 +96,7 @@ function blockmergesort(g::AbstractGradedUnitRange)
end

blockmergesort(g::GradedUnitRangeDual) = flip(blockmergesort(flip(g)))
blockmergesort(g::AbstractUnitRange) = g

# fusion_product produces a sorted, non-dual GradedUnitRange
function fusion_product(g1, g2)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,13 @@
using BlockArrays: Block, blocksize
using Compat: Returns
using NDTensors.BlockSparseArrays: BlockSparseArray
using NDTensors.GradedAxes: GradedAxes, gradedrange
using NDTensors.GradedAxes: gradedrange
using NDTensors.SparseArrayInterface: densearray
using NDTensors.SymmetrySectors: U1
using NDTensors.TensorAlgebra: contract
using Random: randn!
using Test: @test, @testset

#TODO remove once fuse_labels is defined in Sectors
GradedAxes.fuse_labels(m::U1, n::U1) = U1(m.n + n.n)

function randn_blockdiagonal(elt::Type, axes::Tuple)
a = BlockSparseArray{elt}(axes)
blockdiaglength = minimum(blocksize(a))
Expand Down

0 comments on commit d7e3a8d

Please sign in to comment.