Skip to content

Commit

Permalink
incorporate partial trace
Browse files Browse the repository at this point in the history
  • Loading branch information
araujoms committed Jun 7, 2024
1 parent 184a940 commit 8890926
Show file tree
Hide file tree
Showing 6 changed files with 129 additions and 112 deletions.
109 changes: 0 additions & 109 deletions some_old_code/mpartiteLinAlg.jl

This file was deleted.

1 change: 1 addition & 0 deletions src/Ket.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ include("games.jl")
include("measurements.jl")
include("entropy.jl")
include("norms.jl")
include("partial_tra.jl")

import Requires
function __init__()
Expand Down
2 changes: 1 addition & 1 deletion src/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ _realeps(::Type{<:Real}) = 0
Zeroes out real or imaginary parts of `M` that are smaller than `tol`.
"""
function cleanup!(M::AbstractArray{T}; tol = _eps(T)) where {T<:Number} # SD: is it type stable?
function cleanup!(M::AbstractArray{T}; tol = _eps(T)) where {T<:Number}
wrapper = Base.typename(typeof(M)).wrapper
cleanup!(parent(M); tol)
return wrapper(M)
Expand Down
97 changes: 97 additions & 0 deletions src/partial_tra.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
"""
_tidx(idx::Integer, dims::Vector)
Converts a standard index `idx` to a tensor index [i₁, i₂, ...] with subsystems dimensions `dims`.
"""
function _tidx(idx::Integer, dims::Vector{<:Integer})
result = Vector{Int64}(undef, length(dims))
_tidx!(result, idx, dims)
return result
end

function _tidx!(tidx::AbstractVector{<:Integer}, idx::Integer, dims::Vector{<:Integer})
nsys = length(dims)
cidx = idx - 1 # Current index
dr = prod(dims)
for k = 1:nsys
# Everytime you increase a tensor index you shift by the product of remaining dimensions
dr ÷= dims[k]
tidx[k] = (cidx ÷ dr) + 1
cidx %= dr
end
return tidx
end

"""
_idx(tidx::Vector, dims::Vector)
Converts a tensor index `tidx` = [i₁, i₂, ...] with subsystems dimensions `dims` to a standard index.
"""
function _idx(tidx::Vector{<:Integer}, dims::Vector{<:Integer})
i = 1
shift = 1

for k in length(tidx):-1:1
i += (tidx[k] - 1) * shift
shift *= dims[k]
end
return i
end

"""
partial_trace(X::AbstractMatrix, remove::AbstractVector, dims::
AbstractVector)
Takes the partial trace of matrix `X` with subsystem dimensions `dims` over the subsystems in `remove`.
"""
function partial_trace(X::AbstractMatrix{T}, remove::Vector{<:Integer}, dims::Vector{<:Integer}) where {T}
isempty(remove) && return X
length(remove) == length(dims) && return fill(T(LA.tr(X)), 1, 1)

keep = Vector{eltype(remove)}(undef, length(dims) - length(remove)) # Systems kept
counter = 0
for i = 1:length(dims)
if !(i in remove)
counter += 1
keep[counter] = i
end
end
dimsY = dims[keep] # The tensor dimensions of Y
dimsR = dims[remove] # The tensor dimensions of the traced out systems
dY = prod(dimsY) # Dimension of Y
dR = prod(dimsR) # Dimension of system traced out

Y = Matrix{T}(undef, (dY, dY)) # Final output Y
tXi = Vector{Int64}(undef, length(dims)) # Tensor indexing of X for column
tXj = Vector{Int64}(undef, length(dims)) # Tensor indexing of X for row

@views tXikeep = tXi[keep]
@views tXiremove = tXi[remove]
@views tXjkeep = tXj[keep]
@views tXjremove = tXj[remove]

# We loop through Y and find the corresponding element
@inbounds for i = 1:dY
# Find current column tensor index for Y
_tidx!(tXikeep, i, dimsY)
for j = 1:dY
# Find current row tensor index for Y
_tidx!(tXjkeep, j, dimsY)

# Now loop through the diagonal of the traced out systems
val = zero(T)
for k = 1:dR
_tidx!(tXiremove, k, dimsR)
_tidx!(tXjremove, k, dimsR)

# Find (i,j) index of X that we are currently on and add it to total
Xi, Xj = _idx(tXi, dims), _idx(tXj, dims)
val += X[Xi, Xj]
end
Y[i, j] = val
end
end
return Y
end
partial_trace(X::AbstractMatrix, sys::Integer, dims::Vector{<:Integer}) = partial_trace(X, [sys], dims)
export partial_trace
27 changes: 27 additions & 0 deletions test/partial_tra.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
@testset "Partial trace" begin
d1, d2, d3 = 2, 3, 4
for R in [Float64, Double64, Float128, BigFloat]
for T in [R, Complex{R}]
a = randn(T, d1, d1)
b = randn(T, d2, d2)
c = randn(T, d3, d3)
ab = kron(a, b)
ac = kron(a, c)
bc = kron(b, c)
abc = kron(ab, c)
@test partial_trace(ab, [1, 2], [d1, d2])[1] tr(ab)
@test partial_trace(ab, 2, [d1, d2]) a * tr(b)
@test partial_trace(ab, 1, [d1, d2]) b * tr(a)
@test partial_trace(ab, Int64[], [d1, d2]) ab
@test partial_trace(abc, [1, 2, 3], [d1, d2, d3])[1] tr(abc)
@test partial_trace(abc, [2, 3], [d1, d2, d3]) a * tr(b) * tr(c)
@test partial_trace(abc, [1, 3], [d1, d2, d3]) b * tr(a) * tr(c)
@test partial_trace(abc, [1, 2], [d1, d2, d3]) c * tr(a) * tr(b)
@test partial_trace(abc, 3, [d1, d2, d3]) ab * tr(c)
@test partial_trace(abc, 2, [d1, d2, d3]) ac * tr(b)
@test partial_trace(abc, 1, [d1, d2, d3]) bc * tr(a)
@test partial_trace(abc, Int64[], [d1, d2, d3]) abc
end
end
end
#TODO add test with JuMP variables
5 changes: 3 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@ using Quadmath
using Test

include("basic.jl")
include("entropy.jl")
include("measurements.jl")
include("nonlocal.jl")
include("norms.jl")
include("partial_tra.jl")
include("random.jl")
include("nonlocal.jl")
include("entropy.jl")

0 comments on commit 8890926

Please sign in to comment.