diff --git a/SIAB/opt_orb_pytorch_dpsi/IO/cal_weight.py b/SIAB/opt_orb_pytorch_dpsi/IO/cal_weight.py
new file mode 100644
index 00000000..7580d671
--- /dev/null
+++ b/SIAB/opt_orb_pytorch_dpsi/IO/cal_weight.py
@@ -0,0 +1,70 @@
+import IO.read_istate
+import torch
+import re
+import functools
+import operator
+
+def cal_weight(info_weight, flag_same_band, stru_file_list=None):
+ """ weight[ist][ib] """
+
+ if "bands_file" in info_weight.keys():
+ if "bands_range" in info_weight.keys():
+ raise IOError('"bands_file" and "bands_range" only once')
+
+ weight = [] # weight[ist][ib]
+ for weight_stru, file_name in zip(info_weight["stru"], info_weight["bands_file"]):
+ occ = IO.read_istate.read_istate(file_name)
+ weight += [occ_k * weight_stru for occ_k in occ]
+
+ elif "bands_range" in info_weight.keys():
+ k_weight = read_k_weight(stru_file_list) # k_weight[ist][ik]
+ nbands = read_nbands(stru_file_list) # nbands[ist]
+
+ st_weight = [] # st_weight[ist][ib]
+ for weight_stru, bands_range, nbands_ist in zip(info_weight["stru"], info_weight["bands_range"], nbands):
+ st_weight_tmp = torch.zeros((nbands_ist,))
+ st_weight_tmp[:bands_range] = weight_stru
+ st_weight.append( st_weight_tmp )
+
+ weight = [] # weight[ist][ib]
+ for ist,_ in enumerate(k_weight):
+ for ik,_ in enumerate(k_weight[ist]):
+ weight.append(st_weight[ist] * k_weight[ist][ik])
+
+ else:
+ raise IOError('"bands_file" and "bands_range" must once')
+
+
+ if not flag_same_band:
+ for ist,_ in enumerate(weight):
+ weight[ist] = torch.tensordot(weight[ist], weight[ist], dims=0)
+
+
+ normalization = functools.reduce(operator.add, map(torch.sum, weight), 0)
+ weight = list(map(lambda x:x/normalization, weight))
+
+ return weight
+
+
+def read_k_weight(stru_file_list):
+ """ weight[ist][ik] """
+ weight = [] # weight[ist][ik]
+ for file_name in stru_file_list:
+ weight_k = [] # weight_k[ik]
+ with open(file_name,"r") as file:
+ data = re.compile(r"(.+)", re.S).search(file.read()).group(1).split("\n")
+ for line in data:
+ line = line.strip()
+ if line:
+ weight_k.append(float(line.split()[-1]))
+ weight.append(weight_k)
+ return weight
+
+
+def read_nbands(stru_file_list):
+ """ nbands[ib] """
+ nbands = []
+ for file_name in stru_file_list:
+ with open(file_name,"r") as file:
+ nbands.append(int(re.compile(r"(\d+)\s+nbands").search(file.read()).group(1)))
+ return nbands
diff --git a/SIAB/opt_orb_pytorch_dpsi/IO/change_info.py b/SIAB/opt_orb_pytorch_dpsi/IO/change_info.py
new file mode 100644
index 00000000..e9012042
--- /dev/null
+++ b/SIAB/opt_orb_pytorch_dpsi/IO/change_info.py
@@ -0,0 +1,98 @@
+import addict
+import util
+import itertools
+
+def change_info(info_old, weight_old):
+ info_stru = [None] * info_old.Nst
+ for ist in range(len(info_stru)):
+ info_stru[ist] = addict.Dict()
+ for ist,Na in enumerate(info_old.Na):
+ info_stru[ist].Na = Na
+ for ist,weight in enumerate(weight_old):
+ info_stru[ist].weight = weight
+ info_stru[ist].Nb = weight.shape[0]
+ for ib in range(weight.shape[0], 0, -1):
+ if weight[ib-1]>0:
+ info_stru[ist].Nb_true = ib
+ break
+
+ info_element = addict.Dict()
+ for it_index,it in enumerate(info_old.Nt_all):
+ info_element[it].index = it_index
+ for it,Nu in info_old.Nu.items():
+ info_element[it].Nu = Nu
+ info_element[it].Nl = len(Nu)
+ for it,Rcut in info_old.Rcut.items():
+ info_element[it].Rcut = Rcut
+ for it,dr in info_old.dr.items():
+ info_element[it].dr = dr
+ for it,Ecut in info_old.Ecut.items():
+ info_element[it].Ecut = Ecut
+ for it,Ne in info_old.Ne.items():
+ info_element[it].Ne = Ne
+
+ info_opt = addict.Dict()
+ info_opt.lr = info_old.lr
+ info_opt.cal_T = info_old.cal_T
+ info_opt.cal_smooth = info_old.cal_smooth
+
+ return info_stru, info_element, info_opt
+
+ """
+ info_stru =
+ [{'Na': {'C': 1},
+ 'Nb': 6,
+ 'Nb_true': 4,
+ 'weight': tensor([0.0333, 0.0111, 0.0111, 0.0111, 0.0000, 0.0000])},
+ {'Na': {'C': 1},
+ 'Nb': 6,
+ 'Nb_true': 2,
+ 'weight': tensor([0.0667, 0.0667, 0.0000, 0.0000, 0.0000, 0.0000])},
+ {'Na': {'C': 1, 'O': 2},
+ 'Nb': 10,
+ 'Nb_true': 8,
+ 'weight': tensor([0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.0000, 0.0000])}]
+
+ info_element =
+ {'C': {
+ 'Ecut': 200,
+ 'Ne': 19,
+ 'Nl': 3,
+ 'Nu': [2, 2, 1],
+ 'Rcut': 6,
+ 'dr': 0.01,
+ 'index': 0},
+ 'O': {
+ 'Ecut': 200,
+ 'Ne': 19,
+ 'Nl': 3,
+ 'Nu': [3, 2, 1],
+ 'Rcut': 6,
+ 'dr': 0.01,
+ 'index': 1}}
+
+ info_opt =
+ {'cal_T': False,
+ 'cal_smooth': False,
+ 'lr': 0.01}
+ """
+
+
+def get_info_max(info_stru, info_element):
+ info_max = [None] * len(info_stru)
+ for ist in range(len(info_stru)):
+ Nt = info_stru[ist].Na.keys()
+ info_max[ist] = addict.Dict()
+ info_max[ist].Nt = len(Nt)
+ info_max[ist].Na = max((info_stru[ist].Na[it] for it in Nt))
+ info_max[ist].Nl = max([info_element[it].Nl for it in Nt])
+ info_max[ist].Nm = max((util.Nm(info_element[it].Nl-1) for it in Nt))
+ info_max[ist].Nu = max(itertools.chain.from_iterable([info_element[it].Nu for it in Nt]))
+ info_max[ist].Ne = max((info_element[it].Ne for it in Nt))
+ info_max[ist].Nb = info_stru[ist].Nb
+ return info_max
+
+ """
+ [{'Na': 2, 'Nb': 6, 'Ne': 19, 'Nl': 3, 'Nm': 5, 'Nt': 1, 'Nu': 2},
+ {'Na': 2, 'Nb': 6, 'Ne': 19, 'Nl': 3, 'Nm': 5, 'Nt': 1, 'Nu': 2}]
+ """
diff --git a/SIAB/opt_orb_pytorch_dpsi/IO/func_C.py b/SIAB/opt_orb_pytorch_dpsi/IO/func_C.py
new file mode 100644
index 00000000..95a76035
--- /dev/null
+++ b/SIAB/opt_orb_pytorch_dpsi/IO/func_C.py
@@ -0,0 +1,99 @@
+from util import *
+import torch
+import numpy as np
+
+def random_C_init(info_element):
+ """ C[it][il][ie,iu] """
+ C = dict()
+ for it in info_element.keys():
+ C[it] = ND_list(info_element[it].Nl)
+ for il in range(info_element[it].Nl):
+ C[it][il] = torch.tensor(np.random.uniform(-1,1, (info_element[it].Ne, info_element[it].Nu[il])), dtype=torch.float64, requires_grad=True)
+ return C
+
+
+
+def read_C_init(file_name,info_element):
+ """ C[it][il][ie,iu] """
+ C = random_C_init(info_element)
+
+ with open(file_name,"r") as file:
+
+ for line in file:
+ if line.strip() == "":
+ line=None
+ break
+ ignore_line(file,1)
+
+ C_read_index = set()
+ while True:
+ line = file.readline().strip()
+ if line.startswith("Type"):
+ it,il,iu = file.readline().split();
+ il = int(il)
+ iu = int(iu)-1
+ C_read_index.add((it,il,iu))
+ line = file.readline().split()
+ for ie in range(info_element[it].Ne):
+ if not line: line = file.readline().split()
+ C[it][il].data[ie,iu] = float(line.pop(0))
+ elif line.startswith(""):
+ break;
+ else:
+ raise IOError("unknown line in read_C_init "+file_name+"\n"+line)
+ return C, C_read_index
+
+
+
+def copy_C(C,info_element):
+ C_copy = dict()
+ for it in info_element.keys():
+ C_copy[it] = ND_list(info_element[it].Nl)
+ for il in range(info_element[it].Nl):
+ C_copy[it][il] = C[it][il].clone()
+ return C_copy
+
+
+
+def write_C(file_name,C,Spillage):
+ with open(file_name,"w") as file:
+ print("", file=file)
+ #print("\tTotal number of radial orbitals.", file=file)
+ nTotal = 0
+ for it,C_t in C.items():
+ for il,C_tl in enumerate(C_t):
+ for iu in range(C_tl.size()[1]):
+ nTotal += 1
+ #nTotal = sum(info["Nu"][it])
+ print("\t %s Total number of radial orbitals."%nTotal , file=file)
+ #print("\tTotal number of radial orbitals.", file=file)
+ for it,C_t in C.items():
+ for il,C_tl in enumerate(C_t):
+ for iu in range(C_tl.size()[1]):
+ print("\tType\tL\tZeta-Orbital", file=file)
+ print(f"\t {it} \t{il}\t {iu+1}", file=file)
+ for ie in range(C_tl.size()[0]):
+ print("\t", '%18.14f'%C_tl[ie,iu].item(), file=file)
+ print("", file=file)
+ print("", file=file)
+ print("Left spillage = %.10e"%Spillage.item(), file=file)
+ print("", file=file)
+
+
+#def init_C(info):
+# """ C[it][il][ie,iu] """
+# C = ND_list(max(info.Nt))
+# for it in range(len(C)):
+# C[it] = ND_list(info.Nl[it])
+# for il in range(info.Nl[it]):
+# C[it][il] = torch.autograd.Variable( torch.Tensor( info.Ne, info.Nu[it][il] ), requires_grad = True )
+#
+# with open("C_init.dat","r") as file:
+# line = []
+# for it in range(len(C)):
+# for il in range(info.Nl[it]):
+# for i_n in range(info.Nu[it][il]):
+# for ie in range(info.Ne[it]):
+# if not line: line=file.readline().split()
+# C[it][il].data[ie,i_n] = float(line.pop(0))
+# return C
diff --git a/SIAB/opt_orb_pytorch_dpsi/IO/print_QSV.py b/SIAB/opt_orb_pytorch_dpsi/IO/print_QSV.py
new file mode 100644
index 00000000..1b0f88c5
--- /dev/null
+++ b/SIAB/opt_orb_pytorch_dpsi/IO/print_QSV.py
@@ -0,0 +1,29 @@
+def print_V(V,file_name):
+ """ V[ist][ib] """
+ with open(file_name,"w") as file:
+ for V_s in V:
+ for V_sb in V_s:
+ print(1-V_sb.item(),end="\t",file=file)
+ print(file=file)
+
+def print_S(S,file_name):
+ """ S[ist][it1,it2][il1][il2][ia1*im1*in1,ia2*im2*in2] """
+ with open(file_name,"w") as file:
+ for ist,S_s in enumerate(S):
+ for (it1,it2),S_tt in S_s.items():
+ for il1,S_ttl in enumerate(S_tt):
+ for il2,S_ttll in enumerate(S_ttl):
+ print(ist,it1,it2,il1,il2,file=file)
+ print(S_ttll.real.numpy(),file=file)
+ print(S_ttll.imag.numpy(),"\n",file=file)
+
+def print_Q(Q,file_name):
+ """ Q[ist][it][il][ib,ia*im*iu] """
+ with open(file_name,"w") as file:
+ for ist,Q_s in enumerate(Q):
+ for it,Q_st in Q_s.items():
+ for il,Q_stl in enumerate(Q_st):
+ print(ist,it,il,file=file)
+ print(Q_stl.real.numpy(),file=file)
+ print(Q_stl.imag.numpy(),"\n",file=file)
+
\ No newline at end of file
diff --git a/SIAB/opt_orb_pytorch_dpsi/IO/print_orbital.py b/SIAB/opt_orb_pytorch_dpsi/IO/print_orbital.py
new file mode 100644
index 00000000..f3d0aade
--- /dev/null
+++ b/SIAB/opt_orb_pytorch_dpsi/IO/print_orbital.py
@@ -0,0 +1,70 @@
+periodtable = { 'H': 1, 'He': 2, 'Li': 3, 'Be': 4, 'B': 5, 'C': 6, 'N': 7,
+ 'O': 8, 'F': 9, 'Ne': 10, 'Na': 11, 'Mg': 12, 'Al': 13,
+ 'Si': 14, 'P': 15, 'S': 16, 'Cl': 17, 'Ar': 18, 'K': 19,
+ 'Ca': 20, 'Sc': 21, 'Ti': 22, 'V': 23, 'Cr': 24, 'Mn': 25,
+ 'Fe': 26, 'Co': 27, 'Ni': 28, 'Cu': 29, 'Zn': 30, 'Ga': 31,
+ 'Ge': 32, 'As': 33, 'Se': 34, 'Br': 35, 'Kr': 36, 'Rb': 37,
+ 'Sr': 38, 'Y': 39, 'Zr': 40, 'Nb': 41, 'Mo': 42, 'Tc': 43,
+ 'Ru': 44, 'Rh': 45, 'Pd': 46, 'Ag': 47, 'Cd': 48, 'In': 49,
+ 'Sn': 50, 'Sb': 51, 'Te': 52, 'I': 53, 'Xe': 54, 'Cs': 55,
+ 'Ba': 56, #'La': 57, 'Ce': 58, 'Pr': 59, 'Nd': 60, 'Pm': 61,
+ ## 'Sm': 62, 'Eu': 63, 'Gd': 64, 'Tb': 65, 'Dy': 66, 'Ho': 67,
+ ## 'Er': 68, 'Tm': 69, 'Yb': 70,
+ ## 'Lu': 71,
+ 'Hf': 72, 'Ta': 73,
+ 'W': 74, 'Re': 75, 'Os': 76, 'Ir': 77, 'Pt': 78, 'Au': 79,
+ 'Hg': 80, 'Tl': 81, 'Pb': 82, 'Bi': 83,
+ ## 'Po': 84, #'At': 85,
+ ## 'Rn': 86, #'Fr': 87, 'Ra': 88, 'Ac': 89, 'Th': 90, 'Pa': 91,
+ ## 'U': 92, 'Np': 93, 'Pu': 94, 'Am': 95, 'Cm': 96, 'Bk': 97,
+ ## 'Cf': 98, 'Es': 99, 'Fm': 100, 'Md': 101, 'No': 102, 'Lr': 103,
+ ## 'Rf': 104, 'Db': 105, 'Sg': 106, 'Bh': 107, 'Hs': 108,
+ ## 'Mt': 109, 'Ds': 110, 'Rg': 111, 'Cn': 112, 'Uut': 113,
+ ## 'Fl': 114, 'Uup': 115, 'Lv': 116, 'Uus': 117, 'Uuo': 118
+ }
+
+def print_orbital(orb,info_element):
+ """ orb[it][il][iu][r] """
+ for it,orb_t in orb.items():
+ #with open("orb_{0}.dat".format(it),"w") as file:
+ with open("ORBITAL_{0}U.dat".format( periodtable[it] ),"w") as file:
+ print_orbital_head(file,info_element,it)
+ for il,orb_tl in enumerate(orb_t):
+ for iu,orb_tlu in enumerate(orb_tl):
+ print(""" Type L N""",file=file)
+ print(""" 0 {0} {1}""".format(il,iu),file=file)
+ for ir,orb_tlur in enumerate(orb_tlu):
+ print( '%.14e'%orb_tlur, end=" ",file=file)
+ if ir%4==3: print(file=file)
+ print(file=file)
+
+
+def plot_orbital(orb,Rcut,dr):
+ for it,orb_t in orb.items():
+ #with open("orb_{0}_plot.dat".format(it),"w") as file:
+ with open("ORBITAL_PLOTU.dat", "w") as file:
+ Nr = int(Rcut[it]/dr[it])+1
+ for ir in range(Nr):
+ print( '%10.6f'%(ir*dr[it]),end=" ",file=file)
+ for il,orb_tl in enumerate(orb_t):
+ for orb_tlu in orb_tl:
+ print( '%18.14f'%orb_tlu[ir],end=" ",file=file)
+ print(file=file)
+
+
+def print_orbital_head(file,info_element,it):
+ print( "---------------------------------------------------------------------------", file=file )
+ print( "Element {0}".format(it), file=file )
+ print( "Energy Cutoff(Ry) {0}".format(info_element[it].Ecut), file=file )
+ print( "Radius Cutoff(a.u.) {0}".format(info_element[it].Rcut), file=file )
+ print( "Lmax {0}".format(info_element[it].Nl-1), file=file )
+ l_name = ["S","P","D"]+list(map(chr,range(ord('F'),ord('Z')+1)))
+ for il,iu in enumerate(info_element[it].Nu):
+ print( "Number of {0}orbital--> {1}".format(l_name[il],iu), file=file )
+ print( "---------------------------------------------------------------------------", file=file )
+ print( "SUMMARY END", file=file )
+ print( file=file )
+ print( "Mesh {0}".format(int(info_element[it].Rcut/info_element[it].dr)+1), file=file )
+ print( "dr {0}".format(info_element[it].dr), file=file )
+
+
diff --git a/SIAB/opt_orb_pytorch_dpsi/IO/read_QSV.py b/SIAB/opt_orb_pytorch_dpsi/IO/read_QSV.py
new file mode 100644
index 00000000..b959d9ce
--- /dev/null
+++ b/SIAB/opt_orb_pytorch_dpsi/IO/read_QSV.py
@@ -0,0 +1,160 @@
+import util
+import torch
+import itertools
+import numpy as np
+import re
+import copy
+
+def read_file_head(info,file_list):
+ """ QI[ist][it][il][ib*ia*im,ie] <\psi|jY> """
+ """ SI[ist][it1][it2][il1][il2][ie1,ia1,im1,ia2,im2,ie2] """
+ """ VI[ist][ib] <\psi|\psi> """
+ info_true = copy.deepcopy(info)
+ info_true.Nst = len(file_list)
+ info_true.Nt = util.ND_list(info_true.Nst,element="list()")
+ info_true.Na = util.ND_list(info_true.Nst,element="dict()")
+ info_true.Nb = util.ND_list(info_true.Nst)
+ info_true.Nk = util.ND_list(info_true.Nst)
+ info_true.Ne = dict()
+
+ for ist_true,file_name in enumerate(file_list):
+ print(file_name)
+ with open(file_name,"r") as file:
+
+ util.ignore_line(file,4)
+ Nt_tmp = int(file.readline().split()[0])
+ for it in range(Nt_tmp):
+ t_tmp = file.readline().split()[0]
+ assert t_tmp in info.Nt_all
+ info_true.Nt[ist_true].append( t_tmp )
+ info_true.Na[ist_true][t_tmp] = int(file.readline().split()[0])
+ util.ignore_line( file, info_true.Na[ist_true][t_tmp] )
+ util.ignore_line(file,6)
+ Nl_ist = int(file.readline().split()[0])+1
+ for it,Nl_C in info.Nl.items():
+ print(it,Nl_ist,Nl_C)
+ assert Nl_ist>=Nl_C
+ info_true.Nl[it] = Nl_ist
+ info_true.Nk[ist_true] = int(file.readline().split()[0])
+ info_true.Nb[ist_true] = int(file.readline().split()[0])
+ util.ignore_line(file,1)
+ #Ne_tmp = list(map(int,file.readline().split()[:Nt_tmp]))
+ #for it,Ne in zip(info_true.Nt[ist_true],Ne_tmp):
+ # assert info_true.Ne.setdefault(it,Ne)==Ne
+ Ne_tmp = int(file.readline().split()[0])
+ for it in info_true.Nt[ist_true]:
+ info_true.Ne[it] = Ne_tmp
+
+ info_all = copy.deepcopy(info)
+ info_all.Nst = sum(info_true.Nk,0)
+ repeat_Nk = lambda x: list( itertools.chain.from_iterable( map( lambda x:itertools.repeat(*x), zip(x,info_true.Nk) ) ) )
+ info_all.Nt = repeat_Nk(info_true.Nt)
+ info_all.Na = repeat_Nk(info_true.Na)
+ info_all.Nb = repeat_Nk(info_true.Nb)
+ info_all.Ne = info_true.Ne
+
+ return info_all
+
+
+def read_QSV(info_stru, info_element, file_list, V_info):
+ QI=[]; SI=[]; VI=[]
+ ist = 0
+ for ist_true,file_name in enumerate(file_list):
+ with open(file_name,"r") as file:
+ Nk = int(re.compile(r"(\d+)\s+nks").search(file.read()).group(1))
+ with open(file_name,"r") as file:
+ data = re.compile(r"(.+)", re.S).search(file.read())
+ data = map(float,data.group(1).split())
+ for ik in range(Nk):
+ print("read QI:",ist_true,ik)
+ qi = read_QI(info_stru[ist+ik], info_element, data)
+ QI.append( qi )
+ with open(file_name,"r") as file:
+ data = re.compile(r"(.+)", re.S).search(file.read())
+ data = map(float,data.group(1).split())
+ for ik in range(Nk):
+ print("read SI:",ist_true,ik)
+ si = read_SI(info_stru[ist+ik], info_element, data)
+ SI.append( si )
+ if V_info["init_from_file"]:
+ with open(file_name,"r") as file:
+ data = re.compile(r"(.+)", re.S).search(file.read())
+ data = map(float,data.group(1).split())
+ else:
+ data = ()
+ for ik in range(Nk):
+ print("read VI:",ist_true,ik)
+ vi = read_VI(info_stru[ist+ik], V_info, ist_true, data)
+ VI.append( vi )
+ ist += Nk
+ print()
+ return QI,SI,VI
+
+
+def read_QI(info_stru, info_element, data):
+ """ QI[it][il][ib*ia*im,ie] <\psi|jY> """
+ QI = dict()
+ for it in info_stru.Na.keys():
+ QI[it] = util.ND_list(info_element[it].Nl)
+ for il in range(info_element[it].Nl):
+ QI[it][il] = torch.zeros((info_stru.Nb, info_stru.Na[it], util.Nm(il), info_element[it].Ne), dtype=torch.complex128)
+ for ib in range(info_stru.Nb):
+ for it in info_stru.Na.keys():
+ for ia in range(info_stru.Na[it]):
+ for il in range(info_element[it].Nl):
+ for im in range(util.Nm(il)):
+ for ie in range(info_element[it].Ne):
+ QI[it][il][ib,ia,im,ie] = complex(next(data), next(data))
+ for it in info_stru.Na.keys():
+ for il in range(info_element[it].Nl):
+ QI[it][il] = QI[it][il][:info_stru.Nb_true,:,:,:].view(-1,info_element[it].Ne).conj()
+ return QI
+
+
+def read_SI(info_stru, info_element, data):
+ """ SI[it1,it2][il1][il2][ie1,ia1,im1,ia2,im2,ie2] """
+ SI = dict()
+ for it1,it2 in itertools.product( info_stru.Na.keys(), info_stru.Na.keys() ):
+ SI[it1,it2] = util.ND_list(info_element[it1].Nl, info_element[it2].Nl)
+ for il1,il2 in itertools.product( range(info_element[it1].Nl), range(info_element[it2].Nl) ):
+ SI[it1,it2][il1][il2] = torch.zeros((info_stru.Na[it1], util.Nm(il1), info_element[it1].Ne, info_stru.Na[it2], util.Nm(il2), info_element[it2].Ne), dtype=torch.complex128)
+ for it1 in info_stru.Na.keys():
+ for ia1 in range(info_stru.Na[it1]):
+ for il1 in range(info_element[it1].Nl):
+ for im1 in range(util.Nm(il1)):
+ for it2 in info_stru.Na.keys():
+ for ia2 in range(info_stru.Na[it2]):
+ for il2 in range(info_element[it2].Nl):
+ for im2 in range(util.Nm(il2)):
+ for ie1 in range(info_element[it1].Ne):
+ for ie2 in range(info_element[it2].Ne):
+ SI[it1,it2][il1][il2][ia1,im1,ie1,ia2,im2,ie2] = complex(next(data), next(data))
+# for it1,it2 in itertools.product( info.Nt[ist], info.Nt[ist] ):
+# for il1,il2 in itertools.product( range(info.Nl[it1]), range(info.Nl[it2]) ):
+# SI[it1,it2][il1][il2] = torch_complex.ComplexTensor(
+# torch.from_numpy(SI[it1,it2][il1][il2].real),
+# torch.from_numpy(SI[it1,it2][il1][il2].imag))
+ return SI
+
+
+
+def read_VI(info_stru,V_info,ist,data):
+ if V_info["same_band"]:
+ """ VI[ib] """
+ if V_info["init_from_file"]:
+ VI = np.empty(info_stru.Nb,dtype=np.float64)
+ for ib in range(info_stru.Nb):
+ VI.data[ib] = next(data)
+ VI = VI[:info_stru.Nb_true]
+ else:
+ VI = np.ones(info_stru.Nb_true, dtype=np.float64)
+ else:
+ """ VI[ib1,ib2] """
+ if V_info["init_from_file"]:
+ VI = np.empty((info_stru.Nb,info_stru.Nb),dtype=np.float64)
+ for ib1,ib2 in itertools.product( range(info_stru.Nb), range(info_stru.Nb) ):
+ VI[ib1,ib2] = next(data)
+ VI = VI[info_stru.Nb_true, info_stru.Nb_true]
+ else:
+ VI = np.eye(info_stru.Nb_true, info_stru.Nb_true, dtype=np.float64)
+ return torch.from_numpy(VI)
diff --git a/SIAB/opt_orb_pytorch_dpsi/IO/read_istate.py b/SIAB/opt_orb_pytorch_dpsi/IO/read_istate.py
new file mode 100644
index 00000000..e1e7277e
--- /dev/null
+++ b/SIAB/opt_orb_pytorch_dpsi/IO/read_istate.py
@@ -0,0 +1,42 @@
+import re
+import torch
+import itertools
+
+# occ[ik][ib]
+def read_istate(file_name):
+ nspin0 = get_nspin0(file_name)
+ if nspin0==1: occ = [[]]
+ elif nspin0==2: occ = [[],[]]
+ with open(file_name,"r") as file:
+ content = file.read().split("BAND")
+ for content_k in content[1:]:
+ content_k = content_k.split("\n")
+ k = get_k(content_k[0])
+ for ispin in range(nspin0):
+ occ[ispin].append([])
+ for line in content_k[1:]:
+ line = line.strip()
+ if line:
+ line = line.split()
+ if nspin0==1:
+ occ[0][-1].append(float(line[2]))
+ elif nspin0==2:
+ occ[0][-1].append(float(line[2]))
+ occ[1][-1].append(float(line[4]))
+ for ispin in range(nspin0):
+ occ[ispin][-1] = torch.Tensor(occ[ispin][-1])
+ occ = list(itertools.chain(*occ))
+ return occ
+
+def get_k(line):
+ k = re.compile(r"Kpoint\s*=\s*(\d+)").search(line).group(1)
+ return int(k)
+
+def get_nspin0(file_name):
+ with open(file_name,"r") as file:
+ file.readline()
+ line = file.readline()
+ lens = len(line.split())
+ if lens == 3: return 1
+ elif lens == 5: return 2
+ else: raise
\ No newline at end of file
diff --git a/SIAB/opt_orb_pytorch_dpsi/IO/read_json.py b/SIAB/opt_orb_pytorch_dpsi/IO/read_json.py
new file mode 100644
index 00000000..67705b4b
--- /dev/null
+++ b/SIAB/opt_orb_pytorch_dpsi/IO/read_json.py
@@ -0,0 +1,82 @@
+import json
+from util import Info
+
+def read_json(file_name):
+
+ with open(file_name,"r") as file:
+ input = file.read()
+ input = json.loads(input)
+
+ info = Info()
+ for info_attr,info_value in input["info"].items():
+ info.__dict__[info_attr] = info_value
+ info.Nl = { it:len(Nu) for it,Nu in info.Nu.items() }
+
+ return input["file_list"], info, input["weight"], input["C_init_info"], input["V_info"]
+
+ """ file_name
+ {
+ "file_list": {
+ "origin": [
+ "~/C_bulk/orb_matrix/test.0.dat",
+ "~/CO2/orb_matrix/test.0.dat"
+ ],
+ "linear": [
+ [
+ "~/C_bulk/orb_matrix/test.1.dat",
+ "~/CO2/orb_matrix/test.1.dat"
+ ],
+ [
+ "~/C_bulk/orb_matrix/test.2.dat",
+ "~/CO2/orb_matrix/test.2.dat"
+ ],
+ ]
+ },
+ "info": {
+ "Nt_all": [ "C", "O" ],
+ "Nu": { "C":[2,2,1], "O":[3,2,1] },
+ "Rcut": { "C":6, "O":6 },
+ "dr": { "C":0.01, "O":0.01 },
+ "Ecut": { "C":200, "O":200 },
+ "lr": 0.01,
+ "cal_T": false,
+ "cal_smooth": false
+ },
+ "weight":
+ {
+ "stru": [1, 2.3],
+ "bands_range": [10, 15], # "bands_range" and "bands_file" only once
+ "bands_file":
+ [
+ "~/C_bulk/OUT.ABACUS/istate.info",
+ "~/CO2/OUT.ABACUS/istate.info"
+ ]
+ },
+ "C_init_info": {
+ "init_from_file": false,
+ "C_init_file": "~/CO/ORBITAL_RESULTS.txt",
+ "opt_C_read": false
+ },
+ "V_info": {
+ "init_from_file": true,
+ "same_band": true
+ }
+ }
+ """
+
+ """ info
+ Nt_all ['C', 'O']
+ Nu {'C': [2, 2, 1], 'O': [3, 2, 1]}
+ Rcut {'C': 6, 'O': 6}
+ dr {'C': 0.01, 'O': 0.01}
+ Ecut {'C': 200, 'O': 200}
+ lr 0.01
+ cal_T False
+ cal_smooth False
+ Nl {'C': 3, 'O': 3}
+ Nst 3
+ Nt [['C'], ['C'], ['C', 'O']]
+ Na [{'C': 1}, {'C': 1}, {'C': 1, 'O': 2}]
+ Nb [6, 6, 10]
+ Ne {'C': 19, 'O': 19}
+ """
\ No newline at end of file
diff --git a/SIAB/opt_orb_pytorch_dpsi/main.py b/SIAB/opt_orb_pytorch_dpsi/main.py
new file mode 100755
index 00000000..a3e17e91
--- /dev/null
+++ b/SIAB/opt_orb_pytorch_dpsi/main.py
@@ -0,0 +1,161 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+import IO.read_QSV
+import IO.print_QSV
+import IO.func_C
+import IO.read_json
+import IO.print_orbital
+import opt_orbital
+import orbital
+import torch
+import numpy as np
+import time
+import torch_optimizer
+import IO.cal_weight
+import util
+import IO.change_info
+import pprint
+
+def main():
+ seed = int(1000*time.time())%(2**32)
+ np.random.seed(seed)
+ print("seed:",seed)
+ time_start = time.time()
+
+ file_list, info_true, weight_info, C_init_info, V_info = IO.read_json.read_json("INPUT")
+
+ weight = IO.cal_weight.cal_weight(weight_info, V_info["same_band"], file_list["origin"])
+
+ info_kst = IO.read_QSV.read_file_head(info_true,file_list["origin"])
+
+ info_stru, info_element, info_opt = IO.change_info.change_info(info_kst,weight)
+ info_max = IO.change_info.get_info_max(info_stru, info_element)
+
+ print("info_kst:", info_kst, sep="\n", end="\n"*2, flush=True)
+ print("info_stru:", pprint.pformat(info_stru), sep="\n", end="\n"*2, flush=True)
+ print("info_element:", pprint.pformat(info_element,width=40), sep="\n", end="\n"*2, flush=True)
+ print("info_opt:", pprint.pformat(info_opt,width=40), sep="\n", end="\n"*2, flush=True)
+ print("info_max:", pprint.pformat(info_max), sep="\n", end="\n"*2, flush=True)
+
+ QI,SI,VI_origin = IO.read_QSV.read_QSV(info_stru, info_element, file_list["origin"], V_info)
+ if "linear" in file_list.keys():
+ QI_linear, SI_linear, VI_linear = list(zip(*( IO.read_QSV.read_QSV(info_stru, info_element, file, V_info) for file in file_list["linear"] )))
+
+ if C_init_info["init_from_file"]:
+ C, C_read_index = IO.func_C.read_C_init( C_init_info["C_init_file"], info_element )
+ else:
+ C = IO.func_C.random_C_init(info_element)
+ E = orbital.set_E(info_element)
+ orbital.normalize(
+ orbital.generate_orbital(info_element,C,E),
+ {it:info_element[it].dr for it in info_element},
+ C, flag_norm_C=True)
+
+ opt_orb = opt_orbital.Opt_Orbital()
+
+ #opt = torch.optim.Adam(sum( ([c.real,c.imag] for c in sum(C,[])), []), lr=info_opt.lr, eps=1e-8)
+ #opt = torch.optim.Adam( sum(C.values(),[]), lr=info_opt.lr, eps=1e-20, weight_decay=info_opt.weight_decay)
+ #opt = radam.RAdam( sum(C.values(),[]), lr=info_opt.lr, eps=1e-20 )
+ opt = torch_optimizer.SWATS( sum(C.values(),[]), lr=info_opt.lr, eps=1e-20 )
+
+ flag_finish = 0
+ C_old = 0
+ with open("Spillage.dat","w") as S_file:
+
+ print( "\nSee \"Spillage.dat\" for detail status: " , flush=True )
+ if info_opt.cal_T:
+ print( '%5s'%"istep", "%20s"%"Spillage", "%20s"%"T.item()", "%20s"%"Loss", flush=True )
+ else:
+ print( '%5s'%"istep", "%20s"%"Spillage", flush=True )
+
+ loss_old = np.inf
+ for istep in range(10000):
+
+ Spillage = 0
+ for ist in range(len(info_stru)):
+
+ Q = opt_orb.change_index_Q(opt_orb.cal_Q(QI[ist],C,info_stru[ist],info_element),info_stru[ist])
+ S = opt_orb.change_index_S(opt_orb.cal_S(SI[ist],C,info_stru[ist],info_element),info_stru[ist],info_element)
+ coef = opt_orb.cal_coef(Q,S)
+ V = opt_orb.cal_V(coef,Q)
+ V_origin = opt_orb.cal_V_origin(V,V_info)
+
+ if "linear" in file_list.keys():
+ V_linear = [None] * len(file_list["linear"])
+ for i in range(len(file_list["linear"])):
+ Q_linear = opt_orb.change_index_Q(opt_orb.cal_Q(QI_linear[i][ist],C,info_stru[ist],info_element),info_stru[ist])
+ S_linear = opt_orb.change_index_S(opt_orb.cal_S(SI_linear[i][ist],C,info_stru[ist],info_element),info_stru[ist],info_element)
+ V_linear[i] = opt_orb.cal_V_linear(coef,Q_linear,S_linear,V,V_info)
+
+ def cal_Spillage(V_delta):
+ Spillage = (V_delta * weight[ist][:info_stru[ist].Nb_true]).sum()
+ return Spillage
+
+ def cal_delta(VI, V):
+ return ((VI[ist]-V)/util.update0(VI[ist])).abs() # abs or **2?
+
+ Spillage += 2*cal_Spillage(cal_delta(VI_origin,V_origin))
+ if "linear" in file_list.keys():
+ for i in range(len(file_list["linear"])):
+ Spillage += cal_Spillage(cal_delta(VI_linear[i],V_linear[i]))
+
+ if info_opt.cal_T:
+ T = opt_orb.cal_T(C,E)
+ if not "TSrate" in vars(): TSrate = torch.abs(0.002*Spillage/T).data[0]
+ Loss = Spillage + TSrate*T
+ else:
+ Loss = Spillage
+
+ if info_opt.cal_T:
+ print_content = [istep, Spillage.item(), T.item(), Loss.item()]
+ else:
+ print_content = [istep, Spillage.item()]
+ print(*print_content, sep="\t", file=S_file, flush=True)
+ if not istep%100:
+ print(*print_content, sep="\t", flush=True)
+
+ if Loss.item() < loss_old:
+ loss_old = Loss.item()
+ C_old = IO.func_C.copy_C(C,info_element)
+ flag_finish = 0
+ else:
+ flag_finish += 1
+ if flag_finish > 50:
+ break
+
+ opt.zero_grad()
+ Loss.backward()
+ if C_init_info["init_from_file"] and not C_init_info["opt_C_read"]:
+ for it,il,iu in C_read_index:
+ C[it][il].grad[:,iu] = 0
+ opt.step()
+ #orbital.normalize(
+ # orbital.generate_orbital(info_element,C,E),
+ # {it:info_element[it].dr for it in info_element},
+ # C, flag_norm_C=True)
+
+ orb = orbital.generate_orbital(info_element,C_old,E)
+ if info_opt.cal_smooth:
+ orbital.smooth_orbital(
+ orb,
+ {it:info_element[it].Rcut for it in info_element}, {it:info_element[it].dr for it in info_element},
+ 0.1)
+ orbital.orth(
+ orb,
+ {it:info_element[it].dr for it in info_element})
+ IO.print_orbital.print_orbital(orb,info_element)
+ IO.print_orbital.plot_orbital(
+ orb,
+ {it:info_element[it].Rcut for it in info_element},
+ {it:info_element[it].dr for it in info_element})
+
+ IO.func_C.write_C("ORBITAL_RESULTS.txt",C_old,Spillage)
+
+ print("Time (PyTorch): %s\n"%(time.time()-time_start), flush=True )
+
+
+if __name__=="__main__":
+ import sys
+ np.set_printoptions(threshold=sys.maxsize, linewidth=10000)
+ print( sys.version, flush=True )
+ main()
diff --git a/SIAB/opt_orb_pytorch_dpsi/opt_orbital.py b/SIAB/opt_orb_pytorch_dpsi/opt_orbital.py
new file mode 100644
index 00000000..342a634a
--- /dev/null
+++ b/SIAB/opt_orb_pytorch_dpsi/opt_orbital.py
@@ -0,0 +1,178 @@
+from util import ND_list
+import util
+import functools
+import itertools
+import torch
+
+class Opt_Orbital:
+
+ def cal_Q(self,QI,C,info_stru,info_element):
+ """
+ <\psi|\phi> = <\psi|jY> *
+ Q[it][il][ib,ia*im*iu]
+ = sum_{q} QI[it][il][ib*ia*im,ie] * C[it][il][ie,iu]
+ """
+ Q = dict()
+ for it in info_stru.Na.keys():
+ Q[it] = ND_list(info_element[it].Nl)
+
+ for it in info_stru.Na.keys():
+ for il in range(info_element[it].Nl):
+ Q[it][il] = torch.mm( QI[it][il], C[it][il].to(torch.complex128) ).view(info_stru.Nb_true,-1)
+ return Q
+
+
+
+ def cal_S(self,SI,C,info_stru,info_element):
+ """
+ <\phi|\phi> = <\phi|jY> * *
+ S[it1,it2][il1][il2][ia1*im1*iu1,ia2*im2*iu2]
+ = sum_{ie1 ie2} C^*[it1][il1][ie1,iu1] * SI[it1,it2][il1][il2][ia1,im1,ie1,ia2,im2,ie2] * C[it2][[il2][ie2,iu2]
+ """
+ S = dict()
+ for it1,it2 in itertools.product( info_stru.Na.keys(), info_stru.Na.keys() ):
+ S[it1,it2] = ND_list(info_element[it1].Nl, info_element[it2].Nl)
+
+ for it1,it2 in itertools.product( info_stru.Na.keys(), info_stru.Na.keys() ):
+ for il1,il2 in itertools.product( range(info_element[it1].Nl), range(info_element[it2].Nl) ):
+ # SI_C[ia1*im1*ie1*ia2*im2,iu2]
+ SI_C = torch.mm(
+ SI[it1,it2][il1][il2].view(-1,info_element[it2].Ne),
+ C[it2][il2].to(torch.complex128) )
+ # SI_C[ia1*im1,ie1,ia2*im2*iu2]
+ SI_C = SI_C.view( info_stru.Na[it1]*util.Nm(il1), info_element[it1].Ne, -1 )
+ # Ct[iu1,ie1]
+ Ct = C[it1][il1].t().to(torch.complex128)
+ C_mm = functools.partial(torch.mm,Ct)
+ # C_SI_C[ia1*im1][iu1,ia2*im2*iu2]
+ C_SI_C = list(map( C_mm, SI_C ))
+ # C_SI_C[ia1*im1*iu1,ia2*im2*iu2]
+ C_SI_C = torch.cat( C_SI_C, dim=0 )
+#??? C_SI_C = C_SI_C.view(info_stru.Na[it1]*util.Nm(il1)*info_element[it1].Nu[il1],-1)
+ S[it1,it2][il1][il2] = C_SI_C
+ return S
+
+
+
+ def change_index_S(self,S,info_stru,info_element): # S[it1,it2][il1][il2][ia1*im1*iu1,ia2*im2*iu2]
+ """
+ <\phi|\phi>
+ S_cat[it1*il1*iat*im1*iu1,iat2*il2*ia2*im2*iu2]
+ """
+ # S_[it1][il1*ia1*im1*iu1,it2*il2*ia2*im2*iu2]
+ S_ = dict()
+ for it1 in info_stru.Na.keys():
+ # S_t[it2][il1*ia1*im1*iu1,il2*ia2*im2*iu2]
+ S_t = dict()
+ for it2 in info_stru.Na.keys():
+ # S_tt[il1][ia1*im1*iu1,il2*ia2*im2*iu2]
+ S_tt = ND_list(info_element[it1].Nl)
+ for il1 in range(info_element[it1].Nl):
+ S_tt[il1] = torch.cat( S[it1,it2][il1], dim=1 )
+ S_t[it2] = torch.cat( S_tt, dim=0 )
+ S_[it1] = torch.cat( list(S_t.values()), dim=1 )
+ # S_cat[it1*il1*ia1*im1*iu1,it2*il2*ia2*im2*iu2]
+ S_cat = torch.cat( list(S_.values()), dim=0 )
+ return S_cat
+
+
+
+ def change_index_Q(self,Q,info_stru): # Q[it][il][ib,ia*im*iu]
+ """
+ <\psi|\phi>
+ Q_cat[ib,it*il*ia*im*iu]
+ """
+ # Q_b[ib][0,it*il*ia*im*iu]
+ Q_b = ND_list(info_stru.Nb_true)
+ for ib in range(info_stru.Nb_true):
+ # Q_[it][il*ia*im*iu]
+ Q_ = dict()
+ for it in info_stru.Na.keys():
+ # Q_ts[il][ia*im*iu]
+ Q_ts = [ Q_tl[ib] for Q_tl in Q[it] ]
+ Q_[it] = torch.cat(Q_ts)
+ Q_b[ib] = torch.cat(list(Q_.values())).view(1,-1)
+ # Q_cat[ib,it*il*ia*im*iu]
+ Q_cat = torch.cat( Q_b, dim=0 )
+ return Q_cat
+
+
+
+ def cal_coef(self,Q,S):
+ # Q[ib,it*il*ia*im*iu]
+ # S[it1*il1*ia1*im1*iu1,it2*il2*ia2*im2*iu2]
+ """
+ <\psi|\phi> * <\phi|\phi>^{-1}
+ coef[ib,it*il*ia*im*iu]
+ = Q[ib,it1*il1*ia1*im1*iu1] * S{[it1*il1*ia1*im1*iu1,it2*il2*ia2*im2*iu2]}^{-1}
+ """
+ S_I = torch.inverse(S)
+ coef = torch.mm(Q, S_I)
+ return coef
+
+
+
+ def cal_V(self,coef,Q):
+ # coef[ib,it*il*ia*im*iu]
+ # Q[ib,it*il*ia*im*iu]
+ """
+ <\psi|\psi> = <\psi|\phi> * <\phi|\phi>^{-1} * <\phi|psi>
+ V[ib1,ib2]
+ = sum_{it1,ia1,il1,im1,iu1} sum_{it2,ia2,il2,im2,iu2}
+ Q[ib1,it1*il1*ia1*im1*iu1] * S{[it1*il1*ia1*im1*iu1,it2*il2*ia2*im2*iu2]}^{-1} * Q[ib2,it2*il2*ia2*im2*iu2]
+ """
+ V = torch.mm( coef, Q.t().conj() ).real
+ return V
+
+
+ def cal_V_origin(self,V,V_info):
+ # V[ib1,ib2]
+ """
+ <\psi|\psi> = <\psi|\phi> * <\phi|\phi>^{-1} * <\phi|psi>
+ V_origin[ib]
+ V_origin[ib1,ib2]
+ """
+ if V_info["same_band"]: V_origin = V.diag().sqrt()
+ else: V_origin = V.sqrt()
+ return V_origin
+
+
+ def cal_V_linear(self,coef,Q_linear,S_linear,V,V_info):
+ # coef[ib,it*il*ia*im*iu]
+ # Q_linear[ib,it*il*ia*im*iu]
+ # S_linear[it1*il1*ia1*im1*iu1,it2*il2*ia2*im2*iu2]
+ # V[ib1,ib2]
+ """
+ V_linear[ib]
+ V_linear[ib1,ib2]
+ """
+ V_linear_1 = coef.mm(S_linear).mm(coef.t().conj()).real
+ V_linear_2 = Q_linear.mm(coef.t().conj()).real
+ V_linear_3 = coef.mm(Q_linear.t().conj()).real
+ if V_info["same_band"]:
+ V_linear_1 = V_linear_1.diag()
+ V_linear_2 = V_linear_2.diag()
+ V_linear_3 = V_linear_3.diag()
+ if V_info["same_band"]: Z = V.diag().sqrt()
+ else: Z = V.sqrt()
+ Z = util.update0(Z)
+ V_linear = (-V_linear_1/Z + V_linear_2 + V_linear_3) / Z
+ return V_linear
+
+
+ def cal_T(self,C,E):
+ """ T = 0.5* sum_{it,il,iu} sum_{ie} ( E[it][il,ie] * C[it][il][ie,iu] )**2 """
+ T = torch.zeros(1)
+ num = 0
+ for it,C_t in C.items():
+ for il,C_tl in enumerate(C_t):
+ for iu in range(C_tl.size()[1]):
+ T_tlu = torch.zeros(1)
+ Z_tlu = 0
+ for ie in range(C_tl.size()[0]):
+ T_tlu = T_tlu + ( E[it][il,ie] * C_tl[ie,iu] )**2
+ Z_tlu = Z_tlu + E[it][il,ie].item()**2
+ T = T + T_tlu/Z_tlu
+ num += C_tl.size()[1]
+ T = 0.5 * T / num
+ return T
\ No newline at end of file
diff --git a/SIAB/opt_orb_pytorch_dpsi/orbital.py b/SIAB/opt_orb_pytorch_dpsi/orbital.py
new file mode 100644
index 00000000..7600e778
--- /dev/null
+++ b/SIAB/opt_orb_pytorch_dpsi/orbital.py
@@ -0,0 +1,80 @@
+from util import ND_list
+import numpy as np
+from scipy.special import spherical_jn
+from scipy.integrate import simps
+from scipy.optimize import fsolve
+import functools
+import torch
+
+def generate_orbital(info_element,C,E):
+ """ C[it][il][ie,iu] """
+ """ orb[it][il][iu][r] = \suml_{ie} C[it][il][ie,iu] * jn(il,ie*r) """
+ orb = dict()
+ for it in info_element:
+ Nr = int(info_element[it].Rcut/info_element[it].dr)+1
+ orb[it] = ND_list(info_element[it].Nl)
+ for il in range(info_element[it].Nl):
+ orb[it][il] = ND_list(info_element[it].Nu[il])
+ for iu in range(info_element[it].Nu[il]):
+ orb[it][il][iu] = np.zeros(Nr)
+ for ir in range(Nr):
+ r = ir * info_element[it].dr
+ for ie in range(info_element[it].Ne):
+ orb[it][il][iu][ir] += C[it][il][ie,iu].item() * spherical_jn(il,E[it][il,ie].item()*r)
+ return orb
+
+
+def smooth_orbital(orb,Rcut,dr,smearing_sigma):
+ for it,orb_t in orb.items():
+ for orb_tl in orb_t:
+ for orb_tlu in orb_tl:
+ for ir in range(orb_tlu.shape[0]):
+ assert orb_tlu.shape[0] == int(Rcut[it]/dr[it])+1
+ r = ir * dr[it]
+ orb_tlu[ir] *= 1-np.exp( -(r-Rcut[it])**2/(2*smearing_sigma**2) )
+
+
+
+def inner_product( orb1, orb2, dr ):
+ assert orb1.shape == orb2.shape
+ r = np.array(range(orb1.shape[0]))*dr
+ return simps( orb1 * orb2 * r * r, dx=dr )
+
+def normalize(orb,dr,C=None,flag_norm_orb=False,flag_norm_C=False):
+ """ C[it][il][ie,iu] """
+ """ orb[it][il][iu][r] = \suml_{ie} C[it][il][ie,iu] * jn(il,ie*r) """
+ for it,orb_t in orb.items():
+ for il,orb_tl in enumerate(orb_t):
+ for iu,orb_tlu in enumerate(orb_tl):
+ norm = np.sqrt(inner_product(orb_tlu,orb_tlu,dr[it]))
+ if flag_norm_orb: orb_tlu[:] = orb_tlu / norm
+ if flag_norm_C: C[it][il].data[:,iu] = C[it][il].data[:,iu] / norm
+
+def orth(orb,dr):
+ """ |n'> = 1/Z ( |n> - \sum_{i=0}^{n-1} |i> ) """
+ """ orb[it][il][iu,r] """
+ for it,orb_t in orb.items():
+ for il,orb_tl in enumerate(orb_t):
+ for iu1,orb_tlu1 in enumerate(orb_tl):
+ for iu2 in range(iu1):
+ orb_tlu1[:] -= orb_tl[iu2] * inner_product(orb_tlu1,orb_tl[iu2],dr[it])
+ orb_tlu1[:] = orb_tlu1 / np.sqrt(inner_product(orb_tlu1,orb_tlu1,dr[it]))
+
+def find_eigenvalue(Nl,Ne):
+ """ E[il,ie] """
+ E = np.zeros((Nl,Ne+Nl+1))
+ for ie in range(1,Ne+Nl+1):
+ E[0,ie] = ie*np.pi
+ for il in range(1,Nl):
+ jl = functools.partial(spherical_jn,il)
+ for ie in range(1,Ne+Nl+1-il):
+ E[il,ie] = fsolve( jl, (E[il-1,ie]+E[il-1,ie+1])/2 )
+ return E[:,1:Ne+1]
+
+def set_E(info_element):
+ """ E[it][il,ie] """
+ eigenvalue = { it:find_eigenvalue(info_element[it].Nl,info_element[it].Ne) for it in info_element }
+ E = dict()
+ for it in info_element:
+ E[it] = torch.from_numpy(( eigenvalue[it]/info_element[it].Rcut ).astype("float64"))
+ return E
\ No newline at end of file
diff --git a/SIAB/opt_orb_pytorch_dpsi/torch_complex_bak.py b/SIAB/opt_orb_pytorch_dpsi/torch_complex_bak.py
new file mode 100644
index 00000000..f08be745
--- /dev/null
+++ b/SIAB/opt_orb_pytorch_dpsi/torch_complex_bak.py
@@ -0,0 +1,84 @@
+import torch
+
+class ComplexTensor:
+ def __init__(self,real,imag):
+ self.real = real
+ self.imag = imag
+
+ def view(self,*args,**kwargs):
+ return ComplexTensor( self.real.view(*args,**kwargs), self.imag.view(*args,**kwargs) )
+ def t(self,*args,**kwargs):
+ return ComplexTensor( self.real.t(*args,**kwargs), self.imag.t(*args,**kwargs) )
+# def transpose(self,*args,**kwargs):
+# return ComplexTensor( self.real.transpose(*args,**kwargs), self.imag.transpose(*args,**kwargs) )
+ def __getitem__(self,*args,**kwargs):
+ return ComplexTensor( self.real.__getitem__(*args,**kwargs), self.imag.__getitem__(*args,**kwargs) )
+ def __str__(self):
+ return "<{0};{1}>".format(self.real, self.imag)
+ __repr__=__str__
+# def size(self,*args,**kwargs):
+# return ComplexTensor( self.real.size(*args,**kwargs), self.imag.size(*args,**kwargs) )
+
+ def conj(self):
+ return ComplexTensor( self.real, -self.imag )
+
+ def mm( self,x2, *args,**kwargs ):
+ return mm( self,x2, *args,**kwargs )
+
+
+def dot( x1,x2, *args,**kwargs ):
+ if isinstance(x1,ComplexTensor):
+ if isinstance(x2,ComplexTensor):
+ return ComplexTensor( torch.dot( x1.real,x2.real, *args,**kwargs ) - torch.dot( x1.imag,x2.imag, *args,**kwargs ), torch.dot( x1.real,x2.imag, *args,**kwargs ) + torch.dot( x1.imag,x2.real, *args,**kwargs ) )
+ else:
+ return ComplexTensor( torch.dot( x1.real,x2, *args,**kwargs ), torch.dot( x1.imag,x2, *args,**kwargs ) )
+ else:
+ if isinstance(x2,ComplexTensor):
+ return ComplexTensor( torch.dot( x1,x2.real, *args,**kwargs ), torch.dot( x1,x2.imag, *args,**kwargs ) )
+ else:
+ return torch.dot( x1,x2, *args,**kwargs )
+def mv( x1,x2, *args,**kwargs ):
+ if isinstance(x1,ComplexTensor):
+ if isinstance(x2,ComplexTensor):
+ return ComplexTensor( torch.mv( x1.real,x2.real, *args,**kwargs ) - torch.mv( x1.imag,x2.imag, *args,**kwargs ), torch.mv( x1.real,x2.imag, *args,**kwargs ) + torch.mv( x1.imag,x2.real, *args,**kwargs ) )
+ else:
+ return ComplexTensor( torch.mv( x1.real,x2, *args,**kwargs ), torch.mv( x1.imag,x2, *args,**kwargs ) )
+ else:
+ if isinstance(x2,ComplexTensor):
+ return ComplexTensor( torch.mv( x1,x2.real, *args,**kwargs ), torch.mv( x1,x2.imag, *args,**kwargs ) )
+ else:
+ return torch.mv( x1,x2, *args,**kwargs )
+def mm( x1,x2, *args,**kwargs ):
+ if isinstance(x1,ComplexTensor):
+ if isinstance(x2,ComplexTensor):
+ return ComplexTensor( torch.mm( x1.real,x2.real, *args,**kwargs ) - torch.mm( x1.imag,x2.imag, *args,**kwargs ), torch.mm( x1.real,x2.imag, *args,**kwargs ) + torch.mm( x1.imag,x2.real, *args,**kwargs ) )
+ else:
+ return ComplexTensor( torch.mm( x1.real,x2, *args,**kwargs ), torch.mm( x1.imag,x2, *args,**kwargs ) )
+ else:
+ if isinstance(x2,ComplexTensor):
+ return ComplexTensor( torch.mm( x1,x2.real, *args,**kwargs ), torch.mm( x1,x2.imag, *args,**kwargs ) )
+ else:
+ return torch.mm( x1,x2, *args,**kwargs )
+
+
+def cat( xs, *args,**kwargs ):
+ if isinstance(xs[0],ComplexTensor):
+ xs_real = []; xs_imag = []
+ for x in xs:
+ xs_real.append(x.real)
+ xs_imag.append(x.imag)
+ return ComplexTensor( torch.cat(xs_real,*args,**kwargs), torch.cat(xs_imag,*args,**kwargs) )
+ else:
+ return torch.cat(xs,*args,**kwargs)
+
+
+
+def inverse(M):
+ if isinstance(M,ComplexTensor):
+ A=M.real
+ B=M.imag
+ tmp_AB = torch.mm(A.inverse(),B) # A^{-1} B
+ tmp_X = (A+torch.mm(B,tmp_AB)).inverse() # ( A + B A^{-1} B )^{-1}
+ return ComplexTensor( tmp_X, -torch.mm(tmp_AB,tmp_X) )
+ else:
+ return M.inverse()
\ No newline at end of file
diff --git a/SIAB/opt_orb_pytorch_dpsi/util.py b/SIAB/opt_orb_pytorch_dpsi/util.py
new file mode 100644
index 00000000..9b64a8e4
--- /dev/null
+++ b/SIAB/opt_orb_pytorch_dpsi/util.py
@@ -0,0 +1,42 @@
+def ND_list(*sizes,element=None):
+ size_1,*size_other = sizes
+ l = [element] * size_1
+ if size_other:
+ for i in range(len(l)):
+ l[i] = ND_list(*size_other,element=element)
+ else:
+ if element in ["dict()","list()"]:
+ for i in range(size_1):
+ l[i] = eval(element)
+ return l
+
+
+def ignore_line(file,N):
+ for _ in range(N):
+ file.readline()
+
+
+class Info:
+ def Nm(self,il): return 2*il+1
+ def __str__(self):
+ return "\n".join([name+"\t"+str(value) for name,value in self.__dict__.items()])
+ __repr__=__str__
+
+def change_to_cuda(s):
+ if isinstance(s,list):
+ return [change_to_cuda(x) for x in s]
+ elif isinstance(s,dict):
+ return {i:change_to_cuda(x) for i,x in s.items()}
+ elif isinstance(s,torch.Tensor):
+ return s.cuda()
+ elif isinstance(s,torch_complex.ComplexTensor):
+ return torch_complex.ComplexTensor( change_to_cuda(s.real), change_to_cuda(s.imag) )
+ else:
+ print(s)
+ raise TypeError("change_to_cuda")
+
+def update0(t):
+ return t.masked_fill(mask=(t==0), value=1E-10)
+
+def Nm(il):
+ return 2*il+1
\ No newline at end of file