Skip to content

Commit

Permalink
More tests, simplify wrapper code
Browse files Browse the repository at this point in the history
  • Loading branch information
mtfishman committed Dec 16, 2024
1 parent 6d2e03c commit 98af42a
Show file tree
Hide file tree
Showing 5 changed files with 294 additions and 87 deletions.
29 changes: 25 additions & 4 deletions src/abstractsparsearrayinterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,20 @@ function storedlength end
function storedpairs end
function storedvalues end

# Generic functionality for converting to a
# dense array, trying to preserve information
# about the array (such as which device it is on).
# TODO: Maybe call `densecopy`?
# TODO: Make sure this actually preserves the device,
# maybe use `TypeParameterAccessors.unwrap_array_type`.
# TODO: Turn into an `@interface` function.
function densearray(a::AbstractArray)

Check warning on line 24 in src/abstractsparsearrayinterface.jl

View check run for this annotation

Codecov / codecov/patch

src/abstractsparsearrayinterface.jl#L24

Added line #L24 was not covered by tests
# TODO: `set_ndims(unwrap_array_type(a), ndims(a))(a)`
# Maybe define `densetype(a) = set_ndims(unwrap_array_type(a), ndims(a))`.
# Or could use `unspecify_parameters(unwrap_array_type(a))(a)`.
return Array(a)

Check warning on line 28 in src/abstractsparsearrayinterface.jl

View check run for this annotation

Codecov / codecov/patch

src/abstractsparsearrayinterface.jl#L28

Added line #L28 was not covered by tests
end

# Minimal interface for `SparseArrayInterface`.
# Fallbacks for dense/non-sparse arrays.
@interface ::AbstractArrayInterface isstored(a::AbstractArray, I::Int...) = true
Expand All @@ -32,8 +46,8 @@ end
@interface ::AbstractArrayInterface function setunstoredindex!(

Check warning on line 46 in src/abstractsparsearrayinterface.jl

View check run for this annotation

Codecov / codecov/patch

src/abstractsparsearrayinterface.jl#L46

Added line #L46 was not covered by tests
a::AbstractArray, value, I::Int...
)
setindex!(a, value, I...)
return a
# TODO: Make this a `MethodError`?
return error("Not implemented.")

Check warning on line 50 in src/abstractsparsearrayinterface.jl

View check run for this annotation

Codecov / codecov/patch

src/abstractsparsearrayinterface.jl#L50

Added line #L50 was not covered by tests
end

# TODO: Use `Base.to_indices`?
Expand Down Expand Up @@ -116,10 +130,17 @@ end

@interface ::AbstractSparseArrayInterface storedvalues(a::AbstractArray) = StoredValues(a)

Check warning on line 131 in src/abstractsparsearrayinterface.jl

View check run for this annotation

Codecov / codecov/patch

src/abstractsparsearrayinterface.jl#L131

Added line #L131 was not covered by tests

@interface ::AbstractSparseArrayInterface function eachstoredindex(as::AbstractArray...)
@interface ::AbstractSparseArrayInterface function eachstoredindex(
a1::AbstractArray, a2::AbstractArray, a_rest::AbstractArray...
)
# TODO: Make this more customizable, say with a function
# `combine/promote_storedindices(a1, a2)`.
return union(eachstoredindex.(as)...)
return union(eachstoredindex.((a1, a2, a_rest...))...)
end

@interface ::AbstractSparseArrayInterface function eachstoredindex(a::AbstractArray)

Check warning on line 141 in src/abstractsparsearrayinterface.jl

View check run for this annotation

Codecov / codecov/patch

src/abstractsparsearrayinterface.jl#L141

Added line #L141 was not covered by tests
# TODO: Use `MethodError`?
return error("Not implemented.")

Check warning on line 143 in src/abstractsparsearrayinterface.jl

View check run for this annotation

Codecov / codecov/patch

src/abstractsparsearrayinterface.jl#L143

Added line #L143 was not covered by tests
end

# We restrict to `I::Vararg{Int,N}` to allow more general functions to handle trailing
Expand Down
177 changes: 96 additions & 81 deletions src/wrappers.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,16 @@
parentvalue_to_value(a::AbstractArray, value) = value
value_to_parentvalue(a::AbstractArray, value) = value

Check warning on line 2 in src/wrappers.jl

View check run for this annotation

Codecov / codecov/patch

src/wrappers.jl#L2

