From 40e69d72dc37da92eb502cc5a226d56c4b1315d5 Mon Sep 17 00:00:00 2001
From: araujoms <maltusan@gmail.com>
Date: Tue, 18 Jun 2024 17:18:57 +0200
Subject: [PATCH] a bit more optimization

---
 src/partial_tra.jl | 27 ++++++++++++++++++---------
 1 file changed, 18 insertions(+), 9 deletions(-)

diff --git a/src/partial_tra.jl b/src/partial_tra.jl
index 42725db..4cb986c 100644
--- a/src/partial_tra.jl
+++ b/src/partial_tra.jl
@@ -42,7 +42,7 @@ end
     partial_trace(X::AbstractMatrix, remove::Vector, dims::Vector)
 
 Takes the partial trace of matrix `X` with subsystem dimensions `dims` over the subsystems in `remove`.
-""" partial_trace
+""" partial_trace(X::AbstractMatrix, remove::Vector, dims::Vector)
 
 for (T, limit, wrapper) in
     [(:AbstractMatrix, :dY, :identity), (:(LA.Hermitian), :j, :(LA.Hermitian)), (:(LA.Symmetric), :j, :(LA.Symmetric))]
@@ -116,7 +116,7 @@ export partial_trace
     partial_transpose(X::AbstractMatrix, transp::Vector, dims::Vector)
 
 Takes the partial transpose of matrix `X` with subsystem dimensions `dims` on the subsystems in `transp`.
-""" partial_transpose
+""" partial_transpose(X::AbstractMatrix, transp::Vector, dims::Vector)
 
 for (T, wrapper) in
     [(:AbstractMatrix, :identity), (:(LA.Hermitian), :(LA.Hermitian)), (:(LA.Symmetric), :(LA.Symmetric))]
@@ -134,9 +134,8 @@ for (T, wrapper) in
                 end
             end
 
-            dY = prod(dims)                             # Dimension of the final output Y    
-
-            Y = similar(X, (dY, dY))                    # Final output Y
+            d = size(X, 1)                              # Dimension of the final output Y
+            Y = similar(X, (d, d))                      # Final output Y
 
             tXi = Vector{Int64}(undef, length(dims))    # Tensor indexing of X for row 
             tXj = Vector{Int64}(undef, length(dims))    # Tensor indexing of X for column
@@ -144,9 +143,9 @@ for (T, wrapper) in
             tYi = Vector{Int64}(undef, length(dims))    # Tensor indexing of Y for row 
             tYj = Vector{Int64}(undef, length(dims))    # Tensor indexing of Y for column
 
-            for j in 1:dY
+            @inbounds for j in 1:d
                 _tidx!(tYj, j, dims)
-                for i in 1:j
+                for i in 1:j-1
                     _tidx!(tYi, i, dims)
 
                     for k in keep
@@ -161,15 +160,25 @@ for (T, wrapper) in
 
                     Xi, Xj = _idx(tXi, dims), _idx(tXj, dims)
                     Y[i, j] = X[Xi, Xj]
-                    i != j && (Y[j, i] = X[Xj, Xi])
+                    Y[j, i] = X[Xj, Xi]
+                end
+                for k in keep
+                    tXj[k] = tYj[k]
+                end
+
+                for t in transp
+                    tXj[t] = tYj[t]
                 end
+
+                Xj = _idx(tXj, dims)
+                Y[j, j] = X[Xj, Xj]
             end
             return $wrapper(Y)
         end
     end
 end
 """
-    partial_transpose(X::AbstractMatrix, transp::Vector, dims::Vector)
+    partial_transpose(X::AbstractMatrix, transp::Integer, dims::Vector)
 
 Takes the partial transpose of matrix `X` with subsystem dimensions `dims` on the subsystem `transp`.
 """