Skip to content

Commit

Permalink
diag for BandedMatrixes for off-limit bands (JuliaLang#56065)
Browse files Browse the repository at this point in the history
Currently, one can only obtain the `diag` for a `BandedMatrix` (such as
a `Diagonal`) when the band index is bounded by the size of the matrix.
This PR relaxes this requirement to match the behavior for arrays, where
`diag` returns an empty vector for a large band index instead of
throwing an error.
```julia
julia> D = Diagonal(ones(4))
4×4 Diagonal{Float64, Vector{Float64}}:
 1.0   ⋅    ⋅    ⋅ 
  ⋅   1.0   ⋅    ⋅ 
  ⋅    ⋅   1.0   ⋅ 
  ⋅    ⋅    ⋅   1.0

julia> diag(D, 10)
Float64[]

julia> diag(Array(D), 10)
Float64[]
```
Something similar for `SymTridiagonal` is being done in
JuliaLang#56014
  • Loading branch information
jishnub authored Oct 13, 2024
1 parent 6029173 commit 67c93b9
Show file tree
Hide file tree
Showing 6 changed files with 21 additions and 35 deletions.
11 changes: 4 additions & 7 deletions stdlib/LinearAlgebra/src/bidiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -404,20 +404,17 @@ end
function diag(M::Bidiagonal, n::Integer=0)
# every branch call similar(..., ::Int) to make sure the
# same vector type is returned independent of n
v = similar(M.dv, max(0, length(M.dv)-abs(n)))
if n == 0
return copyto!(similar(M.dv, length(M.dv)), M.dv)
copyto!(v, M.dv)
elseif (n == 1 && M.uplo == 'U') || (n == -1 && M.uplo == 'L')
return copyto!(similar(M.ev, length(M.ev)), M.ev)
copyto!(v, M.ev)
elseif -size(M,1) <= n <= size(M,1)
v = similar(M.dv, size(M,1)-abs(n))
for i in eachindex(v)
v[i] = M[BandIndex(n,i)]
end
return v
else
throw(ArgumentError(LazyString(lazy"requested diagonal, $n, must be at least $(-size(M, 1)) ",
lazy"and at most $(size(M, 2)) for an $(size(M, 1))-by-$(size(M, 2)) matrix")))
end
return v
end

function +(A::Bidiagonal, B::Bidiagonal)
Expand Down
11 changes: 4 additions & 7 deletions stdlib/LinearAlgebra/src/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -773,18 +773,15 @@ permutedims(D::Diagonal, perm) = (Base.checkdims_perm(axes(D), axes(D), perm); D
function diag(D::Diagonal, k::Integer=0)
# every branch call similar(..., ::Int) to make sure the
# same vector type is returned independent of k
v = similar(D.diag, max(0, length(D.diag)-abs(k)))
if k == 0
return copyto!(similar(D.diag, length(D.diag)), D.diag)
elseif -size(D,1) <= k <= size(D,1)
v = similar(D.diag, size(D,1)-abs(k))
copyto!(v, D.diag)
else
for i in eachindex(v)
v[i] = D[BandIndex(k, i)]
end
return v
else
throw(ArgumentError(LazyString(lazy"requested diagonal, $k, must be at least $(-size(D, 1)) ",
lazy"and at most $(size(D, 2)) for an $(size(D, 1))-by-$(size(D, 2)) matrix")))
end
return v
end
tr(D::Diagonal) = sum(tr, D.diag)
det(D::Diagonal) = prod(det, D.diag)
Expand Down
13 changes: 5 additions & 8 deletions stdlib/LinearAlgebra/src/tridiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -662,22 +662,19 @@ issymmetric(S::Tridiagonal) = all(issymmetric, S.d) && all(Iterators.map((x, y)
function diag(M::Tridiagonal, n::Integer=0)
# every branch call similar(..., ::Int) to make sure the
# same vector type is returned independent of n
v = similar(M.d, max(0, length(M.d)-abs(n)))
if n == 0
return copyto!(similar(M.d, length(M.d)), M.d)
copyto!(v, M.d)
elseif n == -1
return copyto!(similar(M.dl, length(M.dl)), M.dl)
copyto!(v, M.dl)
elseif n == 1
return copyto!(similar(M.du, length(M.du)), M.du)
copyto!(v, M.du)
elseif abs(n) <= size(M,1)
v = similar(M.d, size(M,1)-abs(n))
for i in eachindex(v)
v[i] = M[BandIndex(n,i)]
end
return v
else
throw(ArgumentError(LazyString(lazy"requested diagonal, $n, must be at least $(-size(M, 1)) ",
lazy"and at most $(size(M, 2)) for an $(size(M, 1))-by-$(size(M, 2)) matrix")))
end
return v
end

@inline function Base.isassigned(A::Tridiagonal, i::Int, j::Int)
Expand Down
4 changes: 2 additions & 2 deletions stdlib/LinearAlgebra/test/bidiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -398,8 +398,8 @@ Random.seed!(1)
@test (@inferred diag(T))::typeof(dv) == dv
@test (@inferred diag(T, uplo === :U ? 1 : -1))::typeof(dv) == ev
@test (@inferred diag(T,2))::typeof(dv) == zeros(elty, n-2)
@test_throws ArgumentError diag(T, -n - 1)
@test_throws ArgumentError diag(T, n + 1)
@test isempty(@inferred diag(T, -n - 1))
@test isempty(@inferred diag(T, n + 1))
# test diag with another wrapped vector type
gdv, gev = GenericArray(dv), GenericArray(ev)
G = Bidiagonal(gdv, gev, uplo)
Expand Down
4 changes: 2 additions & 2 deletions stdlib/LinearAlgebra/test/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,8 @@ Random.seed!(1)
end

@testset "diag" begin
@test_throws ArgumentError diag(D, n+1)
@test_throws ArgumentError diag(D, -n-1)
@test isempty(@inferred diag(D, n+1))
@test isempty(@inferred diag(D, -n-1))
@test (@inferred diag(D))::typeof(dd) == dd
@test (@inferred diag(D, 0))::typeof(dd) == dd
@test (@inferred diag(D, 1))::typeof(dd) == zeros(elty, n-1)
Expand Down
13 changes: 4 additions & 9 deletions stdlib/LinearAlgebra/test/tridiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -287,13 +287,8 @@ end
@test (@inferred diag(A, 1))::typeof(d) == (mat_type == Tridiagonal ? du : dl)
@test (@inferred diag(A, -1))::typeof(d) == dl
@test (@inferred diag(A, n-1))::typeof(d) == zeros(elty, 1)
if A isa SymTridiagonal
@test isempty(@inferred diag(A, -n - 1))
@test isempty(@inferred diag(A, n + 1))
else
@test_throws ArgumentError diag(A, -n - 1)
@test_throws ArgumentError diag(A, n + 1)
end
@test isempty(@inferred diag(A, -n - 1))
@test isempty(@inferred diag(A, n + 1))
GA = mat_type == Tridiagonal ? mat_type(GenericArray.((dl, d, du))...) : mat_type(GenericArray.((d, dl))...)
@test (@inferred diag(GA))::typeof(GenericArray(d)) == GenericArray(d)
@test (@inferred diag(GA, -1))::typeof(GenericArray(d)) == GenericArray(dl)
Expand Down Expand Up @@ -527,8 +522,8 @@ end
@test @inferred diag(A, -1) == fill(M, n-1)
@test_broken diag(A, -2) == fill(M, n-2)
@test_broken diag(A, 2) == fill(M, n-2)
@test_throws ArgumentError diag(A, n+1)
@test_throws ArgumentError diag(A, -n-1)
@test isempty(@inferred diag(A, n+1))
@test isempty(@inferred diag(A, -n-1))

for n in 0:2
dv, ev = fill(M, n), fill(M, max(n-1,0))
Expand Down

0 comments on commit 67c93b9

Please sign in to comment.