From 9fc409574f67028e8d60a14b1c002cd261b08f45 Mon Sep 17 00:00:00 2001 From: Lucas Date: Mon, 10 Jun 2024 19:26:12 +0200 Subject: [PATCH] add partial_transpose --- TODO | 2 +- src/partial_tra.jl | 64 +++++++++++++++++++++++++++++++++++++++++++++ test/partial_tra.jl | 28 ++++++++++++++++++++ 3 files changed, 93 insertions(+), 1 deletion(-) diff --git a/TODO b/TODO index e5c5227..f017194 100644 --- a/TODO +++ b/TODO @@ -1,4 +1,4 @@ -MA - Partial transpose +MA - Partial transpose (LP) MA - Permute systems MA - DPS hierarchy MA - See-saw diff --git a/src/partial_tra.jl b/src/partial_tra.jl index 957db48..44f90db 100644 --- a/src/partial_tra.jl +++ b/src/partial_tra.jl @@ -99,3 +99,67 @@ Takes the partial trace of matrix `X` with subsystem dimensions `dims` over the """ partial_trace(X::AbstractMatrix, remove::Integer, dims::Vector{<:Integer}) = partial_trace(X, [remove], dims) export partial_trace + +""" + partial_trasnpose(X::AbstractMatrix, transp::Vector, dims::Vector) + +Takes the partial transpose of matrix `X` with subsystem dimensions `dims` on the subsystems in `transp`. +""" +function partial_transpose(X::AbstractMatrix{T}, transp::Vector{<:Integer}, dims::Vector{<:Integer}) where {T} + isempty(transp) && return X + length(transp) == length(dims) && return LA.transpose(X) + + keep = Vector{eltype(transp)}(undef, length(dims) - length(transp)) # Systems kept + counter = 0 + for i = 1:length(dims) + if !(i in transp) + counter += 1 + keep[counter] = i + end + end + + dY = prod(dims) # Dimension of the final output Y + + Y = Matrix{T}(undef, (dY, dY)) # Final output Y + + tXi = Vector{Int64}(undef, length(dims)) # Tensor indexing of X for row + tXj = Vector{Int64}(undef, length(dims)) # Tensor indexing of X for column + + tYi = Vector{Int64}(undef, length(dims)) # Tensor indexing of Y for row + tYj = Vector{Int64}(undef, length(dims)) # Tensor indexing of Y for column + + @views tXikeep = tXi[keep] + @views tXitransp = tXi[transp] + @views tXjkeep = tXj[keep] + @views tXjtransp = tXj[transp] + + @views tYikeep = tYi[keep] + @views tYitransp = tYi[transp] + @views tYjkeep = tYj[keep] + @views tYjtransp = tYj[transp] + + for i in 1:dY + _tidx!(tXi, i, dims) + for j in 1:i + _tidx!(tXj, j, dims) + + tYikeep .= tXikeep + tYitransp .= tXjtransp + + tYjkeep .= tXjkeep + tYjtransp .= tXitransp + + Yi, Yj = _idx(tYi, dims), _idx(tYj, dims) + Y[Yi, Yj] = X[i, j] + Y[Yj, Yi] = X[j, i] + end + end + return Y +end +""" + partial_trasnpose(X::AbstractMatrix, transp::Vector, dims::Vector) + +Takes the partial transpose of matrix `X` with subsystem dimensions `dims` on the subsystem `transp`. +""" +partial_transpose(X::AbstractMatrix, transp::Integer, dims::Vector{<:Integer}) = partial_transpose(X, [transp], dims) +export partial_transpose \ No newline at end of file diff --git a/test/partial_tra.jl b/test/partial_tra.jl index c3d6aab..712f2a9 100644 --- a/test/partial_tra.jl +++ b/test/partial_tra.jl @@ -24,4 +24,32 @@ end end end + +@testset "Partial Transpose" begin + d1, d2, d3 = 2, 3, 4 + for R in [Float64, Double64, Float128, BigFloat] + for T in [R, Complex{R}] + a = randn(T, d1, d1) + b = randn(T, d2, d2) + c = randn(T, d3, d3) + ab = kron(a, b) + ac = kron(a, c) + bc = kron(b, c) + abc = kron(ab, c) + @test partial_transpose(ab, [1, 2], [d1, d2]) ≈ transpose(ab) + @test partial_transpose(ab, 2, [d1, d2]) ≈ kron(a, transpose(b)) + @test partial_transpose(ab, 1, [d1, d2]) ≈ kron(transpose(a), b) + @test partial_transpose(ab, Int64[], [d1, d2]) ≈ ab + @test partial_transpose(abc, [1, 2, 3], [d1, d2, d3]) ≈ transpose(abc) + @test partial_transpose(abc, [2, 3], [d1, d2, d3]) ≈ kron(a, transpose(b), transpose(c)) + @test partial_transpose(abc, [1, 3], [d1, d2, d3]) ≈ kron(transpose(a), b, transpose(c)) + @test partial_transpose(abc, [1, 2], [d1, d2, d3]) ≈ kron(transpose(a), transpose(b), c) + @test partial_transpose(abc, 3, [d1, d2, d3]) ≈ kron(ab, transpose(c)) + @test partial_transpose(abc, 2, [d1, d2, d3]) ≈ kron(a, transpose(b), c) + @test partial_transpose(abc, 1, [d1, d2, d3]) ≈ kron(transpose(a), bc) + @test partial_transpose(abc, Int64[], [d1, d2, d3]) ≈ abc + end + end +end + #TODO add test with JuMP variables