Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BlockSparseArrays] Direct sum/cat #1579

Merged
merged 5 commits into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion NDTensors/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "NDTensors"
uuid = "23ae76d9-e61a-49c4-8f12-3f1a16adf9cf"
authors = ["Matthew Fishman <[email protected]>"]
version = "0.3.63"
version = "0.3.64"

[deps]
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
Expand Down
2 changes: 2 additions & 0 deletions NDTensors/src/lib/BlockSparseArrays/src/BlockSparseArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ include("blocksparsearrayinterface/broadcast.jl")
include("blocksparsearrayinterface/map.jl")
include("blocksparsearrayinterface/arraylayouts.jl")
include("blocksparsearrayinterface/views.jl")
include("blocksparsearrayinterface/cat.jl")
include("abstractblocksparsearray/abstractblocksparsearray.jl")
include("abstractblocksparsearray/wrappedabstractblocksparsearray.jl")
include("abstractblocksparsearray/abstractblocksparsematrix.jl")
Expand All @@ -17,6 +18,7 @@ include("abstractblocksparsearray/sparsearrayinterface.jl")
include("abstractblocksparsearray/broadcast.jl")
include("abstractblocksparsearray/map.jl")
include("abstractblocksparsearray/linearalgebra.jl")
include("abstractblocksparsearray/cat.jl")
include("blocksparsearray/defaults.jl")
include("blocksparsearray/blocksparsearray.jl")
include("BlockArraysSparseArrayInterfaceExt/BlockArraysSparseArrayInterfaceExt.jl")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# TODO: Change to `AnyAbstractBlockSparseArray`.
function Base.cat(as::BlockSparseArrayLike...; dims)
# TODO: Use `sparse_cat` instead, currently
# that erroneously allocates too many blocks that are
# zero and shouldn't be stored.
return blocksparse_cat(as...; dims)
end
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
using BlockArrays: AbstractBlockedUnitRange, blockedrange, blocklengths
using NDTensors.SparseArrayInterface: SparseArrayInterface, allocate_cat_output, sparse_cat!

# TODO: Maybe move to `SparseArrayInterfaceBlockArraysExt`.
# TODO: Handle dual graded unit ranges, for example in a new `SparseArrayInterfaceGradedAxesExt`.
function SparseArrayInterface.axis_cat(
a1::AbstractBlockedUnitRange, a2::AbstractBlockedUnitRange
)
return blockedrange(vcat(blocklengths(a1), blocklengths(a2)))
end

# that erroneously allocates too many blocks that are
# zero and shouldn't be stored.
function blocksparse_cat!(a_dest::AbstractArray, as::AbstractArray...; dims)
sparse_cat!(blocks(a_dest), blocks.(as)...; dims)
return a_dest
end

# TODO: Delete this in favor of `sparse_cat`, currently
# that erroneously allocates too many blocks that are
# zero and shouldn't be stored.
function blocksparse_cat(as::AbstractArray...; dims)
a_dest = allocate_cat_output(as...; dims)
blocksparse_cat!(a_dest, as...; dims)
return a_dest
end
27 changes: 27 additions & 0 deletions NDTensors/src/lib/BlockSparseArrays/test/test_basics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -866,6 +866,33 @@ using .NDTensorsTestUtils: devices_list, is_supported_eltype
@test a1' * a2 ≈ Array(a1)' * Array(a2)
@test dot(a1, a2) ≈ a1' * a2
end
@testset "cat" begin
a1 = dev(BlockSparseArray{elt}([2, 3], [2, 3]))
a1[Block(2, 1)] = dev(randn(elt, size(@view(a1[Block(2, 1)]))))
a2 = dev(BlockSparseArray{elt}([2, 3], [2, 3]))
a2[Block(1, 2)] = dev(randn(elt, size(@view(a2[Block(1, 2)]))))

