Skip to content

Commit

Permalink
Fix keytype and zero
Browse files Browse the repository at this point in the history
  • Loading branch information
Dani Pinyol committed Oct 25, 2024
1 parent 485fd4b commit 8c6a370
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 2 deletions.
5 changes: 3 additions & 2 deletions src/SparseArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import LinearAlgebra: mul!, ldiv!, rdiv!, cholesky, adjoint!, diag, eigen, dot,

import Base: adjoint, argmin, argmax, Array, broadcast, circshift!, complex, Complex,
conj, conj!, convert, copy, copy!, copyto!, count, diff, findall, findmax, findmin,
float, getindex, imag, inv, kron, kron!, length, map, maximum, minimum, permute!, real,
float, getindex, imag, inv, keytype, kron, kron!, length, map, maximum, minimum, permute!, real,
rot180, rotl90, rotr90, setindex!, show, similar, size, sum, transpose,
vcat, hcat, hvcat, cat, vec, reverse, reverse!

Expand Down Expand Up @@ -84,7 +84,8 @@ if Base.USE_GPL_LIBS
include("solvers/spqr.jl")
end

zero(a::AbstractSparseArray) = spzeros(eltype(a), size(a)...)
keytype(::Type{A}) where {Tv, Ti, A<:AbstractSparseArray{Tv,Ti}} = Ti
zero(a::AbstractSparseArray) = spzeros(eltype(a), keytype(a), size(a)...)

LinearAlgebra.diagzero(D::Diagonal{<:AbstractSparseMatrix{T}},i,j) where {T} =
spzeros(T, size(D.diag[i], 1), size(D.diag[j], 2))
Expand Down
9 changes: 9 additions & 0 deletions test/sparsevector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ include("forbidproperties.jl")
### Data

spv_x1 = SparseVector(8, [2, 5, 6], [1.25, -0.75, 3.5])
spv_x1_32 = SparseVector(8, Int32[2, 5, 6], Float32[1.25, -0.75, 3.5])

@test isa(spv_x1, SparseVector{Float64,Int})

Expand Down Expand Up @@ -42,6 +43,14 @@ x1_full[SparseArrays.nonzeroinds(spv_x1)] = nonzeros(spv_x1)
@test @inferred size(y) == (@inferred(length(y))::Int8,)
end

@testset "Non default index type" begin
x = spv_x1_32
for func in [identity, copy, empty, similar, zero]
@test eltype(func(spv_x1_32)) == Float32
@test keytype(func(spv_x1_32)) == Int32
end
end

@testset "isstored" begin
x = spv_x1
stored_inds = [2, 5, 6]
Expand Down

0 comments on commit 8c6a370

Please sign in to comment.