Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update constraints.py #69

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 38 additions & 43 deletions pyci/rdm/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,19 +44,24 @@ def find_closest_sdp(dm, constraint, alpha):
Value of the correct trace.

"""
#reshape 2DM (rank-4 tensor) into matrix (rank-2 tensor)
n = gamma.shape[1]
gamma_reshaped = np.reshape(gamma, (n**2, n**2))
#symmetrize if necessary
constrained = constraint(dm)
#constrained = constraint(dm)
L = constrained + constrained.conj().T
#find eigendecomposition
vals, vecs = np.linalg.eigh(L)
#calculate the shift, sigma0
sigma0 = calculate_shift(vals, alpha)

#calculate the closest semidefinite positive matrix with correct trace
L_closest = vals @ np.diag(vecs - sigma0) @ vecs.conj().T

L_closest = vecs @ np.diag(vals - sigma0) @ vecs.conj().T
#convert back into density matrix via reshaping
gamma_new = np.reshape(L_closest, (n, n, n, n))

# return the reconstructed density matrix
return constraint(L_closest).conj().T
return constraint(gamma_new).conj().T # ? check ?


def calculate_shift(eigenvalues, alpha):
Expand All @@ -81,8 +86,19 @@ def calculate_shift(eigenvalues, alpha):
def calc_P():
pass

def calc_Q():
pass
def calc_Q(gamma):
eye = np.eye(gamma.shape[0])
a_bar = np.einsum('abgb -> ag', gamma)
rho = 1/(N - 1) * a_bar
tr_gamma = np.einsum('aaaa', gamma)
term_1 = np.einsum('ag, bd -> abgd', eye, eye)
term_2 = np.einsum('ad, bg -> abgd', eye, eye)
term_3 = gamma
term_4 = np.einsum('ag, bd -> abgd', eye, rho)
term_5 = np.einsum('bg, ad -> abgd', eye, rho)
term_6 = np.einsum('ad, bg -> abgd', eye, rho)
term_7 = np.einsum('bd, ag -> abgd', eye, rho)
return (2*tr_gamma/(N * (N - 1)) * (term_1 - term_2)) + term_3 - term_4 + term_5 + term_6 - term_7

def calc_G(gamma, N, conjugate=False):
"""
Expand Down Expand Up @@ -114,13 +130,13 @@ def calc_G(gamma, N, conjugate=False):
rho = 1/(N - 1) * a_bar
if not conjugate:
return np.einsum('bd, ag -> abgd', eye, rho) - np.einsum('adgb -> abgd', gamma)
term_1 = 1/(N-1) *\
(np.einsum('bd, ag -> abgd', eye, a_bar) - np.einsum('ad, bg -> abgd', eye, a_bar) -\
np.einsum('bg, ad -> abgd', eye, a_bar) + np.einsum('ag, bd -> abgd', eye, a_bar)
)
term_2 = -np.einsum('adgb -> abgd', gamma) + np.einsum('bdga -> abgd', gamma) +\
np.einsum('agdb -> abgd', gamma) - np.einsum('bgda -> abgd', gamma)
return term_1 + term_2
else:
term_1 = 1/(N-1) *\
(np.einsum('bd, ag -> abgd', eye, a_bar) - np.einsum('ad, bg -> abgd', eye, a_bar) -\
np.einsum('bg, ad -> abgd', eye, a_bar) + np.einsum('ag, bd -> abgd', eye, a_bar))
term_2 = -np.einsum('adgb -> abgd', gamma) + np.einsum('bdga -> abgd', gamma) +\
np.einsum('agdb -> abgd', gamma) - np.einsum('bgda -> abgd', gamma)
return term_1 + term_2