a_dest = cat(a1, a2; dims=1)
@test block_nstored(a_dest) == 2
@test blocklengths.(axes(a_dest)) == ([2, 3, 2, 3], [2, 3])
@test issetequal(block_stored_indices(a_dest), [Block(2, 1), Block(3, 2)])
@test a_dest[Block(2, 1)] == a1[Block(2, 1)]
@test a_dest[Block(3, 2)] == a2[Block(1, 2)]

a_dest = cat(a1, a2; dims=2)
@test block_nstored(a_dest) == 2
@test blocklengths.(axes(a_dest)) == ([2, 3], [2, 3, 2, 3])
@test issetequal(block_stored_indices(a_dest), [Block(2, 1), Block(1, 4)])
@test a_dest[Block(2, 1)] == a1[Block(2, 1)]
@test a_dest[Block(1, 4)] == a2[Block(1, 2)]

a_dest = cat(a1, a2; dims=(1, 2))
@test block_nstored(a_dest) == 2
@test blocklengths.(axes(a_dest)) == ([2, 3, 2, 3], [2, 3, 2, 3])
@test issetequal(block_stored_indices(a_dest), [Block(2, 1), Block(3, 4)])
@test a_dest[Block(2, 1)] == a1[Block(2, 1)]
@test a_dest[Block(3, 4)] == a2[Block(1, 2)]
end
@testset "TensorAlgebra" begin
a1 = dev(BlockSparseArray{elt}([2, 3], [2, 3]))
a1[Block(1, 1)] = dev(randn(elt, size(@view(a1[Block(1, 1)]))))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ include("sparsearrayinterface/broadcast.jl")
include("sparsearrayinterface/conversion.jl")
include("sparsearrayinterface/wrappers.jl")
include("sparsearrayinterface/zero.jl")
include("sparsearrayinterface/cat.jl")
include("sparsearrayinterface/SparseArrayInterfaceLinearAlgebraExt.jl")
include("abstractsparsearray/abstractsparsearray.jl")
include("abstractsparsearray/abstractsparsematrix.jl")
Expand All @@ -24,6 +25,7 @@ include("abstractsparsearray/broadcast.jl")
include("abstractsparsearray/map.jl")
include("abstractsparsearray/baseinterface.jl")
include("abstractsparsearray/convert.jl")
include("abstractsparsearray/cat.jl")
include("abstractsparsearray/SparseArrayInterfaceSparseArraysExt.jl")
include("abstractsparsearray/SparseArrayInterfaceLinearAlgebraExt.jl")
end
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# TODO: Change to `AnyAbstractSparseArray`.
function Base.cat(as::SparseArrayLike...; dims)
return sparse_cat(as...; dims)
end
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
unval(x) = x
unval(::Val{x}) where {x} = x

# TODO: Assert that `a1` and `a2` start at one.
axis_cat(a1::AbstractUnitRange, a2::AbstractUnitRange) = Base.OneTo(length(a1) + length(a2))
function axis_cat(
a1::AbstractUnitRange, a2::AbstractUnitRange, a_rest::AbstractUnitRange...
)
return axis_cat(axis_cat(a1, a2), a_rest...)
end
function cat_axes(as::AbstractArray...; dims)
return ntuple(length(first(axes.(as)))) do dim
return if dim in unval(dims)
axis_cat(map(axes -> axes[dim], axes.(as))...)
else
axes(first(as))[dim]
end
end
end

function allocate_cat_output(as::AbstractArray...; dims)
eltype_dest = promote_type(eltype.(as)...)
axes_dest = cat_axes(as...; dims)
# TODO: Promote the block types of the inputs rather than using
# just the first input.
# TODO: Make this customizable with `cat_similar`.
# TODO: Base the zero element constructor on those of the inputs,
# for example block sparse arrays.
return similar(first(as), eltype_dest, axes_dest...)
end

