Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mtfishman committed Dec 16, 2024
1 parent bd54ec3 commit e38dc08
Show file tree
Hide file tree
Showing 16 changed files with 191 additions and 180 deletions.
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ julia> Pkg.add("BlockSparseArrays")

````julia
using BlockArrays: BlockArrays, BlockedVector, Block, blockedrange
using BlockSparseArrays: BlockSparseArray, block_storedlength
using BlockSparseArrays: BlockSparseArray, blockstoredlength
using Test: @test, @test_broken

function main()
Expand All @@ -62,13 +62,13 @@ function main()
]
b = BlockSparseArray(nz_blocks, d_blocks, i_axes)

@test block_storedlength(b) == 2
@test blockstoredlength(b) == 2

# Blocks with discontiguous underlying data
d_blocks = randn.(nz_block_sizes)
b = BlockSparseArray(nz_blocks, d_blocks, i_axes)

@test block_storedlength(b) == 2
@test blockstoredlength(b) == 2

# Access a block
@test b[Block(1, 1)] == d_blocks[1]
Expand All @@ -92,7 +92,7 @@ function main()
@test b + b Array(b) + Array(b)
@test b + b isa BlockSparseArray
# TODO: Fix this, broken.
@test_broken block_storedlength(b + b) == 2
@test_broken blockstoredlength(b + b) == 2

scaled_b = 2b
@test scaled_b 2Array(b)
Expand Down
8 changes: 4 additions & 4 deletions examples/README.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ julia> Pkg.add("BlockSparseArrays")
# ## Examples

using BlockArrays: BlockArrays, BlockedVector, Block, blockedrange
using BlockSparseArrays: BlockSparseArray, block_storedlength
using BlockSparseArrays: BlockSparseArray, blockstoredlength
using Test: @test, @test_broken

function main()
Expand All @@ -67,13 +67,13 @@ function main()
]
b = BlockSparseArray(nz_blocks, d_blocks, i_axes)

@test block_storedlength(b) == 2
@test blockstoredlength(b) == 2

## Blocks with discontiguous underlying data
d_blocks = randn.(nz_block_sizes)
b = BlockSparseArray(nz_blocks, d_blocks, i_axes)

@test block_storedlength(b) == 2
@test blockstoredlength(b) == 2

## Access a block
@test b[Block(1, 1)] == d_blocks[1]
Expand All @@ -97,7 +97,7 @@ function main()
@test b + b Array(b) + Array(b)
@test b + b isa BlockSparseArray
## TODO: Fix this, broken.
@test_broken block_storedlength(b + b) == 2
@test_broken blockstoredlength(b + b) == 2