def calc_T1(gamma, N, conjugate):
"""
Expand Down Expand Up @@ -203,7 +219,9 @@ def calc_T1(gamma, N, conjugate):
np.einsum('be, agdz -> abgdez', eye, gamma) - np.einsum('ae, bgdz -> abgdez', eye, gamma) + \
np.einsum('gd, abez -> abgdez', eye, gamma) - np.einsum('bd, agez -> abgdez', eye, gamma) + \
np.einsum('ad, bgez -> abgdez', eye, gamma)
return term_1 + term_2 + term_3 + term_4 + term_5
k1 = term_1 + term_2 + term_3 + term_4 + term_5
DM = 1/(N-2) * np.einsum('ablgdl -> abgd', k1) # ????? check please
return DM

else:
tr_gamma = np.einsum('aaaaaa', gamma)
Expand All @@ -216,7 +234,7 @@ def calc_T1(gamma, N, conjugate):
gamma_ad = np.einsum('agdg -> ad', gamma_abgd)
gamma_bd = np.einsum('abda -> bd', gamma_abgd)

term_2 = - 2 / (2*N - 2)*\
term_2 = - 1 / (2*N - 2)*\
(np.einsum('bd, ag -> abgd', eye, gamma_ag) - np.einsum('ad, bg -> abgd', eye, gamma_bg) -\
np.einsum('bg, ad -> abgd', eye, gamma_ad) + np.einsum('ag, bd -> abgd', eye, gamma_bd))

Expand Down Expand Up @@ -265,33 +283,11 @@ def calc_T2(gamma, N, conjugate=False):
term_4 = np.einsum('bd, geza -> abgdez', eye, gamma)
term_5 = np.einsum('ae, gdzb -> abgdez', eye, gamma)
term_6 = np.einsum('be, gdza -> abgdez', eye, gamma)
return term_1 + term_2 - term_3 + term_4 + term_5 - term_6
a_dtilda = np.einsum('lkalkg -> ag', gamma)
a_tilda = np.einsum('lablgd -> abgd', gamma)
a_bar = np.einsum('ablgdl -> abgd', gamma)

term_1 = np.einsum('bd, ag -> abgd', eye, a_dtilda)
term_2 = np.einsum('ad, bg -> abgd', eye, a_dtilda)
term_3 = np.einsum('bg, ad -> abgd', eye, a_dtilda)
term_4 = np.einsum('ag, bd -> abgd', eye, a_dtilda)
# term_5 = a_bar
term_6 = np.einsum('dabg -> abgd', a_tilda)
term_7 = np.einsum('dbag -> abgd', a_tilda)
term_8 = np.einsum('gabd -> abgd', a_tilda)
term_9 = np.einsum('gbad -> abgd', a_tilda)
return 0.5/(N-1) * (term_1 - term_2 - term_3 + term_4) +\
a_bar - (term_6 - term_7 - term_8 + term_9)
eye = np.eye(gamma.shape[0])
rho = 1/(N-1) * np.einsum('abgb -> ag', gamma)
if not conjugate:
term_1 = np.einsum('ad, be, gz -> abgdez', eye, eye, rho) -\
np.einsum('ae, bd, gz -> abgdez', eye, eye, rho)
term_2 = np.einsum('gz, abde -> abgdez', eye, gamma)
term_3 = np.einsum('ad, gezb -> abgdez', eye, gamma)
term_4 = np.einsum('bd, geza -> abgdez', eye, gamma)
term_5 = np.einsum('ae, gdzb -> abgdez', eye, gamma)
term_6 = np.einsum('be, gdza -> abgdez', eye, gamma)
return term_1 + term_2 - term_3 + term_4 + term_5 - term_6
k1 = term_1 + term_2 - term_3 + term_4 + term_5 - term_6
DM = 1/(N-2) * np.einsum('ablgdl -> abgd', k1) # check please
return DM

else:
a_dtilda = np.einsum('lkalkg -> ag', gamma)
a_tilda = np.einsum('lablgd -> abgd', gamma)
a_bar = np.einsum('ablgdl -> abgd', gamma)
Expand All @@ -308,7 +304,6 @@ def calc_T2(gamma, N, conjugate=False):
return 0.5/(N-1) * (term_1 - term_2 - term_3 + term_4) +\
a_bar - (term_6 - term_7 - term_8 + term_9)


def calc_T2_prime():
pass

Loading