From 13a31985856edc811c12e37ff298bcf90a8b0394 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Wed, 13 Nov 2024 13:32:31 -0500 Subject: [PATCH] Add cat implementation --- .../src/BlockSparseArrays.jl | 2 + .../src/abstractblocksparsearray/cat.jl | 7 ++ .../src/blocksparsearrayinterface/cat.jl | 26 ++++++++ .../src/SparseArrayInterface.jl | 2 + .../src/abstractsparsearray/cat.jl | 4 ++ .../src/sparsearrayinterface/cat.jl | 64 +++++++++++++++++++ .../src/sparsearrayinterface/indexing.jl | 2 + 7 files changed, 107 insertions(+) create mode 100644 NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/cat.jl create mode 100644 NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/cat.jl create mode 100644 NDTensors/src/lib/SparseArrayInterface/src/abstractsparsearray/cat.jl create mode 100644 NDTensors/src/lib/SparseArrayInterface/src/sparsearrayinterface/cat.jl diff --git a/NDTensors/src/lib/BlockSparseArrays/src/BlockSparseArrays.jl b/NDTensors/src/lib/BlockSparseArrays/src/BlockSparseArrays.jl index d0a1e4cdd7..dc43ba560a 100644 --- a/NDTensors/src/lib/BlockSparseArrays/src/BlockSparseArrays.jl +++ b/NDTensors/src/lib/BlockSparseArrays/src/BlockSparseArrays.jl @@ -7,6 +7,7 @@ include("blocksparsearrayinterface/broadcast.jl") include("blocksparsearrayinterface/map.jl") include("blocksparsearrayinterface/arraylayouts.jl") include("blocksparsearrayinterface/views.jl") +include("blocksparsearrayinterface/cat.jl") include("abstractblocksparsearray/abstractblocksparsearray.jl") include("abstractblocksparsearray/wrappedabstractblocksparsearray.jl") include("abstractblocksparsearray/abstractblocksparsematrix.jl") @@ -17,6 +18,7 @@ include("abstractblocksparsearray/sparsearrayinterface.jl") include("abstractblocksparsearray/broadcast.jl") include("abstractblocksparsearray/map.jl") include("abstractblocksparsearray/linearalgebra.jl") +include("abstractblocksparsearray/cat.jl") include("blocksparsearray/defaults.jl") include("blocksparsearray/blocksparsearray.jl") include("BlockArraysSparseArrayInterfaceExt/BlockArraysSparseArrayInterfaceExt.jl") diff --git a/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/cat.jl b/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/cat.jl new file mode 100644 index 0000000000..eac4ea1b02 --- /dev/null +++ b/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/cat.jl @@ -0,0 +1,7 @@ +# TODO: Change to `AnyAbstractBlockSparseArray`. +function Base.cat(as::BlockSparseArrayLike...; dims) + # TODO: Use `sparse_cat` instead, currently + # that erroneously allocates too many blocks that are + # zero and shouldn't be stored. + return blocksparse_cat(as...; dims) +end diff --git a/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/cat.jl b/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/cat.jl new file mode 100644 index 0000000000..22d1a24a02 --- /dev/null +++ b/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/cat.jl @@ -0,0 +1,26 @@ +using BlockArrays: AbstractBlockedUnitRange, blockedrange, blocklengths +using NDTensors.SparseArrayInterface: SparseArrayInterface, allocate_cat_output, sparse_cat! + +# TODO: Maybe move to `SparseArrayInterfaceBlockArraysExt`. +# TODO: Handle dual graded unit ranges, for example in a new `SparseArrayInterfaceGradedAxesExt`. +function SparseArrayInterface.axis_cat( + a1::AbstractBlockedUnitRange, a2::AbstractBlockedUnitRange +) + return blockedrange(vcat(blocklengths(a1), blocklengths(a2))) +end + +# that erroneously allocates too many blocks that are +# zero and shouldn't be stored. +function blocksparse_cat!(a_dest::AbstractArray, as::AbstractArray...; dims) + sparse_cat!(blocks(a_dest), blocks.(as)...; dims) + return a_dest +end + +# TODO: Delete this in favor of `sparse_cat`, currently +# that erroneously allocates too many blocks that are +# zero and shouldn't be stored. +function blocksparse_cat(as::AbstractArray...; dims) + a_dest = allocate_cat_output(as...; dims) + blocksparse_cat!(a_dest, as...; dims) + return a_dest +end diff --git a/NDTensors/src/lib/SparseArrayInterface/src/SparseArrayInterface.jl b/NDTensors/src/lib/SparseArrayInterface/src/SparseArrayInterface.jl index 33647bf476..f192225f27 100644 --- a/NDTensors/src/lib/SparseArrayInterface/src/SparseArrayInterface.jl +++ b/NDTensors/src/lib/SparseArrayInterface/src/SparseArrayInterface.jl @@ -12,6 +12,7 @@ include("sparsearrayinterface/broadcast.jl") include("sparsearrayinterface/conversion.jl") include("sparsearrayinterface/wrappers.jl") include("sparsearrayinterface/zero.jl") +include("sparsearrayinterface/cat.jl") include("sparsearrayinterface/SparseArrayInterfaceLinearAlgebraExt.jl") include("abstractsparsearray/abstractsparsearray.jl") include("abstractsparsearray/abstractsparsematrix.jl") @@ -24,6 +25,7 @@ include("abstractsparsearray/broadcast.jl") include("abstractsparsearray/map.jl") include("abstractsparsearray/baseinterface.jl") include("abstractsparsearray/convert.jl") +include("abstractsparsearray/cat.jl") include("abstractsparsearray/SparseArrayInterfaceSparseArraysExt.jl") include("abstractsparsearray/SparseArrayInterfaceLinearAlgebraExt.jl") end diff --git a/NDTensors/src/lib/SparseArrayInterface/src/abstractsparsearray/cat.jl b/NDTensors/src/lib/SparseArrayInterface/src/abstractsparsearray/cat.jl new file mode 100644 index 0000000000..a9db504e38 --- /dev/null +++ b/NDTensors/src/lib/SparseArrayInterface/src/abstractsparsearray/cat.jl @@ -0,0 +1,4 @@ +# TODO: Change to `AnyAbstractSparseArray`. +function Base.cat(as::SparseArrayLike...; dims) + return sparse_cat(as...; dims) +end diff --git a/NDTensors/src/lib/SparseArrayInterface/src/sparsearrayinterface/cat.jl b/NDTensors/src/lib/SparseArrayInterface/src/sparsearrayinterface/cat.jl new file mode 100644 index 0000000000..9f2b3179a5 --- /dev/null +++ b/NDTensors/src/lib/SparseArrayInterface/src/sparsearrayinterface/cat.jl @@ -0,0 +1,64 @@ +unval(x) = x +unval(::Val{x}) where {x} = x + +# TODO: Assert that `a1` and `a2` start at one. +axis_cat(a1::AbstractUnitRange, a2::AbstractUnitRange) = Base.OneTo(length(a1) + length(a2)) +function axis_cat( + a1::AbstractUnitRange, a2::AbstractUnitRange, a_rest::AbstractUnitRange... +) + return axis_cat(axis_cat(a1, a2), a_rest...) +end +function cat_axes(as::AbstractArray...; dims) + return ntuple(length(first(axes.(as)))) do dim + return if dim in unval(dims) + axis_cat(map(axes -> axes[dim], axes.(as))...) + else + axes(first(as))[dim] + end + end +end + +function allocate_cat_output(as::AbstractArray...; dims) + eltype_dest = promote_type(eltype.(as)...) + axes_dest = cat_axes(as...; dims) + # TODO: Promote the block types of the inputs rather than using + # just the first input. + # TODO: Make this customizable with `cat_similar`. + # TODO: Base the zero element constructor on those of the inputs, + # for example block sparse arrays. + return similar(first(as), eltype_dest, axes_dest...) +end + +# https://github.com/JuliaLang/julia/blob/v1.11.1/base/abstractarray.jl#L1748-L1857 +# https://docs.julialang.org/en/v1/base/arrays/#Concatenation-and-permutation +# This is very similar to the `Base.cat` implementation but handles zero values better. +function cat_offset!( + a_dest::AbstractArray, offsets, a1::AbstractArray, a_rest::AbstractArray...; dims +) + inds = ntuple(ndims(a_dest)) do dim + dim in unval(dims) ? offsets[dim] .+ axes(a1, dim) : axes(a_dest, dim) + end + a_dest[inds...] = a1 + new_offsets = ntuple(ndims(a_dest)) do dim + dim in unval(dims) ? offsets[dim] + size(a1, dim) : offsets[dim] + end + cat_offset!(a_dest, new_offsets, a_rest...; dims) + return a_dest +end +function cat_offset!(a_dest::AbstractArray, offsets; dims) + return a_dest +end + +# TODO: Define a generic `cat!` function. +function sparse_cat!(a_dest::AbstractArray, as::AbstractArray...; dims) + offsets = ntuple(zero, ndims(a_dest)) + # TODO: Fill `a_dest` with zeros if needed. + cat_offset!(a_dest, offsets, as...; dims) + return a_dest +end + +function sparse_cat(as::AbstractArray...; dims) + a_dest = allocate_cat_output(as...; dims) + sparse_cat!(a_dest, as...; dims) + return a_dest +end diff --git a/NDTensors/src/lib/SparseArrayInterface/src/sparsearrayinterface/indexing.jl b/NDTensors/src/lib/SparseArrayInterface/src/sparsearrayinterface/indexing.jl index e7fe9462b2..1119fb93ef 100644 --- a/NDTensors/src/lib/SparseArrayInterface/src/sparsearrayinterface/indexing.jl +++ b/NDTensors/src/lib/SparseArrayInterface/src/sparsearrayinterface/indexing.jl @@ -144,6 +144,8 @@ function sparse_setindex!(a::AbstractArray, value, I::CartesianIndex{1}) end # Slicing +# TODO: Make this handle more general slicing operations, +# base it off of `ArrayLayouts.sub_materialize`. function sparse_setindex!(a::AbstractArray, value, I::AbstractUnitRange...) inds = CartesianIndices(I) for i in stored_indices(value)