scaled_b = 2b
@test scaled_b 2Array(b)
Expand Down
28 changes: 14 additions & 14 deletions ext/BlockSparseArraysGradedUnitRangesExt/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
using Test: @test, @testset
using BlockArrays:
AbstractBlockArray, Block, BlockedOneTo, blockedrange, blocklengths, blocksize
using BlockSparseArrays: BlockSparseArray, block_stored_length
using BlockSparseArrays: BlockSparseArray, blockstoredlength
using GradedUnitRanges:
GradedUnitRanges,
GradedOneTo,
Expand All @@ -13,7 +13,7 @@ using GradedUnitRanges:
gradedrange,
isdual
using LabelledNumbers: label
using SparseArraysBase: stored_length
using SparseArraysBase: storedlength
using SymmetrySectors: U1
using TensorAlgebra: fusedims, splitdims
using LinearAlgebra: adjoint
Expand All @@ -40,8 +40,8 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
@test size(b) == (4, 4, 4, 4)
@test blocksize(b) == (2, 2, 2, 2)
@test blocklengths.(axes(b)) == ([2, 2], [2, 2], [2, 2], [2, 2])
@test stored_length(b) == 32
@test block_stored_length(b) == 2
@test storedlength(b) == 32
@test blockstoredlength(b) == 2
for i in 1:ndims(a)
@test axes(b, i) isa GradedOneTo
end
Expand All @@ -58,8 +58,8 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
@test size(b) == (4, 4, 4, 4)
@test blocksize(b) == (2, 2, 2, 2)
@test blocklengths.(axes(b)) == ([2, 2], [2, 2], [2, 2], [2, 2])
@test stored_length(b) == 256
@test block_stored_length(b) == 16
@test storedlength(b) == 256
@test blockstoredlength(b) == 16
for i in 1:ndims(a)
@test axes(b, i) isa BlockedOneTo{Int}
end
Expand All @@ -71,8 +71,8 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
b = a[2:3, 2:3, 2:3, 2:3]
@test size(b) == (2, 2, 2, 2)
@test blocksize(b) == (2, 2, 2, 2)
@test stored_length(b) == 2
@test block_stored_length(b) == 2
@test storedlength(b) == 2
@test blockstoredlength(b) == 2
for i in 1:ndims(a)
@test axes(b, i) isa GradedOneTo
end
Expand Down Expand Up @@ -156,7 +156,7 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
a[i] = randn(elt, size(a[i]))
end
b = 2 * a
@test block_stored_length(b) == 2
@test blockstoredlength(b) == 2
@test Array(b) == 2 * Array(a)
for i in 1:2
@test axes(b, i) isa GradedOneTo
Expand All @@ -177,7 +177,7 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
a[i] = randn(elt, size(a[i]))
end
b = 2 * a
@test block_stored_length(b) == 2
@test blockstoredlength(b) == 2
@test Array(b) == 2 * Array(a)
for i in 1:2
@test axes(b, i) isa GradedUnitRange
Expand All @@ -204,7 +204,7 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
a[i] = randn(elt, size(a[i]))
end
b = 2 * a
@test block_stored_length(b) == 2
@test blockstoredlength(b) == 2
@test Array(b) == 2 * Array(a)
for i in 1:2
@test axes(b, i) isa GradedUnitRangeDual
Expand All @@ -229,7 +229,7 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
a[i] = randn(elt, size(a[i]))
end
b = 2 * a
@test block_stored_length(b) == 2
@test blockstoredlength(b) == 2
@test Array(b) == 2 * Array(a)
for i in 1:2
@test axes(b, i) isa GradedUnitRangeDual
Expand All @@ -255,7 +255,7 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
a[i] = randn(elt, size(a[i]))
end
b = 2 * a
@test block_stored_length(b) == 2
@test blockstoredlength(b) == 2
@test Array(b) == 2 * Array(a)
@test a[:, :] isa BlockSparseArray
for i in 1:2
Expand All @@ -280,7 +280,7 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
a[i] = randn(elt, size(a[i]))
end
b = 2 * a'
@test block_stored_length(b) == 2
@test blockstoredlength(b) == 2
@test Array(b) == 2 * Array(a)'
for ax in axes(b)
@test ax isa typeof(dual(r))
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module BlockSparseArraysTensorAlgebraExt
using BlockArrays: AbstractBlockedUnitRange
using ..BlockSparseArrays: AbstractBlockSparseArray, block_reshape
using ..BlockSparseArrays: AbstractBlockSparseArray, blockreshape
using GradedUnitRanges: tensor_product
using TensorAlgebra: TensorAlgebra, FusionStyle, BlockReshapeFusion

Expand All @@ -13,12 +13,12 @@ TensorAlgebra.FusionStyle(::AbstractBlockedUnitRange) = BlockReshapeFusion()
function TensorAlgebra.fusedims(
::BlockReshapeFusion, a::AbstractArray, axes::AbstractUnitRange...
)
return block_reshape(a, axes)
return blockreshape(a, axes)

Check warning on line 16 in ext/BlockSparseArraysTensorAlgebraExt/src/BlockSparseArraysTensorAlgebraExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/BlockSparseArraysTensorAlgebraExt/src/BlockSparseArraysTensorAlgebraExt.jl#L16

Added line #L16 was not covered by tests
end

function TensorAlgebra.splitdims(
::BlockReshapeFusion, a::AbstractArray, axes::AbstractUnitRange...
)
return block_reshape(a, axes)
return blockreshape(a, axes)

Check warning on line 22 in ext/BlockSparseArraysTensorAlgebraExt/src/BlockSparseArraysTensorAlgebraExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/BlockSparseArraysTensorAlgebraExt/src/BlockSparseArraysTensorAlgebraExt.jl#L22