# https://github.com/JuliaLang/julia/blob/v1.11.1/base/abstractarray.jl#L1748-L1857
# https://docs.julialang.org/en/v1/base/arrays/#Concatenation-and-permutation
# This is very similar to the `Base.cat` implementation but handles zero values better.
function cat_offset!(
a_dest::AbstractArray, offsets, a1::AbstractArray, a_rest::AbstractArray...; dims
)
inds = ntuple(ndims(a_dest)) do dim
dim in unval(dims) ? offsets[dim] .+ axes(a1, dim) : axes(a_dest, dim)
end
a_dest[inds...] = a1
new_offsets = ntuple(ndims(a_dest)) do dim
dim in unval(dims) ? offsets[dim] + size(a1, dim) : offsets[dim]
end
cat_offset!(a_dest, new_offsets, a_rest...; dims)
return a_dest
end
function cat_offset!(a_dest::AbstractArray, offsets; dims)
return a_dest
end

# TODO: Define a generic `cat!` function.
function sparse_cat!(a_dest::AbstractArray, as::AbstractArray...; dims)
offsets = ntuple(zero, ndims(a_dest))
# TODO: Fill `a_dest` with zeros if needed.
cat_offset!(a_dest, offsets, as...; dims)
return a_dest
end

function sparse_cat(as::AbstractArray...; dims)
a_dest = allocate_cat_output(as...; dims)
sparse_cat!(a_dest, as...; dims)
return a_dest
end
Original file line number Diff line number Diff line change
Expand Up @@ -137,12 +137,31 @@ function sparse_setindex!(a::AbstractArray, value, I::Vararg{Int})
return a
end

# Fix ambiguity error
function sparse_setindex!(a::AbstractArray, value)
sparse_setindex!(a, value, CartesianIndex())
return a
end

# Linear indexing
function sparse_setindex!(a::AbstractArray, value, I::CartesianIndex{1})
sparse_setindex!(a, value, CartesianIndices(a)[I])
return a
end

# Slicing
# TODO: Make this handle more general slicing operations,
# base it off of `ArrayLayouts.sub_materialize`.
function sparse_setindex!(a::AbstractArray, value, I::AbstractUnitRange...)
inds = CartesianIndices(I)
for i in stored_indices(value)
if i in CartesianIndices(inds)
a[inds[i]] = value[i]
end
end
return a
end

# Handle trailing indices
function sparse_setindex!(a::AbstractArray, value, I::CartesianIndex)
t = Tuple(I)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,38 @@ using Test: @test, @testset
@test a_dest isa SparseArray{elt}
@test SparseArrayInterface.nstored(a_dest) == 2

# cat
a1 = SparseArray{elt}(2, 3)
a1[1, 2] = 12
a1[2, 1] = 21
a2 = SparseArray{elt}(2, 3)
a2[1, 1] = 11
a2[2, 2] = 22

a_dest = cat(a1, a2; dims=1)
@test size(a_dest) == (4, 3)
@test SparseArrayInterface.nstored(a_dest) == 4
@test a_dest[1, 2] == a1[1, 2]
@test a_dest[2, 1] == a1[2, 1]
@test a_dest[3, 1] == a2[1, 1]
@test a_dest[4, 2] == a2[2, 2]

a_dest = cat(a1, a2; dims=2)
@test size(a_dest) == (2, 6)
@test SparseArrayInterface.nstored(a_dest) == 4
@test a_dest[1, 2] == a1[1, 2]
@test a_dest[2, 1] == a1[2, 1]
@test a_dest[1, 4] == a2[1, 1]
@test a_dest[2, 5] == a2[2, 2]

a_dest = cat(a1, a2; dims=(1, 2))
@test size(a_dest) == (4, 6)
@test SparseArrayInterface.nstored(a_dest) == 4
@test a_dest[1, 2] == a1[1, 2]
@test a_dest[2, 1] == a1[2, 1]
@test a_dest[3, 4] == a2[1, 1]
@test a_dest[4, 5] == a2[2, 2]

## # Sparse matrix of matrix multiplication
## TODO: Make this work, seems to require
## a custom zero constructor.
Expand Down
Loading