Added line #L2 was not covered by tests
eachstoredparentindex(a::AbstractArray) = eachstoredindex(parent(a))
storedparentvalues(a::AbstractArray) = storedvalues(parent(a))
parentindex_to_index(a::AbstractArray, I::CartesianIndex) = error()
function parentindex_to_index(a::AbstractArray, I::Int...)
return Tuple(parentindex_to_index(a, CartesianIndex(I)))

Check warning on line 7 in src/wrappers.jl

View check run for this annotation

Codecov / codecov/patch

src/wrappers.jl#L4-L7

Added lines #L4 - L7 were not covered by tests
end
index_to_parentindex(a::AbstractArray, I::CartesianIndex) = error()

Check warning on line 9 in src/wrappers.jl

View check run for this annotation

Codecov / codecov/patch

src/wrappers.jl#L9

Added line #L9 was not covered by tests
function index_to_parentindex(a::AbstractArray, I::Int...)
return Tuple(index_to_parentindex(a, CartesianIndex(I)))
end

function cartesianindex_reverse(I::CartesianIndex)
return CartesianIndex(reverse(Tuple(I)))
end
Expand All @@ -7,115 +20,117 @@ tuple_oneto(n) = ntuple(identity, n)
# https://github.com/jipolanco/StaticPermutations.jl?
genperm(v, perm) = map(j -> v[j], perm)

## TODO: Use this and something similar for `Dictionary` to make a faster
## implementation of `storedvalues(::SubArray)`.
## function valuesview(d::Dict, keys)
## return @view d.vals[[Base.ht_keyindex(d, key) for key in keys]]
## end
using LinearAlgebra: Adjoint
function parentindex_to_index(a::Adjoint, I::CartesianIndex)
return cartesianindex_reverse(I)
end
function index_to_parentindex(a::Adjoint, I::CartesianIndex)
return cartesianindex_reverse(I)
end
function parentvalue_to_value(a::Adjoint, value)
return adjoint(value)
end
function value_to_parentvalue(a::Adjoint, value)
return adjoint(value)

Check warning on line 34 in src/wrappers.jl

View check run for this annotation

Codecov / codecov/patch

src/wrappers.jl#L33-L34

Added lines #L33 - L34 were not covered by tests
end

perm(::PermutedDimsArray{<:Any,<:Any,p}) where {p} = p
iperm(::PermutedDimsArray{<:Any,<:Any,<:Any,ip}) where {ip} = ip
function index_to_parentindex(a::PermutedDimsArray, I::CartesianIndex)
return CartesianIndex(genperm(I, iperm(a)))
end
function parentindex_to_index(a::PermutedDimsArray, I::CartesianIndex)
return CartesianIndex(genperm(I, perm(a)))
end

using Base: ReshapedArray
function parentindex_to_index(a::ReshapedArray, I::CartesianIndex)
return CartesianIndices(size(a))[LinearIndices(parent(a))[I]]

Check warning on line 48 in src/wrappers.jl

View check run for this annotation

Codecov / codecov/patch

src/wrappers.jl#L47-L48

Added lines #L47 - L48 were not covered by tests
end
function index_to_parentindex(a::ReshapedArray, I::CartesianIndex)
return CartesianIndices(parent(a))[LinearIndices(size(a))[I]]

Check warning on line 51 in src/wrappers.jl

View check run for this annotation

Codecov / codecov/patch

src/wrappers.jl#L50-L51

Added lines #L50 - L51 were not covered by tests
end

function eachstoredparentindex(a::SubArray)
return filter(eachstoredindex(parent(a))) do I
return all(d -> I[d] parentindices(a)[d], 1:ndims(parent(a)))
end
end
@interface ::AbstractSparseArrayInterface function storedvalues(a::SubArray)
function index_to_parentindex(a::SubArray, I::CartesianIndex)
return CartesianIndex(Base.reindex(parentindices(a), Tuple(I)))
end
function parentindex_to_index(a::SubArray, I::CartesianIndex)
nonscalardims = filter(tuple_oneto(ndims(parent(a)))) do d
return !(parentindices(a)[d] isa Real)
end
return CartesianIndex(
map(nonscalardims) do d
return findfirst(==(I[d]), parentindices(a)[d])
end,
)
end
## TODO: Use this and something similar for `Dictionary` to make a faster
## implementation of `storedvalues(::SubArray)`.
## function valuesview(d::Dict, keys)
## return @view d.vals[[Base.ht_keyindex(d, key) for key in keys]]
## end
function storedparentvalues(a::SubArray)