Added line #L22 was not covered by tests
end
end
20 changes: 8 additions & 12 deletions src/BlockArraysExtensions/BlockArraysExtensions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -264,17 +264,17 @@ function blocks_to_cartesianindices(d::Dictionary{<:Block})
return Dictionary(blocks_to_cartesianindices(eachindex(d)), d)
end

function block_reshape(a::AbstractArray, dims::Tuple{Vararg{Vector{Int}}})
return block_reshape(a, blockedrange.(dims))
function blockreshape(a::AbstractArray, dims::Tuple{Vararg{Vector{Int}}})
return blockreshape(a, blockedrange.(dims))

Check warning on line 268 in src/BlockArraysExtensions/BlockArraysExtensions.jl

View check run for this annotation

Codecov / codecov/patch

src/BlockArraysExtensions/BlockArraysExtensions.jl#L267-L268

Added lines #L267 - L268 were not covered by tests
end

function block_reshape(a::AbstractArray, dims::Vararg{Vector{Int}})
return block_reshape(a, dims)
function blockreshape(a::AbstractArray, dims::Vararg{Vector{Int}})
return blockreshape(a, dims)

Check warning on line 272 in src/BlockArraysExtensions/BlockArraysExtensions.jl

View check run for this annotation

Codecov / codecov/patch

src/BlockArraysExtensions/BlockArraysExtensions.jl#L271-L272

Added lines #L271 - L272 were not covered by tests
end

tuple_oneto(n) = ntuple(identity, n)

function block_reshape(a::AbstractArray, axes::Tuple{Vararg{AbstractUnitRange}})
function blockreshape(a::AbstractArray, axes::Tuple{Vararg{AbstractUnitRange}})

Check warning on line 277 in src/BlockArraysExtensions/BlockArraysExtensions.jl

View check run for this annotation

Codecov / codecov/patch

src/BlockArraysExtensions/BlockArraysExtensions.jl#L277

Added line #L277 was not covered by tests
reshaped_blocks_a = reshape(blocks(a), blocklength.(axes))
reshaped_a = similar(a, axes)
for I in eachstoredindex(reshaped_blocks_a)

Check warning on line 280 in src/BlockArraysExtensions/BlockArraysExtensions.jl

View check run for this annotation

Codecov / codecov/patch

src/BlockArraysExtensions/BlockArraysExtensions.jl#L280

Added line #L280 was not covered by tests
Expand All @@ -285,8 +285,8 @@ function block_reshape(a::AbstractArray, axes::Tuple{Vararg{AbstractUnitRange}})
return reshaped_a
end

function block_reshape(a::AbstractArray, axes::Vararg{AbstractUnitRange})
return block_reshape(a, axes)
function blockreshape(a::AbstractArray, axes::Vararg{AbstractUnitRange})
return blockreshape(a, axes)

Check warning on line 289 in src/BlockArraysExtensions/BlockArraysExtensions.jl

View check run for this annotation

Codecov / codecov/patch

src/BlockArraysExtensions/BlockArraysExtensions.jl#L288-L289

Added lines #L288 - L289 were not covered by tests
end

function cartesianindices(axes::Tuple, b::Block)
Expand Down Expand Up @@ -473,10 +473,6 @@ function findblocks(axis::AbstractUnitRange, range::AbstractUnitRange)
return findblock(axis, first(range)):findblock(axis, last(range))
end

function block_eachstoredindex(a::AbstractArray)
return Block.(Tuple.(eachstoredindex(blocks(a))))
end

_block(indices) = block(indices)
_block(indices::CartesianIndices) = Block(ntuple(Returns(1), ndims(indices)))

Expand Down Expand Up @@ -550,7 +546,7 @@ function SparseArraysBase.storedlength(a::BlockView)
# TODO: Store whether or not the block is stored already as
# a Bool in `BlockView`.
I = CartesianIndex(Int.(a.block))
# TODO: Use `block_eachstoredindex`.
# TODO: Use `eachblockstoredindex`.
if I eachstoredindex(blocks(parent(a)))
return storedlength(blocks(parent(a))[I])

Check warning on line 551 in src/BlockArraysExtensions/BlockArraysExtensions.jl

