Skip to content

Commit

Permalink
rewrite diag(::StructuredMatrix[, k=0]) (#24324)
Browse files Browse the repository at this point in the history
- diag(A::StructuredMatrix[, k=0]) now return a new vector,
   such that it does not alias the wrapped vector. This now
   behaves the same for e.g. Matrix, where we get a new vector.
   It is also consistent with getindex (diag is essentially a
   special case of getindex).

 - diag(A::StructuredMatrix{T,VectorType}[, k=0]) now call
   similar(..., ::Int) on every branch to make sure the same
   vector type is always returned, this is usually of type
   VectorType.
  • Loading branch information
fredrikekre authored Oct 26, 2017
1 parent 3709531 commit 4f8438b
Show file tree
Hide file tree
Showing 6 changed files with 63 additions and 26 deletions.
14 changes: 7 additions & 7 deletions base/linalg/bidiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -270,15 +270,15 @@ function triu!(M::Bidiagonal, k::Integer=0)
return M
end

function diag(M::Bidiagonal{T}, n::Integer=0) where T
function diag(M::Bidiagonal, n::Integer=0)
# every branch call similar(..., ::Int) to make sure the
# same vector type is returned independent of n
if n == 0
return M.dv
elseif n == 1
return M.uplo == 'U' ? M.ev : zeros(T, size(M,1)-1)
elseif n == -1
return M.uplo == 'L' ? M.ev : zeros(T, size(M,1)-1)
return copy!(similar(M.dv, length(M.dv)), M.dv)
elseif (n == 1 && M.uplo == 'U') || (n == -1 && M.uplo == 'L')
return copy!(similar(M.ev, length(M.ev)), M.ev)
elseif -size(M,1) <= n <= size(M,1)
return zeros(T, size(M,1)-abs(n))
return fill!(similar(M.dv, size(M,1)-abs(n)), 0)
else
throw(ArgumentError(string("requested diagonal, $n, must be at least $(-size(M, 1)) ",
"and at most $(size(M, 2)) for an $(size(M, 1))-by-$(size(M, 2)) matrix")))
Expand Down
13 changes: 12 additions & 1 deletion base/linalg/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,18 @@ transpose(D::Diagonal) = Diagonal(transpose.(D.diag))
adjoint(D::Diagonal{<:Number}) = conj(D)
adjoint(D::Diagonal) = Diagonal(adjoint.(D.diag))

diag(D::Diagonal) = D.diag
function diag(D::Diagonal, k::Integer=0)
# every branch call similar(..., ::Int) to make sure the
# same vector type is returned independent of k
if k == 0
return copy!(similar(D.diag, length(D.diag)), D.diag)
elseif -size(D,1) <= k <= size(D,1)
return fill!(similar(D.diag, size(D,1)-abs(k)), 0)
else
throw(ArgumentError(string("requested diagonal, $k, must be at least $(-size(D, 1)) ",
"and at most $(size(D, 2)) for an $(size(D, 1))-by-$(size(D, 2)) matrix")))
end
end
trace(D::Diagonal) = sum(D.diag)
det(D::Diagonal) = prod(D.diag)
logdet(D::Diagonal{<:Real}) = sum(log, D.diag)
Expand Down
20 changes: 12 additions & 8 deletions base/linalg/tridiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -130,14 +130,16 @@ broadcast(::typeof(ceil), ::Type{T}, M::SymTridiagonal) where {T<:Integer} = Sym
transpose(M::SymTridiagonal) = M #Identity operation
adjoint(M::SymTridiagonal) = conj(M)

function diag(M::SymTridiagonal{T}, n::Integer=0) where T
function diag(M::SymTridiagonal, n::Integer=0)
# every branch call similar(..., ::Int) to make sure the
# same vector type is returned independent of n
absn = abs(n)
if absn == 0
return M.dv
return copy!(similar(M.dv, length(M.dv)), M.dv)
elseif absn==1
return M.ev
return copy!(similar(M.ev, length(M.ev)), M.ev)
elseif absn <= size(M,1)
return zeros(T,size(M,1)-absn)
return fill!(similar(M.dv, size(M,1)-absn), 0)
else
throw(ArgumentError(string("requested diagonal, $n, must be at least $(-size(M, 1)) ",
"and at most $(size(M, 2)) for an $(size(M, 1))-by-$(size(M, 2)) matrix")))
Expand Down Expand Up @@ -535,14 +537,16 @@ transpose(M::Tridiagonal) = Tridiagonal(M.du, M.d, M.dl)
adjoint(M::Tridiagonal) = conj(transpose(M))

function diag(M::Tridiagonal{T}, n::Integer=0) where T
# every branch call similar(..., ::Int) to make sure the
# same vector type is returned independent of n
if n == 0
return M.d
return copy!(similar(M.d, length(M.d)), M.d)
elseif n == -1
return M.dl
return copy!(similar(M.dl, length(M.dl)), M.dl)
elseif n == 1
return M.du
return copy!(similar(M.du, length(M.du)), M.du)
elseif abs(n) <= size(M,1)
return zeros(T,size(M,1)-abs(n))
return fill!(similar(M.d, size(M,1)-abs(n)), 0)
else
throw(ArgumentError(string("requested diagonal, $n, must be at least $(-size(M, 1)) ",
"and at most $(size(M, 2)) for an $(size(M, 1))-by-$(size(M, 2)) matrix")))
Expand Down
14 changes: 11 additions & 3 deletions test/linalg/bidiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -216,10 +216,18 @@ srand(1)
end
end

@testset "Diagonals" begin
@test diag(T,2) == zeros(elty, n-2)
@testset "diag" begin
@test (@inferred diag(T))::typeof(dv) == dv
@test (@inferred diag(T, uplo == :U ? 1 : -1))::typeof(dv) == ev
@test (@inferred diag(T,2))::typeof(dv) == zeros(elty, n-2)
@test_throws ArgumentError diag(T, -n - 1)
@test_throws ArgumentError diag(T, n + 1)
@test_throws ArgumentError diag(T, n + 1)
# test diag with another wrapped vector type
gdv, gev = GenericArray(dv), GenericArray(ev)
G = Bidiagonal(gdv, gev, uplo)
@test (@inferred diag(G))::typeof(gdv) == gdv
@test (@inferred diag(G, uplo == :U ? 1 : -1))::typeof(gdv) == gev
@test (@inferred diag(G,2))::typeof(gdv) == GenericArray(zeros(elty, n-2))
end

@testset "Eigensystems" begin
Expand Down
13 changes: 12 additions & 1 deletion test/linalg/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ srand(1)
@test Array(imag(D)) == imag(DM)

@test parent(D) == dd
@test diag(D) == dd
@test D[1,1] == dd[1]
@test D[1,2] == 0

Expand All @@ -51,6 +50,18 @@ srand(1)
end
end

@testset "diag" begin
@test_throws ArgumentError diag(D, n+1)
@test_throws ArgumentError diag(D, -n-1)
@test (@inferred diag(D))::typeof(dd) == dd
@test (@inferred diag(D, 0))::typeof(dd) == dd
@test (@inferred diag(D, 1))::typeof(dd) == zeros(elty, n-1)
DG = Diagonal(GenericArray(dd))
@test (@inferred diag(DG))::typeof(GenericArray(dd)) == GenericArray(dd)
@test (@inferred diag(DG, 1))::typeof(GenericArray(dd)) == GenericArray(zeros(elty, n-1))
end


@testset "Simple unary functions" begin
for op in (-,)
@test op(D)==op(DM)
Expand Down
15 changes: 9 additions & 6 deletions test/linalg/tridiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -153,14 +153,17 @@ guardsrand(123) do
@test_throws ArgumentError A[2, 3] = 1 # test assignment on the superdiagonal
end
end
@testset "Diagonal extraction" begin
@test diag(A, 1) === (mat_type == Tridiagonal ? du : dl)
@test diag(A, -1) === dl
@test diag(A, 0) === d
@test diag(A) === d
@test diag(A, n - 1) == zeros(elty, 1)
@testset "diag" begin
@test (@inferred diag(A))::typeof(d) == d
@test (@inferred diag(A, 0))::typeof(d) == d
@test (@inferred diag(A, 1))::typeof(d) == (mat_type == Tridiagonal ? du : dl)
@test (@inferred diag(A, -1))::typeof(d) == dl
@test (@inferred diag(A, n-1))::typeof(d) == zeros(elty, 1)
@test_throws ArgumentError diag(A, -n - 1)
@test_throws ArgumentError diag(A, n + 1)
GA = mat_type == Tridiagonal ? mat_type(GenericArray.((dl, d, du))...) : mat_type(GenericArray.((d, dl))...)
@test (@inferred diag(GA))::typeof(GenericArray(d)) == GenericArray(d)
@test (@inferred diag(GA, -1))::typeof(GenericArray(d)) == GenericArray(dl)
end
@testset "Idempotent tests" begin
for func in (conj, transpose, adjoint)
Expand Down

0 comments on commit 4f8438b

Please sign in to comment.