Check warning on line 77 in src/wrappers.jl

View check run for this annotation

Codecov / codecov/patch

src/wrappers.jl#L77

Added line #L77 was not covered by tests
# We use `StoredValues` rather than `@view`/`SubArray` so that
# it gets interpreted as a dense array.
return StoredValues(parent(a), collect(eachstoredparentindex(a)))

Check warning on line 80 in src/wrappers.jl

View check run for this annotation

Codecov / codecov/patch

src/wrappers.jl#L80

Added line #L80 was not covered by tests
end
@interface ::AbstractSparseArrayInterface function isstored(a::SubArray, I::Int...)
return isstored(parent(a), Base.reindex(parentindices(a), I)...)

using LinearAlgebra: Transpose
function parentindex_to_index(a::Transpose, I::CartesianIndex)
return cartesianindex_reverse(I)

Check warning on line 85 in src/wrappers.jl

View check run for this annotation

Codecov / codecov/patch

src/wrappers.jl#L84-L85

Added lines #L84 - L85 were not covered by tests
end
@interface ::AbstractSparseArrayInterface function getstoredindex(a::SubArray, I::Int...)
return getstoredindex(parent(a), Base.reindex(parentindices(a), I)...)
function index_to_parentindex(a::Transpose, I::CartesianIndex)
return cartesianindex_reverse(I)

Check warning on line 88 in src/wrappers.jl

View check run for this annotation

Codecov / codecov/patch

src/wrappers.jl#L87-L88

Added lines #L87 - L88 were not covered by tests
end
@interface ::AbstractSparseArrayInterface function getunstoredindex(a::SubArray, I::Int...)
return getunstoredindex(parent(a), Base.reindex(parentindices(a), I)...)
function parentvalue_to_value(a::Transpose, value)
return transpose(value)

Check warning on line 91 in src/wrappers.jl

View check run for this annotation

Codecov / codecov/patch

src/wrappers.jl#L90-L91

Added lines #L90 - L91 were not covered by tests
end
@interface ::AbstractSparseArrayInterface function eachstoredindex(a::SubArray)
nonscalardims = filter(tuple_oneto(ndims(parent(a)))) do d
return !(parentindices(a)[d] isa Real)
end
return collect((
CartesianIndex(
map(nonscalardims) do d
return findfirst(==(I[d]), parentindices(a)[d])
end,
) for I in eachstoredparentindex(a)
))
end

perm(::PermutedDimsArray{<:Any,<:Any,p}) where {p} = p
iperm(::PermutedDimsArray{<:Any,<:Any,<:Any,ip}) where {ip} = ip

@interface ::AbstractSparseArrayInterface storedvalues(a::PermutedDimsArray) =
storedvalues(parent(a))
@interface ::AbstractSparseArrayInterface function isstored(a::PermutedDimsArray, I::Int...)
return isstored(parent(a), genperm(I, iperm(a))...)
end
@interface ::AbstractSparseArrayInterface function getstoredindex(
a::PermutedDimsArray, I::Int...
)
return getstoredindex(parent(a), genperm(I, iperm(a))...)
end
@interface ::AbstractSparseArrayInterface function getunstoredindex(
a::PermutedDimsArray, I::Int...
)
return getunstoredindex(parent(a), genperm(I, iperm(a))...)
end
@interface ::AbstractSparseArrayInterface function setstoredindex!(
a::PermutedDimsArray, value, I::Int...
)
# TODO: Should this be `iperm(a)`?
setstoredindex!(parent(a), value, genperm(I, perm(a))...)
return a
end
@interface ::AbstractSparseArrayInterface function setunstoredindex!(
a::PermutedDimsArray, value, I::Int...
)
# TODO: Should this be `iperm(a)`?
setunstoredindex!(parent(a), value, genperm(I, perm(a))...)
return a
end
@interface ::AbstractSparseArrayInterface function eachstoredindex(a::PermutedDimsArray)
# TODO: Make lazy with `Iterators.map`.
return map(collect(eachstoredindex(parent(a)))) do I
return CartesianIndex(genperm(I, perm(a)))
end
function value_to_parentvalue(a::Transpose, value)
return transpose(value)

Check warning on line 94 in src/wrappers.jl

View check run for this annotation

Codecov / codecov/patch

src/wrappers.jl#L93-L94

Added lines #L93 - L94 were not covered by tests
end