View check run for this annotation

Codecov / codecov/patch

src/BlockArraysExtensions/BlockArraysExtensions.jl#L550-L551

Added lines #L550 - L551 were not covered by tests
end
Expand Down
1 change: 0 additions & 1 deletion src/abstractblocksparsearray/abstractblocksparsearray.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
using BlockArrays:
BlockArrays, AbstractBlockArray, Block, BlockIndex, BlockedUnitRange, blocks
using SparseArraysBase: sparse_getindex, sparse_setindex!

# TODO: Delete this. This function was replaced
# by `stored_length` but is still used in `NDTensors`.
Expand Down
9 changes: 4 additions & 5 deletions src/abstractblocksparsearray/cat.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# TODO: Change to `AnyAbstractBlockSparseArray`.
using Derive: @interface, interface

# TODO: Define with `@derive`.
function Base.cat(as::AnyAbstractBlockSparseArray...; dims)
# TODO: Use `sparse_cat` instead, currently
# that erroneously allocates too many blocks that are
# zero and shouldn't be stored.
return blocksparse_cat(as...; dims)
return @interface interface(as...) cat(as...; dims)

Check warning on line 5 in src/abstractblocksparsearray/cat.jl

View check run for this annotation

Codecov / codecov/patch

src/abstractblocksparsearray/cat.jl#L5

Added line #L5 was not covered by tests
end
15 changes: 4 additions & 11 deletions src/abstractblocksparsearray/map.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,13 @@ using ArrayLayouts: LayoutArray
using BlockArrays: blockisequal
using Derive: @interface, interface
using LinearAlgebra: Adjoint, Transpose
using SparseArraysBase:
SparseArraysBase,
SparseArrayStyle,
sparse_map!,
sparse_copy!,
sparse_copyto!,
sparse_permutedims!,
sparse_mapreduce,
sparse_iszero,
sparse_isreal
using SparseArraysBase: SparseArraysBase, SparseArrayStyle

# Returns `Vector{<:CartesianIndices}`
function union_stored_blocked_cartesianindices(as::Vararg{AbstractArray})
combined_axes = combine_axes(axes.(as)...)
stored_blocked_cartesianindices_as = map(as) do a
return blocked_cartesianindices(axes(a), combined_axes, block_eachstoredindex(a))
return blocked_cartesianindices(axes(a), combined_axes, eachblockstoredindex(a))
end
return (stored_blocked_cartesianindices_as...)
end
Expand Down Expand Up @@ -102,11 +93,13 @@ end

# TODO: Move to `blocksparsearrayinterface/map.jl`.
@interface ::AbstractBlockSparseArrayInterface function Base.iszero(a::AbstractArray)

Check warning on line 95 in src/abstractblocksparsearray/map.jl

View check run for this annotation

Codecov / codecov/patch

src/abstractblocksparsearray/map.jl#L95

Added line #L95 was not covered by tests
# TODO: Just call `iszero(blocks(a))`?
return @interface interface(blocks(a)) iszero(blocks(a))

Check warning on line 97 in src/abstractblocksparsearray/map.jl

View check run for this annotation

Codecov / codecov/patch

src/abstractblocksparsearray/map.jl#L97

Added line #L97 was not covered by tests
end

# TODO: Move to `blocksparsearrayinterface/map.jl`.
@interface ::AbstractBlockSparseArrayInterface function Base.isreal(a::AbstractArray)

Check warning on line 101 in src/abstractblocksparsearray/map.jl

View check run for this annotation

Codecov / codecov/patch

src/abstractblocksparsearray/map.jl#L101

Added line #L101 was not covered by tests
# TODO: Just call `isreal(blocks(a))`?
return @interface interface(blocks(a)) isreal(blocks(a))

Check warning on line 103 in src/abstractblocksparsearray/map.jl

View check run for this annotation

Codecov / codecov/patch

src/abstractblocksparsearray/map.jl#L103

Added line #L103 was not covered by tests
end

Expand Down
7 changes: 5 additions & 2 deletions src/abstractblocksparsearray/sparsearrayinterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ using BlockArrays: Block
using SparseArraysBase: SparseArraysBase, eachstoredindex, storedlength, storedvalues

