Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NDTensors] Add DiagonalArrays submodule #1225

Merged
merged 24 commits into from
Oct 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion NDTensors/src/BlockSparseArrays/src/blocksparsearray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ function BlockArrays.viewblock(block_arr::BlockSparseArray, block)
# TODO: Make this `Zeros`?
## zero = zeros(eltype(block_arr), block_size)
return block_arr.blocks[blks...] # Fails because zero isn't defined
## return get_nonzero(block_arr.blocks, blks, zero)
end

function Base.getindex(block_arr::BlockSparseArray{T,N}, bi::BlockIndex{N}) where {T,N}
Expand Down
13 changes: 2 additions & 11 deletions NDTensors/src/BlockSparseArrays/src/sparsearray.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# TODO: Define a constructor with a default `zero`.
struct SparseArray{T,N,Zero} <: AbstractArray{T,N}
data::Dictionary{CartesianIndex{N},T}
dims::NTuple{N,Int64}
dims::NTuple{N,Int}
zero::Zero
end

Expand All @@ -20,13 +21,3 @@ end
function Base.getindex(a::SparseArray{T,N}, I::Vararg{Int,N}) where {T,N}
return getindex(a, CartesianIndex(I))
end

## # `getindex` but uses a default if the value is
## # structurally zero.
## function get_nonzero(a::SparseArray{T,N}, I::CartesianIndex{N}, zero) where {T,N}
## @boundscheck checkbounds(a, I)
## return get(a.data, I, zero)
## end
## function get_nonzero(a::SparseArray{T,N}, I::NTuple{N,Int}, zero) where {T,N}
## return get_nonzero(a, CartesianIndex(I), zero)
## end
49 changes: 49 additions & 0 deletions NDTensors/src/DiagonalArrays/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# DiagonalArrays.jl

A Julia `DiagonalArray` type.

````julia
using NDTensors.DiagonalArrays:
DiagonalArray,
densearray,
diagview,
diaglength,
getdiagindex,
setdiagindex!,
setdiag!,
diagcopyto!

d = DiagonalArray([1., 2, 3], 3, 4, 5)
@show d[1, 1, 1] == 1
@show d[2, 2, 2] == 2
@show d[1, 2, 1] == 0

d[2, 2, 2] = 22
@show d[2, 2, 2] == 22

@show diaglength(d) == 3
@show densearray(d) == d
@show getdiagindex(d, 2) == d[2, 2, 2]

setdiagindex!(d, 222, 2)
@show d[2, 2, 2] == 222

a = randn(3, 4, 5)
new_diag = randn(3)
setdiag!(a, new_diag)
diagcopyto!(d, a)

@show diagview(a) == new_diag
@show diagview(d) == new_diag
````

You can generate this README with:
```julia
using Literate
Literate.markdown("examples/README.jl", "."; flavor=Literate.CommonMarkFlavor())
```

---

