From 67c93b9a19cfe00211be595b2d5c50cfcfcece09 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Sun, 13 Oct 2024 23:24:25 +0530 Subject: [PATCH] `diag` for `BandedMatrix`es for off-limit bands (#56065) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Currently, one can only obtain the `diag` for a `BandedMatrix` (such as a `Diagonal`) when the band index is bounded by the size of the matrix. This PR relaxes this requirement to match the behavior for arrays, where `diag` returns an empty vector for a large band index instead of throwing an error. ```julia julia> D = Diagonal(ones(4)) 4×4 Diagonal{Float64, Vector{Float64}}: 1.0 ⋅ ⋅ ⋅ ⋅ 1.0 ⋅ ⋅ ⋅ ⋅ 1.0 ⋅ ⋅ ⋅ ⋅ 1.0 julia> diag(D, 10) Float64[] julia> diag(Array(D), 10) Float64[] ``` Something similar for `SymTridiagonal` is being done in https://github.com/JuliaLang/julia/pull/56014 --- stdlib/LinearAlgebra/src/bidiag.jl | 11 ++++------- stdlib/LinearAlgebra/src/diagonal.jl | 11 ++++------- stdlib/LinearAlgebra/src/tridiag.jl | 13 +++++-------- stdlib/LinearAlgebra/test/bidiag.jl | 4 ++-- stdlib/LinearAlgebra/test/diagonal.jl | 4 ++-- stdlib/LinearAlgebra/test/tridiag.jl | 13 ++++--------- 6 files changed, 21 insertions(+), 35 deletions(-) diff --git a/stdlib/LinearAlgebra/src/bidiag.jl b/stdlib/LinearAlgebra/src/bidiag.jl index 381afd2f09a61..a34df37153cd2 100644 --- a/stdlib/LinearAlgebra/src/bidiag.jl +++ b/stdlib/LinearAlgebra/src/bidiag.jl @@ -404,20 +404,17 @@ end function diag(M::Bidiagonal, n::Integer=0) # every branch call similar(..., ::Int) to make sure the # same vector type is returned independent of n + v = similar(M.dv, max(0, length(M.dv)-abs(n))) if n == 0 - return copyto!(similar(M.dv, length(M.dv)), M.dv) + copyto!(v, M.dv) elseif (n == 1 && M.uplo == 'U') || (n == -1 && M.uplo == 'L') - return copyto!(similar(M.ev, length(M.ev)), M.ev) + copyto!(v, M.ev) elseif -size(M,1) <= n <= size(M,1) - v = similar(M.dv, size(M,1)-abs(n)) for i in eachindex(v) v[i] = M[BandIndex(n,i)] end - return v - else - throw(ArgumentError(LazyString(lazy"requested diagonal, $n, must be at least $(-size(M, 1)) ", - lazy"and at most $(size(M, 2)) for an $(size(M, 1))-by-$(size(M, 2)) matrix"))) end + return v end function +(A::Bidiagonal, B::Bidiagonal) diff --git a/stdlib/LinearAlgebra/src/diagonal.jl b/stdlib/LinearAlgebra/src/diagonal.jl index 0c93024f33a9a..aabfb3e8ba114 100644 --- a/stdlib/LinearAlgebra/src/diagonal.jl +++ b/stdlib/LinearAlgebra/src/diagonal.jl @@ -773,18 +773,15 @@ permutedims(D::Diagonal, perm) = (Base.checkdims_perm(axes(D), axes(D), perm); D function diag(D::Diagonal, k::Integer=0) # every branch call similar(..., ::Int) to make sure the # same vector type is returned independent of k + v = similar(D.diag, max(0, length(D.diag)-abs(k))) if k == 0 - return copyto!(similar(D.diag, length(D.diag)), D.diag) - elseif -size(D,1) <= k <= size(D,1) - v = similar(D.diag, size(D,1)-abs(k)) + copyto!(v, D.diag) + else for i in eachindex(v) v[i] = D[BandIndex(k, i)] end - return v - else - throw(ArgumentError(LazyString(lazy"requested diagonal, $k, must be at least $(-size(D, 1)) ", - lazy"and at most $(size(D, 2)) for an $(size(D, 1))-by-$(size(D, 2)) matrix"))) end + return v end tr(D::Diagonal) = sum(tr, D.diag) det(D::Diagonal) = prod(det, D.diag) diff --git a/stdlib/LinearAlgebra/src/tridiag.jl b/stdlib/LinearAlgebra/src/tridiag.jl index c1af12514e020..d6382d2e16a43 100644 --- a/stdlib/LinearAlgebra/src/tridiag.jl +++ b/stdlib/LinearAlgebra/src/tridiag.jl @@ -662,22 +662,19 @@ issymmetric(S::Tridiagonal) = all(issymmetric, S.d) && all(Iterators.map((x, y) function diag(M::Tridiagonal, n::Integer=0) # every branch call similar(..., ::Int) to make sure the # same vector type is returned independent of n + v = similar(M.d, max(0, length(M.d)-abs(n))) if n == 0 - return copyto!(similar(M.d, length(M.d)), M.d) + copyto!(v, M.d) elseif n == -1 - return copyto!(similar(M.dl, length(M.dl)), M.dl) + copyto!(v, M.dl) elseif n == 1 - return copyto!(similar(M.du, length(M.du)), M.du) + copyto!(v, M.du) elseif abs(n) <= size(M,1) - v = similar(M.d, size(M,1)-abs(n)) for i in eachindex(v) v[i] = M[BandIndex(n,i)] end - return v - else - throw(ArgumentError(LazyString(lazy"requested diagonal, $n, must be at least $(-size(M, 1)) ", - lazy"and at most $(size(M, 2)) for an $(size(M, 1))-by-$(size(M, 2)) matrix"))) end + return v end @inline function Base.isassigned(A::Tridiagonal, i::Int, j::Int) diff --git a/stdlib/LinearAlgebra/test/bidiag.jl b/stdlib/LinearAlgebra/test/bidiag.jl index 628e59debe8b7..df30748e042b5 100644 --- a/stdlib/LinearAlgebra/test/bidiag.jl +++ b/stdlib/LinearAlgebra/test/bidiag.jl @@ -398,8 +398,8 @@ Random.seed!(1) @test (@inferred diag(T))::typeof(dv) == dv @test (@inferred diag(T, uplo === :U ? 1 : -1))::typeof(dv) == ev @test (@inferred diag(T,2))::typeof(dv) == zeros(elty, n-2) - @test_throws ArgumentError diag(T, -n - 1) - @test_throws ArgumentError diag(T, n + 1) + @test isempty(@inferred diag(T, -n - 1)) + @test isempty(@inferred diag(T, n + 1)) # test diag with another wrapped vector type gdv, gev = GenericArray(dv), GenericArray(ev) G = Bidiagonal(gdv, gev, uplo) diff --git a/stdlib/LinearAlgebra/test/diagonal.jl b/stdlib/LinearAlgebra/test/diagonal.jl index 85fe963e3592b..f7a9ccb705de9 100644 --- a/stdlib/LinearAlgebra/test/diagonal.jl +++ b/stdlib/LinearAlgebra/test/diagonal.jl @@ -109,8 +109,8 @@ Random.seed!(1) end @testset "diag" begin - @test_throws ArgumentError diag(D, n+1) - @test_throws ArgumentError diag(D, -n-1) + @test isempty(@inferred diag(D, n+1)) + @test isempty(@inferred diag(D, -n-1)) @test (@inferred diag(D))::typeof(dd) == dd @test (@inferred diag(D, 0))::typeof(dd) == dd @test (@inferred diag(D, 1))::typeof(dd) == zeros(elty, n-1) diff --git a/stdlib/LinearAlgebra/test/tridiag.jl b/stdlib/LinearAlgebra/test/tridiag.jl index b1d52ab8c5679..aa3baec8f6be8 100644 --- a/stdlib/LinearAlgebra/test/tridiag.jl +++ b/stdlib/LinearAlgebra/test/tridiag.jl @@ -287,13 +287,8 @@ end @test (@inferred diag(A, 1))::typeof(d) == (mat_type == Tridiagonal ? du : dl) @test (@inferred diag(A, -1))::typeof(d) == dl @test (@inferred diag(A, n-1))::typeof(d) == zeros(elty, 1) - if A isa SymTridiagonal - @test isempty(@inferred diag(A, -n - 1)) - @test isempty(@inferred diag(A, n + 1)) - else - @test_throws ArgumentError diag(A, -n - 1) - @test_throws ArgumentError diag(A, n + 1) - end + @test isempty(@inferred diag(A, -n - 1)) + @test isempty(@inferred diag(A, n + 1)) GA = mat_type == Tridiagonal ? mat_type(GenericArray.((dl, d, du))...) : mat_type(GenericArray.((d, dl))...) @test (@inferred diag(GA))::typeof(GenericArray(d)) == GenericArray(d) @test (@inferred diag(GA, -1))::typeof(GenericArray(d)) == GenericArray(dl) @@ -527,8 +522,8 @@ end @test @inferred diag(A, -1) == fill(M, n-1) @test_broken diag(A, -2) == fill(M, n-2) @test_broken diag(A, 2) == fill(M, n-2) - @test_throws ArgumentError diag(A, n+1) - @test_throws ArgumentError diag(A, -n-1) + @test isempty(@inferred diag(A, n+1)) + @test isempty(@inferred diag(A, -n-1)) for n in 0:2 dv, ev = fill(M, n), fill(M, max(n-1,0))