Skip to content

Commit

Permalink
Merge branch 'main' into blockedunitrange_nondual
Browse files Browse the repository at this point in the history
  • Loading branch information
ogauthe committed Nov 14, 2024
2 parents 229412c + 57994ff commit 60adbee
Show file tree
Hide file tree
Showing 10 changed files with 184 additions and 1 deletion.
2 changes: 1 addition & 1 deletion NDTensors/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "NDTensors"
uuid = "23ae76d9-e61a-49c4-8f12-3f1a16adf9cf"
authors = ["Matthew Fishman <[email protected]>"]
version = "0.3.63"
version = "0.3.64"

[deps]
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
Expand Down
2 changes: 2 additions & 0 deletions NDTensors/src/lib/BlockSparseArrays/src/BlockSparseArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ include("blocksparsearrayinterface/broadcast.jl")
include("blocksparsearrayinterface/map.jl")
include("blocksparsearrayinterface/arraylayouts.jl")
include("blocksparsearrayinterface/views.jl")
include("blocksparsearrayinterface/cat.jl")
include("abstractblocksparsearray/abstractblocksparsearray.jl")
include("abstractblocksparsearray/wrappedabstractblocksparsearray.jl")
include("abstractblocksparsearray/abstractblocksparsematrix.jl")
Expand All @@ -17,6 +18,7 @@ include("abstractblocksparsearray/sparsearrayinterface.jl")
include("abstractblocksparsearray/broadcast.jl")
include("abstractblocksparsearray/map.jl")
include("abstractblocksparsearray/linearalgebra.jl")
include("abstractblocksparsearray/cat.jl")
include("blocksparsearray/defaults.jl")
include("blocksparsearray/blocksparsearray.jl")
include("BlockArraysSparseArrayInterfaceExt/BlockArraysSparseArrayInterfaceExt.jl")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# TODO: Change to `AnyAbstractBlockSparseArray`.
function Base.cat(as::BlockSparseArrayLike...; dims)
# TODO: Use `sparse_cat` instead, currently
# that erroneously allocates too many blocks that are
# zero and shouldn't be stored.
return blocksparse_cat(as...; dims)
end
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
using BlockArrays: AbstractBlockedUnitRange, blockedrange, blocklengths
using NDTensors.SparseArrayInterface: SparseArrayInterface, allocate_cat_output, sparse_cat!

# TODO: Maybe move to `SparseArrayInterfaceBlockArraysExt`.
# TODO: Handle dual graded unit ranges, for example in a new `SparseArrayInterfaceGradedAxesExt`.
function SparseArrayInterface.axis_cat(
a1::AbstractBlockedUnitRange, a2::AbstractBlockedUnitRange
)
return blockedrange(vcat(blocklengths(a1), blocklengths(a2)))
end

# that erroneously allocates too many blocks that are
# zero and shouldn't be stored.
function blocksparse_cat!(a_dest::AbstractArray, as::AbstractArray...; dims)
sparse_cat!(blocks(a_dest), blocks.(as)...; dims)
return a_dest
end

# TODO: Delete this in favor of `sparse_cat`, currently
# that erroneously allocates too many blocks that are
# zero and shouldn't be stored.
function blocksparse_cat(as::AbstractArray...; dims)
a_dest = allocate_cat_output(as...; dims)
blocksparse_cat!(a_dest, as...; dims)
return a_dest
end
27 changes: 27 additions & 0 deletions NDTensors/src/lib/BlockSparseArrays/test/test_basics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -866,6 +866,33 @@ using .NDTensorsTestUtils: devices_list, is_supported_eltype
@test a1' * a2 Array(a1)' * Array(a2)
@test dot(a1, a2) a1' * a2
end
@testset "cat" begin
a1 = dev(BlockSparseArray{elt}([2, 3], [2, 3]))
a1[Block(2, 1)] = dev(randn(elt, size(@view(a1[Block(2, 1)]))))
a2 = dev(BlockSparseArray{elt}([2, 3], [2, 3]))
a2[Block(1, 2)] = dev(randn(elt, size(@view(a2[Block(1, 2)]))))

