Skip to content

Commit

Permalink
Indexing fies for Hermitian matrix (#18)
Browse files Browse the repository at this point in the history
* Indexing fies for Hermitian matrix

* Add missing entry in doc
  • Loading branch information
blegat authored Oct 23, 2020
1 parent f20d084 commit bae44c1
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 22 deletions.
1 change: 1 addition & 0 deletions docs/src/atoms.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
MomentMatrix
moment_matrix
SymMatrix
VectorizedHermitianMatrix
symmetric_setindex!
```

Expand Down
23 changes: 14 additions & 9 deletions src/hermitian_matrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,18 @@ It implement the `AbstractMatrix` interface except for `setindex!` as it might
break its symmetry. The [`symmetric_setindex!`](@ref) function should be used
instead.
"""
struct VectorizedHermitianMatrix{T, U} <: AbstractMatrix{U}
struct VectorizedHermitianMatrix{T, S, U} <: AbstractMatrix{U}
Q::Vector{T}
n::Int
end
function VectorizedHermitianMatrix{T, S}(Q::Vector{T}, n) where {T, S}
V = MA.promote_operation(*, Complex{S}, T)
U = MA.promote_operation(+, T, V)
VectorizedHermitianMatrix{T, S, U}(Q, n)
end
function VectorizedHermitianMatrix{T}(Q::Vector{T}, n) where T
# `typeof(im)` is `Complex{Bool}`
S = MA.promote_operation(*, Complex{Bool}, T)
U = MA.promote_operation(+, T, S)
VectorizedHermitianMatrix{T, U}(Q, n)
return VectorizedHermitianMatrix{T, Bool}(Q, n)
end
function VectorizedHermitianMatrix(Q::Vector{T}, n) where T
return VectorizedHermitianMatrix{T}(Q, n)
Expand Down Expand Up @@ -54,8 +57,8 @@ imag_map(Q::VectorizedHermitianMatrix, i, j) = imag_map(Q.n, i, j)

function vectorized_hermitian_matrix(::Type{T}, f, n, σ) where {T}
Q = Vector{T}(undef, trimap(n, n) + trimap(n - 1, n - 1))
for i in 1:n
for j in 1:i
for j in 1:n
for i in 1:j
x = f(σ[i], σ[j])
Q[trimap(i, j)] = real(x)
if i != j
Expand Down Expand Up @@ -88,14 +91,16 @@ function symmetric_setindex!(Q::VectorizedHermitianMatrix, value, i::Integer, j:
end
end

function Base.getindex(Q::VectorizedHermitianMatrix{T, U}, i::Integer, j::Integer) where {T, U}
I, J = max(i, j), min(i, j)
function Base.getindex(Q::VectorizedHermitianMatrix{T, S, U}, i::Integer, j::Integer) where {T, S, U}
I, J = min(i, j), max(i, j)
r = Q.Q[trimap(I, J)]
if i == j
return convert(U, r)
else
c = Q.Q[imag_map(Q, I, J)]
return r + im * (i < j ? c : -c)
# If `c` is `MathOptInterface.SingleVariable`, `-c` is not defined so
# we prefer calling `-one(S)`.
return r + ((i < j ? one(S) : -one(S)) * im) * c
end
end
Base.getindex(Q::VectorizedHermitianMatrix, I::Tuple) = Q[I...]
Expand Down
14 changes: 6 additions & 8 deletions src/symmatrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,13 @@ Base.map(f::Function, Q::SymMatrix) = SymMatrix(map(f, Q.Q), Q.n)
# div(n*(n+1), 2) - div((n-i+1)*(n-i+2), 2) + j-i+1
#end

# j <= i
function trimap(i, j)
div((i - 1) * i, 2) + j
end
# i <= j
trimap(i, j) = div(j * (j - 1), 2) + i

function trimat(::Type{T}, f, n, σ) where {T}
Q = Vector{T}(undef, trimap(n, n))
for i in 1:n
for j in 1:i
for j in 1:n
for i in 1:j
Q[trimap(i, j)] = f(σ[i], σ[j])
end
end
Expand All @@ -49,11 +47,11 @@ Base.size(Q::SymMatrix) = (Q.n, Q.n)
Set `Q[i, j]` and `Q[j, i]` to the value `value`.
"""
function symmetric_setindex!(Q::SymMatrix, value, i::Integer, j::Integer)
Q.Q[trimap(max(i, j), min(i, j))] = value
Q.Q[trimap(min(i, j), max(i, j))] = value
end

function Base.getindex(Q::SymMatrix, i::Integer, j::Integer)
return Q.Q[trimap(max(i, j), min(i, j))]
return Q.Q[trimap(min(i, j), max(i, j))]
end
Base.getindex(Q::SymMatrix, I::Tuple) = Q[I...]
Base.getindex(Q::SymMatrix, I::CartesianIndex) = Q[I.I]
10 changes: 10 additions & 0 deletions test/hermitian_matrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,14 @@ using MultivariateMoments
@test S.n == 2
@test S.Q == [3, 4, 5, 2]
@test_throws ErrorException symmetric_setindex!(S, im, 1, 1)
M = [1 2 + 3im 4 + 5im
2 - 3im 6 7 + 8im
4 - 5im 7 - 8im 9]
N = MultivariateMoments.vectorized_hermitian_matrix(Int, (i, j) -> M[i, j], 3, 3:-1:1)
@test Matrix(N) == M[3:-1:1, 3:-1:1]
symmetric_setindex!(N, 4 - 5im, 3, 1)
@test Matrix(N) != M[3:-1:1, 3:-1:1]
M[3, 1] = 4 + 5im
M[1, 3] = 4 - 5im
@test Matrix(N) == M[3:-1:1, 3:-1:1]
end
20 changes: 15 additions & 5 deletions test/hermitian_poly.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,20 @@ using MultivariateMoments

@testset "VectorizedHermitianMatrix with polynomial" begin
Mod.@polyvar x y
Q = VectorizedHermitianMatrix([x, x, x, y], 2)
function _tests(Q)
for i in 1:2, j in 1:2
@test (@inferred Q[i, j]) isa eltype(Q)
end
@test x == @inferred Q[1, 1]
@test x + im * y == @inferred Q[1, 2]
@test x - im * y == @inferred Q[2, 1]
@test x == @inferred Q[2, 2]
end
q = [x, x, x, y]
Q = VectorizedHermitianMatrix(q, 2)
@test eltype(Q) == polynomialtype(x * y, Complex{Int})
@test x == @inferred Q[1, 1]
@test x + im * y == @inferred Q[1, 2]
@test x - im * y == @inferred Q[2, 1]
@test x == @inferred Q[2, 2]
_tests(Q)
R = VectorizedHermitianMatrix{eltype(q), Float64}(q, 2)
@test eltype(R) == polynomialtype(x * y, Complex{Float64})
_tests(R)
end

0 comments on commit bae44c1

Please sign in to comment.