# Structure storing the block sparse storage
# TODO: Delete this in favor of `storedvalues(blocks(a))`,
# and rename `storedblocks(a)` and/or `eachstoredblock(a)`.
struct BlockSparseStorage{Arr<:AbstractBlockSparseArray}
array::Arr
end
Expand Down Expand Up @@ -29,11 +31,12 @@ function Base.iterate(s::BlockSparseStorage, args...)
return iterate(values(s), args...)
end

## TODO: Delete this, define `getstoredindex`, etc.
## function SparseArraysBase.sparse_storage(a::AbstractBlockSparseArray)
## TODO: Bring back this deifinition but check that it makes sense.
## function SparseArraysBase.storedvaluese(a::AbstractBlockSparseArray)
## return BlockSparseStorage(a)
## end

# TODO: Turn this into an `@interface ::AbstractBlockSparseArrayInterface` function.
function SparseArraysBase.storedlength(a::AnyAbstractBlockSparseArray)
return sum(storedlength, storedvalues(blocks(a)); init=zero(Int))

Check warning on line 41 in src/abstractblocksparsearray/sparsearrayinterface.jl

View check run for this annotation

Codecov / codecov/patch

src/abstractblocksparsearray/sparsearrayinterface.jl#L40-L41

Added lines #L40 - L41 were not covered by tests
end
4 changes: 2 additions & 2 deletions src/abstractblocksparsearray/views.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ function BlockArrays.viewblock(
a::AbstractBlockSparseArray{<:Any,N}, block::Vararg{Block{1},N}
) where {N}
I = CartesianIndex(Int.(block))
# TODO: Use `block_eachstoredindex`.
# TODO: Use `eachblockstoredindex`.
if I eachstoredindex(blocks(a))
return blocks(a)[I]
end
Expand Down Expand Up @@ -185,7 +185,7 @@ function BlockArrays.viewblock(
block::Vararg{Block{1},N},
) where {T,N}
I = CartesianIndex(Int.(block))
# TODO: Use `block_eachstoredindex`.
# TODO: Use `eachblockstoredindex`.
if I eachstoredindex(blocks(a))

Check warning on line 189 in src/abstractblocksparsearray/views.jl

View check run for this annotation

Codecov / codecov/patch

src/abstractblocksparsearray/views.jl#L189

Added line #L189 was not covered by tests
return blocks(a)[I]
end
Expand Down
23 changes: 7 additions & 16 deletions src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using Adapt: Adapt, WrappedArray
using ArrayLayouts: zero!
using BlockArrays:
BlockArrays,
AbstractBlockVector,
Expand All @@ -8,7 +9,7 @@ using BlockArrays:
blockedrange,
mortar,
unblock
using Derive: Derive, zero!
using Derive: Derive
using SplitApplyCombine: groupcount
using TypeParameterAccessors: similartype

Expand Down Expand Up @@ -142,24 +143,14 @@ function Base.setindex!(a::AnyAbstractBlockSparseArray{<:Any,1}, value, I::Block
return a
end

function Base.fill!(a::AbstractBlockSparseArray, value)
if iszero(value)
# This drops all of the blocks.
@interface interface(blocks(a)) zero!(blocks(a))
return a
end
@interface interface(blocks(a)) fill!(blocks(a), value)
return a
# TODO: Use `@derive`.
function ArrayLayouts.zero!(a::AnyAbstractBlockSparseArray)
return @interface interface(a) zero!(a)
end

# TODO: Use `@derive`.
function Base.fill!(a::AnyAbstractBlockSparseArray, value)
# TODO: Even if `iszero(value)`, this doesn't drop
# blocks from `a`, and additionally allocates
# new blocks filled with zeros, unlike
# `fill!(a::AbstractBlockSparseArray, value)`.
# Consider changing that behavior when possible.
@interface interface(blocks(a)) fill!(blocks(a), value)
return a
return @interface interface(a) fill!(a, value)

Check warning on line 153 in src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl

View check run for this annotation

Codecov / codecov/patch

src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl#L153

Added line #L153 was not covered by tests
end

# Needed by `BlockArrays` matrix multiplication interface
Expand Down
Loading

0 comments on commit e38dc08

Please sign in to comment.