a_dest = cat(a1, a2; dims=1)
@test block_nstored(a_dest) == 2
@test blocklengths.(axes(a_dest)) == ([2, 3, 2, 3], [2, 3])
@test issetequal(block_stored_indices(a_dest), [Block(2, 1), Block(3, 2)])
@test a_dest[Block(2, 1)] == a1[Block(2, 1)]
@test a_dest[Block(3, 2)] == a2[Block(1, 2)]

a_dest = cat(a1, a2; dims=2)
@test block_nstored(a_dest) == 2
@test blocklengths.(axes(a_dest)) == ([2, 3], [2, 3, 2, 3])
@test issetequal(block_stored_indices(a_dest), [Block(2, 1), Block(1, 4)])
@test a_dest[Block(2, 1)] == a1[Block(2, 1)]
@test a_dest[Block(1, 4)] == a2[Block(1, 2)]

a_dest = cat(a1, a2; dims=(1, 2))
@test block_nstored(a_dest) == 2
@test blocklengths.(axes(a_dest)) == ([2, 3, 2, 3], [2, 3, 2, 3])
@test issetequal(block_stored_indices(a_dest), [Block(2, 1), Block(3, 4)])
@test a_dest[Block(2, 1)] == a1[Block(2, 1)]
@test a_dest[Block(3, 4)] == a2[Block(1, 2)]
end
@testset "TensorAlgebra" begin
a1 = dev(BlockSparseArray{elt}([2, 3], [2, 3]))
a1[Block(1, 1)] = dev(randn(elt, size(@view(a1[Block(1, 1)]))))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ include("sparsearrayinterface/broadcast.jl")
include("sparsearrayinterface/conversion.jl")
include("sparsearrayinterface/wrappers.jl")
include("sparsearrayinterface/zero.jl")
include("sparsearrayinterface/cat.jl")
include("sparsearrayinterface/SparseArrayInterfaceLinearAlgebraExt.jl")
include("abstractsparsearray/abstractsparsearray.jl")
include("abstractsparsearray/abstractsparsematrix.jl")
Expand All @@ -24,6 +25,7 @@ include("abstractsparsearray/broadcast.jl")
include("abstractsparsearray/map.jl")
include("abstractsparsearray/baseinterface.jl")
include("abstractsparsearray/convert.jl")
include("abstractsparsearray/cat.jl")
include("abstractsparsearray/SparseArrayInterfaceSparseArraysExt.jl")
include("abstractsparsearray/SparseArrayInterfaceLinearAlgebraExt.jl")
end
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# TODO: Change to `AnyAbstractSparseArray`.
function Base.cat(as::SparseArrayLike...; dims)
return sparse_cat(as...; dims)
end
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
unval(x) = x
unval(::Val{x}) where {x} = x

# TODO: Assert that `a1` and `a2` start at one.
axis_cat(a1::AbstractUnitRange, a2::AbstractUnitRange) = Base.OneTo(length(a1) + length(a2))
function axis_cat(
a1::AbstractUnitRange, a2::AbstractUnitRange, a_rest::AbstractUnitRange...
)
return axis_cat(axis_cat(a1, a2), a_rest...)
end
function cat_axes(as::AbstractArray...; dims)
return ntuple(length(first(axes.(as)))) do dim
return if dim in unval(dims)
axis_cat(map(axes -> axes[dim], axes.(as))...)
else
axes(first(as))[dim]
end
end
end

function allocate_cat_output(as::AbstractArray...; dims)
eltype_dest = promote_type(eltype.(as)...)
axes_dest = cat_axes(as...; dims)
# TODO: Promote the block types of the inputs rather than using
# just the first input.
# TODO: Make this customizable with `cat_similar`.
# TODO: Base the zero element constructor on those of the inputs,
# for example block sparse arrays.
return similar(first(as), eltype_dest, axes_dest...)
end

