diff --git a/base/linalg/bidiag.jl b/base/linalg/bidiag.jl index 25f6fd8bae858..e9abd67a17dfd 100644 --- a/base/linalg/bidiag.jl +++ b/base/linalg/bidiag.jl @@ -270,15 +270,15 @@ function triu!(M::Bidiagonal, k::Integer=0) return M end -function diag(M::Bidiagonal{T}, n::Integer=0) where T +function diag(M::Bidiagonal, n::Integer=0) + # every branch call similar(..., ::Int) to make sure the + # same vector type is returned independent of n if n == 0 - return M.dv - elseif n == 1 - return M.uplo == 'U' ? M.ev : zeros(T, size(M,1)-1) - elseif n == -1 - return M.uplo == 'L' ? M.ev : zeros(T, size(M,1)-1) + return copy!(similar(M.dv, length(M.dv)), M.dv) + elseif (n == 1 && M.uplo == 'U') || (n == -1 && M.uplo == 'L') + return copy!(similar(M.ev, length(M.ev)), M.ev) elseif -size(M,1) <= n <= size(M,1) - return zeros(T, size(M,1)-abs(n)) + return fill!(similar(M.dv, size(M,1)-abs(n)), 0) else throw(ArgumentError(string("requested diagonal, $n, must be at least $(-size(M, 1)) ", "and at most $(size(M, 2)) for an $(size(M, 1))-by-$(size(M, 2)) matrix"))) diff --git a/base/linalg/diagonal.jl b/base/linalg/diagonal.jl index 2953cd14dfa67..173a2236cf0e9 100644 --- a/base/linalg/diagonal.jl +++ b/base/linalg/diagonal.jl @@ -322,7 +322,18 @@ transpose(D::Diagonal) = Diagonal(transpose.(D.diag)) adjoint(D::Diagonal{<:Number}) = conj(D) adjoint(D::Diagonal) = Diagonal(adjoint.(D.diag)) -diag(D::Diagonal) = D.diag +function diag(D::Diagonal, k::Integer=0) + # every branch call similar(..., ::Int) to make sure the + # same vector type is returned independent of k + if k == 0 + return copy!(similar(D.diag, length(D.diag)), D.diag) + elseif -size(D,1) <= k <= size(D,1) + return fill!(similar(D.diag, size(D,1)-abs(k)), 0) + else + throw(ArgumentError(string("requested diagonal, $k, must be at least $(-size(D, 1)) ", + "and at most $(size(D, 2)) for an $(size(D, 1))-by-$(size(D, 2)) matrix"))) + end +end trace(D::Diagonal) = sum(D.diag) det(D::Diagonal) = prod(D.diag) logdet(D::Diagonal{<:Real}) = sum(log, D.diag) diff --git a/base/linalg/tridiag.jl b/base/linalg/tridiag.jl index 32596a9713a9e..9a3ec4b8bde60 100644 --- a/base/linalg/tridiag.jl +++ b/base/linalg/tridiag.jl @@ -130,14 +130,16 @@ broadcast(::typeof(ceil), ::Type{T}, M::SymTridiagonal) where {T<:Integer} = Sym transpose(M::SymTridiagonal) = M #Identity operation adjoint(M::SymTridiagonal) = conj(M) -function diag(M::SymTridiagonal{T}, n::Integer=0) where T +function diag(M::SymTridiagonal, n::Integer=0) + # every branch call similar(..., ::Int) to make sure the + # same vector type is returned independent of n absn = abs(n) if absn == 0 - return M.dv + return copy!(similar(M.dv, length(M.dv)), M.dv) elseif absn==1 - return M.ev + return copy!(similar(M.ev, length(M.ev)), M.ev) elseif absn <= size(M,1) - return zeros(T,size(M,1)-absn) + return fill!(similar(M.dv, size(M,1)-absn), 0) else throw(ArgumentError(string("requested diagonal, $n, must be at least $(-size(M, 1)) ", "and at most $(size(M, 2)) for an $(size(M, 1))-by-$(size(M, 2)) matrix"))) @@ -535,14 +537,16 @@ transpose(M::Tridiagonal) = Tridiagonal(M.du, M.d, M.dl) adjoint(M::Tridiagonal) = conj(transpose(M)) function diag(M::Tridiagonal{T}, n::Integer=0) where T + # every branch call similar(..., ::Int) to make sure the + # same vector type is returned independent of n if n == 0 - return M.d + return copy!(similar(M.d, length(M.d)), M.d) elseif n == -1 - return M.dl + return copy!(similar(M.dl, length(M.dl)), M.dl) elseif n == 1 - return M.du + return copy!(similar(M.du, length(M.du)), M.du) elseif abs(n) <= size(M,1) - return zeros(T,size(M,1)-abs(n)) + return fill!(similar(M.d, size(M,1)-abs(n)), 0) else throw(ArgumentError(string("requested diagonal, $n, must be at least $(-size(M, 1)) ", "and at most $(size(M, 2)) for an $(size(M, 1))-by-$(size(M, 2)) matrix"))) diff --git a/test/linalg/bidiag.jl b/test/linalg/bidiag.jl index 6594ad02e40af..bfd745c286500 100644 --- a/test/linalg/bidiag.jl +++ b/test/linalg/bidiag.jl @@ -216,10 +216,18 @@ srand(1) end end - @testset "Diagonals" begin - @test diag(T,2) == zeros(elty, n-2) + @testset "diag" begin + @test (@inferred diag(T))::typeof(dv) == dv + @test (@inferred diag(T, uplo == :U ? 1 : -1))::typeof(dv) == ev + @test (@inferred diag(T,2))::typeof(dv) == zeros(elty, n-2) @test_throws ArgumentError diag(T, -n - 1) - @test_throws ArgumentError diag(T, n + 1) + @test_throws ArgumentError diag(T, n + 1) + # test diag with another wrapped vector type + gdv, gev = GenericArray(dv), GenericArray(ev) + G = Bidiagonal(gdv, gev, uplo) + @test (@inferred diag(G))::typeof(gdv) == gdv + @test (@inferred diag(G, uplo == :U ? 1 : -1))::typeof(gdv) == gev + @test (@inferred diag(G,2))::typeof(gdv) == GenericArray(zeros(elty, n-2)) end @testset "Eigensystems" begin diff --git a/test/linalg/diagonal.jl b/test/linalg/diagonal.jl index 0d6e12b694901..3ba9c7c757099 100644 --- a/test/linalg/diagonal.jl +++ b/test/linalg/diagonal.jl @@ -39,7 +39,6 @@ srand(1) @test Array(imag(D)) == imag(DM) @test parent(D) == dd - @test diag(D) == dd @test D[1,1] == dd[1] @test D[1,2] == 0 @@ -51,6 +50,18 @@ srand(1) end end + @testset "diag" begin + @test_throws ArgumentError diag(D, n+1) + @test_throws ArgumentError diag(D, -n-1) + @test (@inferred diag(D))::typeof(dd) == dd + @test (@inferred diag(D, 0))::typeof(dd) == dd + @test (@inferred diag(D, 1))::typeof(dd) == zeros(elty, n-1) + DG = Diagonal(GenericArray(dd)) + @test (@inferred diag(DG))::typeof(GenericArray(dd)) == GenericArray(dd) + @test (@inferred diag(DG, 1))::typeof(GenericArray(dd)) == GenericArray(zeros(elty, n-1)) + end + + @testset "Simple unary functions" begin for op in (-,) @test op(D)==op(DM) diff --git a/test/linalg/tridiag.jl b/test/linalg/tridiag.jl index 3649ab415e5b8..c69d08745eef9 100644 --- a/test/linalg/tridiag.jl +++ b/test/linalg/tridiag.jl @@ -153,14 +153,17 @@ guardsrand(123) do @test_throws ArgumentError A[2, 3] = 1 # test assignment on the superdiagonal end end - @testset "Diagonal extraction" begin - @test diag(A, 1) === (mat_type == Tridiagonal ? du : dl) - @test diag(A, -1) === dl - @test diag(A, 0) === d - @test diag(A) === d - @test diag(A, n - 1) == zeros(elty, 1) + @testset "diag" begin + @test (@inferred diag(A))::typeof(d) == d + @test (@inferred diag(A, 0))::typeof(d) == d + @test (@inferred diag(A, 1))::typeof(d) == (mat_type == Tridiagonal ? du : dl) + @test (@inferred diag(A, -1))::typeof(d) == dl + @test (@inferred diag(A, n-1))::typeof(d) == zeros(elty, 1) @test_throws ArgumentError diag(A, -n - 1) @test_throws ArgumentError diag(A, n + 1) + GA = mat_type == Tridiagonal ? mat_type(GenericArray.((dl, d, du))...) : mat_type(GenericArray.((d, dl))...) + @test (@inferred diag(GA))::typeof(GenericArray(d)) == GenericArray(d) + @test (@inferred diag(GA, -1))::typeof(GenericArray(d)) == GenericArray(dl) end @testset "Idempotent tests" begin for func in (conj, transpose, adjoint)