for (type, func) in ((:Adjoint, :adjoint), (:Transpose, :transpose))
# TODO: Turn these into `AbstractWrappedSparseArrayInterface` functions?
for type in (:Adjoint, :PermutedDimsArray, :ReshapedArray, :SubArray, :Transpose)
@eval begin
using LinearAlgebra: $type
@interface ::AbstractSparseArrayInterface storedvalues(a::$type) =
storedvalues(parent(a))
@interface ::AbstractSparseArrayInterface function isstored(a::$type, i::Int, j::Int)
return isstored(parent(a), j, i)
@interface ::AbstractSparseArrayInterface storedvalues(a::$type) = storedparentvalues(a)

Check warning on line 100 in src/wrappers.jl

View check run for this annotation

Codecov / codecov/patch

src/wrappers.jl#L100

Added line #L100 was not covered by tests
@interface ::AbstractSparseArrayInterface function isstored(a::$type, I::Int...)
return isstored(parent(a), index_to_parentindex(a, I...)...)
end
@interface ::AbstractSparseArrayInterface function eachstoredindex(a::$type)
# TODO: Make lazy with `Iterators.map`.
return map(cartesianindex_reverse, collect(eachstoredindex(parent(a))))
return map(collect(eachstoredparentindex(a))) do I
return parentindex_to_index(a, I)
end
end
@interface ::AbstractSparseArrayInterface function getstoredindex(
a::$type, i::Int, j::Int
)
return $func(getstoredindex(parent(a), j, i))
@interface ::AbstractSparseArrayInterface function getstoredindex(a::$type, I::Int...)
return parentvalue_to_value(
a, getstoredindex(parent(a), index_to_parentindex(a, I...)...)
)
end
@interface ::AbstractSparseArrayInterface function getunstoredindex(
a::$type, i::Int, j::Int
)
return $func(getunstoredindex(parent(a), j, i))
@interface ::AbstractSparseArrayInterface function getunstoredindex(a::$type, I::Int...)
return parentvalue_to_value(
a, getunstoredindex(parent(a), index_to_parentindex(a, I...)...)
)
end
@interface ::AbstractSparseArrayInterface function setstoredindex!(

Check warning on line 120 in src/wrappers.jl

View check run for this annotation

Codecov / codecov/patch

src/wrappers.jl#L120

Added line #L120 was not covered by tests
a::$type, value, i::Int, j::Int
a::$type, value, I::Int...
)
setstoredindex!(parent(a), $func(value), j, i)
setstoredindex!(

Check warning on line 123 in src/wrappers.jl

View check run for this annotation

Codecov / codecov/patch

src/wrappers.jl#L123

Added line #L123 was not covered by tests
parent(a), value_to_parentvalue(a, value), index_to_parentindex(a, I...)...
)
return a

Check warning on line 126 in src/wrappers.jl

View check run for this annotation

Codecov / codecov/patch

src/wrappers.jl#L126

Added line #L126 was not covered by tests
end
@interface ::AbstractSparseArrayInterface function setunstoredindex!(

Check warning on line 128 in src/wrappers.jl

View check run for this annotation

Codecov / codecov/patch

src/wrappers.jl#L128

Added line #L128 was not covered by tests
a::$type, value, i::Int, j::Int
a::$type, value, I::Int...
)
setunstoredindex!(parent(a), $func(value), j, i)
setunstoredindex!(

Check warning on line 131 in src/wrappers.jl

View check run for this annotation

Codecov / codecov/patch

src/wrappers.jl#L131

Added line #L131 was not covered by tests
parent(a), value_to_parentvalue(a, value), index_to_parentindex(a, I...)...
)
return a

Check warning on line 134 in src/wrappers.jl

View check run for this annotation

Codecov / codecov/patch

src/wrappers.jl#L134

Added line #L134 was not covered by tests
end
end
Expand Down
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
Derive = "a07dfc7f-7d04-4eb5-84cc-a97f051f655a"
Dictionaries = "85a47980-9c8c-11e8-2b9f-f7ca1fa99fb4"
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
Expand Down
3 changes: 1 addition & 2 deletions test/basics/test_basics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ arrayts = (Array, JLArray)
for I in ((1, 2), (CartesianIndex(1, 2),))
b = copy(a)
value = randn(elt)
@allowscalar setunstoredindex!(b, value, I...)
@allowscalar b[I...] == value
@test_throws ErrorException setunstoredindex!(b, value, I...)
end
end
Loading

0 comments on commit 98af42a

Please sign in to comment.