diff --git a/Project.toml b/Project.toml index 5729f109..75ef585b 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "FillArrays" uuid = "1a297f60-69ca-5386-bcde-b61e274b549b" -version = "0.11.4" +version = "0.11.5" [deps] LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" diff --git a/src/FillArrays.jl b/src/FillArrays.jl index 542ac898..cbf9405a 100644 --- a/src/FillArrays.jl +++ b/src/FillArrays.jl @@ -380,10 +380,14 @@ const Eye{T,Axes} = RectOrDiagonal{T,Ones{T,1,Tuple{Axes}}} isone(::SquareEye) = true -for f in (:permutedims, :inv, :triu, :triu!, :tril, :tril!) - @eval ($f)(IM::SquareEye) = IM +# These should actually be in StdLib, LinearAlgebra.jl, for all Diagonal +for f in (:permutedims, :triu, :triu!, :tril, :tril!) + @eval ($f)(IM::Diagonal{<:Any,<:AbstractFill}) = IM end +inv(IM::SquareEye) = IM +inv(IM::Diagonal{<:Any,<:AbstractFill}) = Diagonal(map(inv, IM.diag)) + Eye(n::Integer, m::Integer) = RectDiagonal(Ones(min(n,m)), n, m) Eye{T}(n::Integer, m::Integer) where T = RectDiagonal{T}(Ones{T}(min(n,m)), n, m) function Eye{T}((a,b)::NTuple{2,AbstractUnitRange{Int}}) where T diff --git a/test/runtests.jl b/test/runtests.jl index 7ba5c52b..37d32c82 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -926,11 +926,14 @@ end @testset "Eye identity ops" begin m = Eye(10) + D = Diagonal(Fill(2,10)) for op in (permutedims, inv) @test op(m) === m end + @test permutedims(D) ≡ D + @test inv(D) ≡ Diagonal(Fill(1/2,10)) - for m in (Eye(10), Eye(10, 10), Eye(10, 8), Eye(8, 10)) + for m in (Eye(10), Eye(10, 10), Eye(10, 8), Eye(8, 10), D) for op in (tril, triu, tril!, triu!) @test op(m) === m end