*This page was generated using [Literate.jl](https://github.com/fredrikekre/Literate.jl).*

42 changes: 42 additions & 0 deletions NDTensors/src/DiagonalArrays/examples/README.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# # DiagonalArrays.jl
#
# A Julia `DiagonalArray` type.

using NDTensors.DiagonalArrays:
DiagonalArray,
densearray,
diagview,
diaglength,
getdiagindex,
setdiagindex!,
setdiag!,
diagcopyto!

d = DiagonalArray([1.0, 2, 3], 3, 4, 5)
@show d[1, 1, 1] == 1
@show d[2, 2, 2] == 2
@show d[1, 2, 1] == 0

d[2, 2, 2] = 22
@show d[2, 2, 2] == 22

@show diaglength(d) == 3
@show densearray(d) == d
@show getdiagindex(d, 2) == d[2, 2, 2]

setdiagindex!(d, 222, 2)
@show d[2, 2, 2] == 222

a = randn(3, 4, 5)
new_diag = randn(3)
setdiag!(a, new_diag)
diagcopyto!(d, a)

@show diagview(a) == new_diag
@show diagview(d) == new_diag

# You can generate this README with:
# ```julia
# using Literate
# Literate.markdown("examples/README.jl", "."; flavor=Literate.CommonMarkFlavor())
# ```
110 changes: 110 additions & 0 deletions NDTensors/src/DiagonalArrays/src/DiagonalArrays.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
module DiagonalArrays

using Compat # allequal
using LinearAlgebra

export DiagonalArray

include("diagview.jl")

struct DefaultZero end

function (::DefaultZero)(eltype::Type, I::CartesianIndex)
return zero(eltype)
end

struct DiagonalArray{T,N,Diag<:AbstractVector{T},Zero} <: AbstractArray{T,N}
diag::Diag
dims::NTuple{N,Int}
zero::Zero
end

function DiagonalArray{T,N}(
diag::AbstractVector{T}, d::Tuple{Vararg{Int,N}}, zero=DefaultZero()
) where {T,N}
return DiagonalArray{T,N,typeof(diag),typeof(zero)}(diag, d, zero)
end

function DiagonalArray{T,N}(
diag::AbstractVector, d::Tuple{Vararg{Int,N}}, zero=DefaultZero()
) where {T,N}
return DiagonalArray{T,N}(T.(diag), d, zero)
end

function DiagonalArray{T,N}(diag::AbstractVector, d::Vararg{Int,N}) where {T,N}
return DiagonalArray{T,N}(diag, d)
end

function DiagonalArray{T}(
diag::AbstractVector, d::Tuple{Vararg{Int,N}}, zero=DefaultZero()
) where {T,N}
return DiagonalArray{T,N}(diag, d, zero)
end

function DiagonalArray{T}(diag::AbstractVector, d::Vararg{Int,N}) where {T,N}
return DiagonalArray{T,N}(diag, d)
end

function DiagonalArray(diag::AbstractVector{T}, d::Tuple{Vararg{Int,N}}) where {T,N}
return DiagonalArray{T,N}(diag, d)
end

function DiagonalArray(diag::AbstractVector{T}, d::Vararg{Int,N}) where {T,N}
return DiagonalArray{T,N}(diag, d)
end

# undef
function DiagonalArray{T,N}(::UndefInitializer, d::Tuple{Vararg{Int,N}}) where {T,N}
return DiagonalArray{T,N}(Vector{T}(undef, minimum(d)), d)
end

function DiagonalArray{T,N}(::UndefInitializer, d::Vararg{Int,N}) where {T,N}
return DiagonalArray{T,N}(undef, d)
end

function DiagonalArray{T}(::UndefInitializer, d::Tuple{Vararg{Int,N}}) where {T,N}
return DiagonalArray{T,N}(undef, d)
end

function DiagonalArray{T}(::UndefInitializer, d::Vararg{Int,N}) where {T,N}
return DiagonalArray{T,N}(undef, d)
end

Base.size(a::DiagonalArray) = a.dims

diagview(a::DiagonalArray) = a.diag
LinearAlgebra.diag(a::DiagonalArray) = copy(diagview(a))

function Base.getindex(a::DiagonalArray{T,N}, I::CartesianIndex{N}) where {T,N}
i = diagindex(a, I)
isnothing(i) && return a.zero(T, I)
return getdiagindex(a, i)
end

function Base.getindex(a::DiagonalArray{T,N}, I::Vararg{Int,N}) where {T,N}
return getindex(a, CartesianIndex(I))
end

function Base.setindex!(a::DiagonalArray{T,N}, v, I::CartesianIndex{N}) where {T,N}
i = diagindex(a, I)
isnothing(i) && return error("Can't set off-diagonal element of DiagonalArray")
setdiagindex!(a, v, i)
return a
end

function Base.setindex!(a::DiagonalArray{T,N}, v, I::Vararg{Int,N}) where {T,N}
a[CartesianIndex(I)] = v
return a
end

# Make dense.
function densearray(a::DiagonalArray)
# TODO: Check this works on GPU.
# TODO: Make use of `a.zero`?
d = similar(diagview(a), size(a))
fill!(d, zero(eltype(a)))
diagcopyto!(d, a)
return d
end

end
54 changes: 54 additions & 0 deletions NDTensors/src/DiagonalArrays/src/diagview.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Convert to an offset along the diagonal.
# Otherwise, return `nothing`.
function diagindex(a::AbstractArray{T,N}, I::CartesianIndex{N}) where {T,N}
!allequal(Tuple(I)) && return nothing
return first(Tuple(I))
end

function diagindex(a::AbstractArray{T,N}, I::Vararg{Int,N}) where {T,N}
return diagindex(a, CartesianIndex(I))
end

function getdiagindex(a::AbstractArray, i::Integer)
return diagview(a)[i]
end

function setdiagindex!(a::AbstractArray, v, i::Integer)
diagview(a)[i] = v
return a
end

function setdiag!(a::AbstractArray, v)
copyto!(diagview(a), v)
return a
end

function diaglength(a::AbstractArray)
# length(diagview(a))
return minimum(size(a))
end

function diagstride(A::AbstractArray)
s = 1
p = 1
for i in 1:(ndims(A) - 1)
p *= size(A, i)
s += p
end
return s
end

function diagindices(A::AbstractArray)
diaglength = minimum(size(A))
maxdiag = LinearIndices(A)[CartesianIndex(ntuple(Returns(diaglength), ndims(A)))]
return 1:diagstride(A):maxdiag
end

function diagview(A::AbstractArray)
return @view A[diagindices(A)]
end

function diagcopyto!(dest::AbstractArray, src::AbstractArray)
copyto!(diagview(dest), diagview(src))
return dest
end
10 changes: 10 additions & 0 deletions NDTensors/src/DiagonalArrays/test/runtests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
using Test
using NDTensors.DiagonalArrays

@testset "Test NDTensors.DiagonalArrays" begin
@testset "README" begin
@test include(
joinpath(pkgdir(DiagonalArrays), "src", "DiagonalArrays", "examples", "README.jl")
) isa Any
end
end
2 changes: 2 additions & 0 deletions NDTensors/src/NDTensors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ using TupleTools

include("SetParameters/src/SetParameters.jl")
using .SetParameters
include("DiagonalArrays/src/DiagonalArrays.jl")
using .DiagonalArrays
include("BlockSparseArrays/src/BlockSparseArrays.jl")
using .BlockSparseArrays
include("SmallVectors/src/SmallVectors.jl")
Expand Down
4 changes: 4 additions & 0 deletions NDTensors/test/DiagonalArrays.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
using Test
using NDTensors

include(joinpath(pkgdir(NDTensors), "src", "DiagonalArrays", "test", "runtests.jl"))
1 change: 1 addition & 0 deletions NDTensors/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ end
@safetestset "NDTensors" begin
@testset "$filename" for filename in [
"BlockSparseArrays.jl",
"DiagonalArrays.jl",
"SetParameters.jl",
"SmallVectors.jl",
"SortedSets.jl",
Expand Down
Loading