# https://github.com/JuliaLang/julia/blob/v1.11.1/base/abstractarray.jl#L1748-L1857
# https://docs.julialang.org/en/v1/base/arrays/#Concatenation-and-permutation
# This is very similar to the `Base.cat` implementation but handles zero values better.
function cat_offset!(
a_dest::AbstractArray, offsets, a1::AbstractArray, a_rest::AbstractArray...; dims
)
inds = ntuple(ndims(a_dest)) do dim
dim in unval(dims) ? offsets[dim] .+ axes(a1, dim) : axes(a_dest, dim)
end
a_dest[inds...] = a1
new_offsets = ntuple(ndims(a_dest)) do dim
dim in unval(dims) ? offsets[dim] + size(a1, dim) : offsets[dim]
end
cat_offset!(a_dest, new_offsets, a_rest...; dims)
return a_dest
end
function cat_offset!(a_dest::AbstractArray, offsets; dims)
return a_dest
end

# TODO: Define a generic `cat!` function.
function sparse_cat!(a_dest::AbstractArray, as::AbstractArray...; dims)
offsets = ntuple(zero, ndims(a_dest))
# TODO: Fill `a_dest` with zeros if needed.
cat_offset!(a_dest, offsets, as...; dims)
return a_dest
end

function sparse_cat(as::AbstractArray...; dims)
a_dest = allocate_cat_output(as...; dims)
sparse_cat!(a_dest, as...; dims)
return a_dest
end
Original file line number Diff line number Diff line change
Expand Up @@ -137,12 +137,31 @@ function sparse_setindex!(a::AbstractArray, value, I::Vararg{Int})
return a
end

# Fix ambiguity error
function sparse_setindex!(a::AbstractArray, value)
sparse_setindex!(a, value, CartesianIndex())
return a
end

# Linear indexing
function sparse_setindex!(a::AbstractArray, value, I::CartesianIndex{1})
sparse_setindex!(a, value, CartesianIndices(a)[I])
return a
end

# Slicing
# TODO: Make this handle more general slicing operations,
# base it off of `ArrayLayouts.sub_materialize`.
function sparse_setindex!(a::AbstractArray, value, I::AbstractUnitRange...)
inds = CartesianIndices(I)
for i in stored_indices(value)
if i in CartesianIndices(inds)
a[inds[i]] = value[i]
end
end
return a
end

# Handle trailing indices
function sparse_setindex!(a::AbstractArray, value, I::CartesianIndex)
t = Tuple(I)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,38 @@ using Test: @test, @testset
@test a_dest isa SparseArray{elt}
@test SparseArrayInterface.nstored(a_dest) == 2

# cat
a1 = SparseArray{elt}(2, 3)
a1[1, 2] = 12
a1[2, 1] = 21
a2 = SparseArray{elt}(2, 3)
a2[1, 1] = 11
a2[2, 2] = 22

a_dest = cat(a1, a2; dims=1)
@test size(a_dest) == (4, 3)
@test SparseArrayInterface.nstored(a_dest) == 4
@test a_dest[1, 2] == a1[1, 2]
@test a_dest[2, 1] == a1[2, 1]
@test a_dest[3, 1] == a2[1, 1]
@test a_dest[4, 2] == a2[2, 2]

a_dest = cat(a1, a2; dims=2)
@test size(a_dest) == (2, 6)
@test SparseArrayInterface.nstored(a_dest) == 4
@test a_dest[1, 2] == a1[1, 2]
@test a_dest[2, 1] == a1[2, 1]
@test a_dest[1, 4] == a2[1, 1]
@test a_dest[2, 5] == a2[2, 2]

a_dest = cat(a1, a2; dims=(1, 2))
@test size(a_dest) == (4, 6)
@test SparseArrayInterface.nstored(a_dest) == 4
@test a_dest[1, 2] == a1[1, 2]
@test a_dest[2, 1] == a1[2, 1]
@test a_dest[3, 4] == a2[1, 1]
@test a_dest[4, 5] == a2[2, 2]

## # Sparse matrix of matrix multiplication
## TODO: Make this work, seems to require
## a custom zero constructor.
Expand Down

0 comments on commit 60adbee

Please sign in to comment.