-
Notifications
You must be signed in to change notification settings - Fork 125
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' into blockedunitrange_nondual
- Loading branch information
Showing
10 changed files
with
184 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,7 @@ | ||
name = "NDTensors" | ||
uuid = "23ae76d9-e61a-49c4-8f12-3f1a16adf9cf" | ||
authors = ["Matthew Fishman <[email protected]>"] | ||
version = "0.3.63" | ||
version = "0.3.64" | ||
|
||
[deps] | ||
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
7 changes: 7 additions & 0 deletions
7
NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/cat.jl
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
# TODO: Change to `AnyAbstractBlockSparseArray`. | ||
function Base.cat(as::BlockSparseArrayLike...; dims) | ||
# TODO: Use `sparse_cat` instead, currently | ||
# that erroneously allocates too many blocks that are | ||
# zero and shouldn't be stored. | ||
return blocksparse_cat(as...; dims) | ||
end |
26 changes: 26 additions & 0 deletions
26
NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/cat.jl
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
using BlockArrays: AbstractBlockedUnitRange, blockedrange, blocklengths | ||
using NDTensors.SparseArrayInterface: SparseArrayInterface, allocate_cat_output, sparse_cat! | ||
|
||
# TODO: Maybe move to `SparseArrayInterfaceBlockArraysExt`. | ||
# TODO: Handle dual graded unit ranges, for example in a new `SparseArrayInterfaceGradedAxesExt`. | ||
function SparseArrayInterface.axis_cat( | ||
a1::AbstractBlockedUnitRange, a2::AbstractBlockedUnitRange | ||
) | ||
return blockedrange(vcat(blocklengths(a1), blocklengths(a2))) | ||
end | ||
|
||
# that erroneously allocates too many blocks that are | ||
# zero and shouldn't be stored. | ||
function blocksparse_cat!(a_dest::AbstractArray, as::AbstractArray...; dims) | ||
sparse_cat!(blocks(a_dest), blocks.(as)...; dims) | ||
return a_dest | ||
end | ||
|
||
# TODO: Delete this in favor of `sparse_cat`, currently | ||
# that erroneously allocates too many blocks that are | ||
# zero and shouldn't be stored. | ||
function blocksparse_cat(as::AbstractArray...; dims) | ||
a_dest = allocate_cat_output(as...; dims) | ||
blocksparse_cat!(a_dest, as...; dims) | ||
return a_dest | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
4 changes: 4 additions & 0 deletions
4
NDTensors/src/lib/SparseArrayInterface/src/abstractsparsearray/cat.jl
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
# TODO: Change to `AnyAbstractSparseArray`. | ||
function Base.cat(as::SparseArrayLike...; dims) | ||
return sparse_cat(as...; dims) | ||
end |
64 changes: 64 additions & 0 deletions
64
NDTensors/src/lib/SparseArrayInterface/src/sparsearrayinterface/cat.jl
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
unval(x) = x | ||
unval(::Val{x}) where {x} = x | ||
|
||
# TODO: Assert that `a1` and `a2` start at one. | ||
axis_cat(a1::AbstractUnitRange, a2::AbstractUnitRange) = Base.OneTo(length(a1) + length(a2)) | ||
function axis_cat( | ||
a1::AbstractUnitRange, a2::AbstractUnitRange, a_rest::AbstractUnitRange... | ||
) | ||
return axis_cat(axis_cat(a1, a2), a_rest...) | ||
end | ||
function cat_axes(as::AbstractArray...; dims) | ||
return ntuple(length(first(axes.(as)))) do dim | ||
return if dim in unval(dims) | ||
axis_cat(map(axes -> axes[dim], axes.(as))...) | ||
else | ||
axes(first(as))[dim] | ||
end | ||
end | ||
end | ||
|
||
function allocate_cat_output(as::AbstractArray...; dims) | ||
eltype_dest = promote_type(eltype.(as)...) | ||
axes_dest = cat_axes(as...; dims) | ||
# TODO: Promote the block types of the inputs rather than using | ||
# just the first input. | ||
# TODO: Make this customizable with `cat_similar`. | ||
# TODO: Base the zero element constructor on those of the inputs, | ||
# for example block sparse arrays. | ||
return similar(first(as), eltype_dest, axes_dest...) | ||
end | ||
|
||
# https://github.com/JuliaLang/julia/blob/v1.11.1/base/abstractarray.jl#L1748-L1857 | ||
# https://docs.julialang.org/en/v1/base/arrays/#Concatenation-and-permutation | ||
# This is very similar to the `Base.cat` implementation but handles zero values better. | ||
function cat_offset!( | ||
a_dest::AbstractArray, offsets, a1::AbstractArray, a_rest::AbstractArray...; dims | ||
) | ||
inds = ntuple(ndims(a_dest)) do dim | ||
dim in unval(dims) ? offsets[dim] .+ axes(a1, dim) : axes(a_dest, dim) | ||
end | ||
a_dest[inds...] = a1 | ||
new_offsets = ntuple(ndims(a_dest)) do dim | ||
dim in unval(dims) ? offsets[dim] + size(a1, dim) : offsets[dim] | ||
end | ||
cat_offset!(a_dest, new_offsets, a_rest...; dims) | ||
return a_dest | ||
end | ||
function cat_offset!(a_dest::AbstractArray, offsets; dims) | ||
return a_dest | ||
end | ||
|
||
# TODO: Define a generic `cat!` function. | ||
function sparse_cat!(a_dest::AbstractArray, as::AbstractArray...; dims) | ||
offsets = ntuple(zero, ndims(a_dest)) | ||
# TODO: Fill `a_dest` with zeros if needed. | ||
cat_offset!(a_dest, offsets, as...; dims) | ||
return a_dest | ||
end | ||
|
||
function sparse_cat(as::AbstractArray...; dims) | ||
a_dest = allocate_cat_output(as...; dims) | ||
sparse_cat!(a_dest, as...; dims) | ||
return a_dest | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters