diff --git a/SIAB/opt_orb_pytorch_dpsi/IO/cal_weight.py b/SIAB/opt_orb_pytorch_dpsi/IO/cal_weight.py new file mode 100644 index 00000000..7580d671 --- /dev/null +++ b/SIAB/opt_orb_pytorch_dpsi/IO/cal_weight.py @@ -0,0 +1,70 @@ +import IO.read_istate +import torch +import re +import functools +import operator + +def cal_weight(info_weight, flag_same_band, stru_file_list=None): + """ weight[ist][ib] """ + + if "bands_file" in info_weight.keys(): + if "bands_range" in info_weight.keys(): + raise IOError('"bands_file" and "bands_range" only once') + + weight = [] # weight[ist][ib] + for weight_stru, file_name in zip(info_weight["stru"], info_weight["bands_file"]): + occ = IO.read_istate.read_istate(file_name) + weight += [occ_k * weight_stru for occ_k in occ] + + elif "bands_range" in info_weight.keys(): + k_weight = read_k_weight(stru_file_list) # k_weight[ist][ik] + nbands = read_nbands(stru_file_list) # nbands[ist] + + st_weight = [] # st_weight[ist][ib] + for weight_stru, bands_range, nbands_ist in zip(info_weight["stru"], info_weight["bands_range"], nbands): + st_weight_tmp = torch.zeros((nbands_ist,)) + st_weight_tmp[:bands_range] = weight_stru + st_weight.append( st_weight_tmp ) + + weight = [] # weight[ist][ib] + for ist,_ in enumerate(k_weight): + for ik,_ in enumerate(k_weight[ist]): + weight.append(st_weight[ist] * k_weight[ist][ik]) + + else: + raise IOError('"bands_file" and "bands_range" must once') + + + if not flag_same_band: + for ist,_ in enumerate(weight): + weight[ist] = torch.tensordot(weight[ist], weight[ist], dims=0) + + + normalization = functools.reduce(operator.add, map(torch.sum, weight), 0) + weight = list(map(lambda x:x/normalization, weight)) + + return weight + + +def read_k_weight(stru_file_list): + """ weight[ist][ik] """ + weight = [] # weight[ist][ik] + for file_name in stru_file_list: + weight_k = [] # weight_k[ik] + with open(file_name,"r") as file: + data = re.compile(r"(.+)", re.S).search(file.read()).group(1).split("\n") + for line in data: + line = line.strip() + if line: + weight_k.append(float(line.split()[-1])) + weight.append(weight_k) + return weight + + +def read_nbands(stru_file_list): + """ nbands[ib] """ + nbands = [] + for file_name in stru_file_list: + with open(file_name,"r") as file: + nbands.append(int(re.compile(r"(\d+)\s+nbands").search(file.read()).group(1))) + return nbands diff --git a/SIAB/opt_orb_pytorch_dpsi/IO/change_info.py b/SIAB/opt_orb_pytorch_dpsi/IO/change_info.py new file mode 100644 index 00000000..e9012042 --- /dev/null +++ b/SIAB/opt_orb_pytorch_dpsi/IO/change_info.py @@ -0,0 +1,98 @@ +import addict +import util +import itertools + +def change_info(info_old, weight_old): + info_stru = [None] * info_old.Nst + for ist in range(len(info_stru)): + info_stru[ist] = addict.Dict() + for ist,Na in enumerate(info_old.Na): + info_stru[ist].Na = Na + for ist,weight in enumerate(weight_old): + info_stru[ist].weight = weight + info_stru[ist].Nb = weight.shape[0] + for ib in range(weight.shape[0], 0, -1): + if weight[ib-1]>0: + info_stru[ist].Nb_true = ib + break + + info_element = addict.Dict() + for it_index,it in enumerate(info_old.Nt_all): + info_element[it].index = it_index + for it,Nu in info_old.Nu.items(): + info_element[it].Nu = Nu + info_element[it].Nl = len(Nu) + for it,Rcut in info_old.Rcut.items(): + info_element[it].Rcut = Rcut + for it,dr in info_old.dr.items(): + info_element[it].dr = dr + for it,Ecut in info_old.Ecut.items(): + info_element[it].Ecut = Ecut + for it,Ne in info_old.Ne.items(): + info_element[it].Ne = Ne + + info_opt = addict.Dict() + info_opt.lr = info_old.lr + info_opt.cal_T = info_old.cal_T + info_opt.cal_smooth = info_old.cal_smooth + + return info_stru, info_element, info_opt + + """ + info_stru = + [{'Na': {'C': 1}, + 'Nb': 6, + 'Nb_true': 4, + 'weight': tensor([0.0333, 0.0111, 0.0111, 0.0111, 0.0000, 0.0000])}, + {'Na': {'C': 1}, + 'Nb': 6, + 'Nb_true': 2, + 'weight': tensor([0.0667, 0.0667, 0.0000, 0.0000, 0.0000, 0.0000])}, + {'Na': {'C': 1, 'O': 2}, + 'Nb': 10, + 'Nb_true': 8, + 'weight': tensor([0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.0000, 0.0000])}] + + info_element = + {'C': { + 'Ecut': 200, + 'Ne': 19, + 'Nl': 3, + 'Nu': [2, 2, 1], + 'Rcut': 6, + 'dr': 0.01, + 'index': 0}, + 'O': { + 'Ecut': 200, + 'Ne': 19, + 'Nl': 3, + 'Nu': [3, 2, 1], + 'Rcut': 6, + 'dr': 0.01, + 'index': 1}} + + info_opt = + {'cal_T': False, + 'cal_smooth': False, + 'lr': 0.01} + """ + + +def get_info_max(info_stru, info_element): + info_max = [None] * len(info_stru) + for ist in range(len(info_stru)): + Nt = info_stru[ist].Na.keys() + info_max[ist] = addict.Dict() + info_max[ist].Nt = len(Nt) + info_max[ist].Na = max((info_stru[ist].Na[it] for it in Nt)) + info_max[ist].Nl = max([info_element[it].Nl for it in Nt]) + info_max[ist].Nm = max((util.Nm(info_element[it].Nl-1) for it in Nt)) + info_max[ist].Nu = max(itertools.chain.from_iterable([info_element[it].Nu for it in Nt])) + info_max[ist].Ne = max((info_element[it].Ne for it in Nt)) + info_max[ist].Nb = info_stru[ist].Nb + return info_max + + """ + [{'Na': 2, 'Nb': 6, 'Ne': 19, 'Nl': 3, 'Nm': 5, 'Nt': 1, 'Nu': 2}, + {'Na': 2, 'Nb': 6, 'Ne': 19, 'Nl': 3, 'Nm': 5, 'Nt': 1, 'Nu': 2}] + """ diff --git a/SIAB/opt_orb_pytorch_dpsi/IO/func_C.py b/SIAB/opt_orb_pytorch_dpsi/IO/func_C.py new file mode 100644 index 00000000..95a76035 --- /dev/null +++ b/SIAB/opt_orb_pytorch_dpsi/IO/func_C.py @@ -0,0 +1,99 @@ +from util import * +import torch +import numpy as np + +def random_C_init(info_element): + """ C[it][il][ie,iu] """ + C = dict() + for it in info_element.keys(): + C[it] = ND_list(info_element[it].Nl) + for il in range(info_element[it].Nl): + C[it][il] = torch.tensor(np.random.uniform(-1,1, (info_element[it].Ne, info_element[it].Nu[il])), dtype=torch.float64, requires_grad=True) + return C + + + +def read_C_init(file_name,info_element): + """ C[it][il][ie,iu] """ + C = random_C_init(info_element) + + with open(file_name,"r") as file: + + for line in file: + if line.strip() == "": + line=None + break + ignore_line(file,1) + + C_read_index = set() + while True: + line = file.readline().strip() + if line.startswith("Type"): + it,il,iu = file.readline().split(); + il = int(il) + iu = int(iu)-1 + C_read_index.add((it,il,iu)) + line = file.readline().split() + for ie in range(info_element[it].Ne): + if not line: line = file.readline().split() + C[it][il].data[ie,iu] = float(line.pop(0)) + elif line.startswith(""): + break; + else: + raise IOError("unknown line in read_C_init "+file_name+"\n"+line) + return C, C_read_index + + + +def copy_C(C,info_element): + C_copy = dict() + for it in info_element.keys(): + C_copy[it] = ND_list(info_element[it].Nl) + for il in range(info_element[it].Nl): + C_copy[it][il] = C[it][il].clone() + return C_copy + + + +def write_C(file_name,C,Spillage): + with open(file_name,"w") as file: + print("", file=file) + #print("\tTotal number of radial orbitals.", file=file) + nTotal = 0 + for it,C_t in C.items(): + for il,C_tl in enumerate(C_t): + for iu in range(C_tl.size()[1]): + nTotal += 1 + #nTotal = sum(info["Nu"][it]) + print("\t %s Total number of radial orbitals."%nTotal , file=file) + #print("\tTotal number of radial orbitals.", file=file) + for it,C_t in C.items(): + for il,C_tl in enumerate(C_t): + for iu in range(C_tl.size()[1]): + print("\tType\tL\tZeta-Orbital", file=file) + print(f"\t {it} \t{il}\t {iu+1}", file=file) + for ie in range(C_tl.size()[0]): + print("\t", '%18.14f'%C_tl[ie,iu].item(), file=file) + print("", file=file) + print("", file=file) + print("Left spillage = %.10e"%Spillage.item(), file=file) + print("", file=file) + + +#def init_C(info): +# """ C[it][il][ie,iu] """ +# C = ND_list(max(info.Nt)) +# for it in range(len(C)): +# C[it] = ND_list(info.Nl[it]) +# for il in range(info.Nl[it]): +# C[it][il] = torch.autograd.Variable( torch.Tensor( info.Ne, info.Nu[it][il] ), requires_grad = True ) +# +# with open("C_init.dat","r") as file: +# line = [] +# for it in range(len(C)): +# for il in range(info.Nl[it]): +# for i_n in range(info.Nu[it][il]): +# for ie in range(info.Ne[it]): +# if not line: line=file.readline().split() +# C[it][il].data[ie,i_n] = float(line.pop(0)) +# return C diff --git a/SIAB/opt_orb_pytorch_dpsi/IO/print_QSV.py b/SIAB/opt_orb_pytorch_dpsi/IO/print_QSV.py new file mode 100644 index 00000000..1b0f88c5 --- /dev/null +++ b/SIAB/opt_orb_pytorch_dpsi/IO/print_QSV.py @@ -0,0 +1,29 @@ +def print_V(V,file_name): + """ V[ist][ib] """ + with open(file_name,"w") as file: + for V_s in V: + for V_sb in V_s: + print(1-V_sb.item(),end="\t",file=file) + print(file=file) + +def print_S(S,file_name): + """ S[ist][it1,it2][il1][il2][ia1*im1*in1,ia2*im2*in2] """ + with open(file_name,"w") as file: + for ist,S_s in enumerate(S): + for (it1,it2),S_tt in S_s.items(): + for il1,S_ttl in enumerate(S_tt): + for il2,S_ttll in enumerate(S_ttl): + print(ist,it1,it2,il1,il2,file=file) + print(S_ttll.real.numpy(),file=file) + print(S_ttll.imag.numpy(),"\n",file=file) + +def print_Q(Q,file_name): + """ Q[ist][it][il][ib,ia*im*iu] """ + with open(file_name,"w") as file: + for ist,Q_s in enumerate(Q): + for it,Q_st in Q_s.items(): + for il,Q_stl in enumerate(Q_st): + print(ist,it,il,file=file) + print(Q_stl.real.numpy(),file=file) + print(Q_stl.imag.numpy(),"\n",file=file) + \ No newline at end of file diff --git a/SIAB/opt_orb_pytorch_dpsi/IO/print_orbital.py b/SIAB/opt_orb_pytorch_dpsi/IO/print_orbital.py new file mode 100644 index 00000000..f3d0aade --- /dev/null +++ b/SIAB/opt_orb_pytorch_dpsi/IO/print_orbital.py @@ -0,0 +1,70 @@ +periodtable = { 'H': 1, 'He': 2, 'Li': 3, 'Be': 4, 'B': 5, 'C': 6, 'N': 7, + 'O': 8, 'F': 9, 'Ne': 10, 'Na': 11, 'Mg': 12, 'Al': 13, + 'Si': 14, 'P': 15, 'S': 16, 'Cl': 17, 'Ar': 18, 'K': 19, + 'Ca': 20, 'Sc': 21, 'Ti': 22, 'V': 23, 'Cr': 24, 'Mn': 25, + 'Fe': 26, 'Co': 27, 'Ni': 28, 'Cu': 29, 'Zn': 30, 'Ga': 31, + 'Ge': 32, 'As': 33, 'Se': 34, 'Br': 35, 'Kr': 36, 'Rb': 37, + 'Sr': 38, 'Y': 39, 'Zr': 40, 'Nb': 41, 'Mo': 42, 'Tc': 43, + 'Ru': 44, 'Rh': 45, 'Pd': 46, 'Ag': 47, 'Cd': 48, 'In': 49, + 'Sn': 50, 'Sb': 51, 'Te': 52, 'I': 53, 'Xe': 54, 'Cs': 55, + 'Ba': 56, #'La': 57, 'Ce': 58, 'Pr': 59, 'Nd': 60, 'Pm': 61, + ## 'Sm': 62, 'Eu': 63, 'Gd': 64, 'Tb': 65, 'Dy': 66, 'Ho': 67, + ## 'Er': 68, 'Tm': 69, 'Yb': 70, + ## 'Lu': 71, + 'Hf': 72, 'Ta': 73, + 'W': 74, 'Re': 75, 'Os': 76, 'Ir': 77, 'Pt': 78, 'Au': 79, + 'Hg': 80, 'Tl': 81, 'Pb': 82, 'Bi': 83, + ## 'Po': 84, #'At': 85, + ## 'Rn': 86, #'Fr': 87, 'Ra': 88, 'Ac': 89, 'Th': 90, 'Pa': 91, + ## 'U': 92, 'Np': 93, 'Pu': 94, 'Am': 95, 'Cm': 96, 'Bk': 97, + ## 'Cf': 98, 'Es': 99, 'Fm': 100, 'Md': 101, 'No': 102, 'Lr': 103, + ## 'Rf': 104, 'Db': 105, 'Sg': 106, 'Bh': 107, 'Hs': 108, + ## 'Mt': 109, 'Ds': 110, 'Rg': 111, 'Cn': 112, 'Uut': 113, + ## 'Fl': 114, 'Uup': 115, 'Lv': 116, 'Uus': 117, 'Uuo': 118 + } + +def print_orbital(orb,info_element): + """ orb[it][il][iu][r] """ + for it,orb_t in orb.items(): + #with open("orb_{0}.dat".format(it),"w") as file: + with open("ORBITAL_{0}U.dat".format( periodtable[it] ),"w") as file: + print_orbital_head(file,info_element,it) + for il,orb_tl in enumerate(orb_t): + for iu,orb_tlu in enumerate(orb_tl): + print(""" Type L N""",file=file) + print(""" 0 {0} {1}""".format(il,iu),file=file) + for ir,orb_tlur in enumerate(orb_tlu): + print( '%.14e'%orb_tlur, end=" ",file=file) + if ir%4==3: print(file=file) + print(file=file) + + +def plot_orbital(orb,Rcut,dr): + for it,orb_t in orb.items(): + #with open("orb_{0}_plot.dat".format(it),"w") as file: + with open("ORBITAL_PLOTU.dat", "w") as file: + Nr = int(Rcut[it]/dr[it])+1 + for ir in range(Nr): + print( '%10.6f'%(ir*dr[it]),end=" ",file=file) + for il,orb_tl in enumerate(orb_t): + for orb_tlu in orb_tl: + print( '%18.14f'%orb_tlu[ir],end=" ",file=file) + print(file=file) + + +def print_orbital_head(file,info_element,it): + print( "---------------------------------------------------------------------------", file=file ) + print( "Element {0}".format(it), file=file ) + print( "Energy Cutoff(Ry) {0}".format(info_element[it].Ecut), file=file ) + print( "Radius Cutoff(a.u.) {0}".format(info_element[it].Rcut), file=file ) + print( "Lmax {0}".format(info_element[it].Nl-1), file=file ) + l_name = ["S","P","D"]+list(map(chr,range(ord('F'),ord('Z')+1))) + for il,iu in enumerate(info_element[it].Nu): + print( "Number of {0}orbital--> {1}".format(l_name[il],iu), file=file ) + print( "---------------------------------------------------------------------------", file=file ) + print( "SUMMARY END", file=file ) + print( file=file ) + print( "Mesh {0}".format(int(info_element[it].Rcut/info_element[it].dr)+1), file=file ) + print( "dr {0}".format(info_element[it].dr), file=file ) + + diff --git a/SIAB/opt_orb_pytorch_dpsi/IO/read_QSV.py b/SIAB/opt_orb_pytorch_dpsi/IO/read_QSV.py new file mode 100644 index 00000000..b959d9ce --- /dev/null +++ b/SIAB/opt_orb_pytorch_dpsi/IO/read_QSV.py @@ -0,0 +1,160 @@ +import util +import torch +import itertools +import numpy as np +import re +import copy + +def read_file_head(info,file_list): + """ QI[ist][it][il][ib*ia*im,ie] <\psi|jY> """ + """ SI[ist][it1][it2][il1][il2][ie1,ia1,im1,ia2,im2,ie2] """ + """ VI[ist][ib] <\psi|\psi> """ + info_true = copy.deepcopy(info) + info_true.Nst = len(file_list) + info_true.Nt = util.ND_list(info_true.Nst,element="list()") + info_true.Na = util.ND_list(info_true.Nst,element="dict()") + info_true.Nb = util.ND_list(info_true.Nst) + info_true.Nk = util.ND_list(info_true.Nst) + info_true.Ne = dict() + + for ist_true,file_name in enumerate(file_list): + print(file_name) + with open(file_name,"r") as file: + + util.ignore_line(file,4) + Nt_tmp = int(file.readline().split()[0]) + for it in range(Nt_tmp): + t_tmp = file.readline().split()[0] + assert t_tmp in info.Nt_all + info_true.Nt[ist_true].append( t_tmp ) + info_true.Na[ist_true][t_tmp] = int(file.readline().split()[0]) + util.ignore_line( file, info_true.Na[ist_true][t_tmp] ) + util.ignore_line(file,6) + Nl_ist = int(file.readline().split()[0])+1 + for it,Nl_C in info.Nl.items(): + print(it,Nl_ist,Nl_C) + assert Nl_ist>=Nl_C + info_true.Nl[it] = Nl_ist + info_true.Nk[ist_true] = int(file.readline().split()[0]) + info_true.Nb[ist_true] = int(file.readline().split()[0]) + util.ignore_line(file,1) + #Ne_tmp = list(map(int,file.readline().split()[:Nt_tmp])) + #for it,Ne in zip(info_true.Nt[ist_true],Ne_tmp): + # assert info_true.Ne.setdefault(it,Ne)==Ne + Ne_tmp = int(file.readline().split()[0]) + for it in info_true.Nt[ist_true]: + info_true.Ne[it] = Ne_tmp + + info_all = copy.deepcopy(info) + info_all.Nst = sum(info_true.Nk,0) + repeat_Nk = lambda x: list( itertools.chain.from_iterable( map( lambda x:itertools.repeat(*x), zip(x,info_true.Nk) ) ) ) + info_all.Nt = repeat_Nk(info_true.Nt) + info_all.Na = repeat_Nk(info_true.Na) + info_all.Nb = repeat_Nk(info_true.Nb) + info_all.Ne = info_true.Ne + + return info_all + + +def read_QSV(info_stru, info_element, file_list, V_info): + QI=[]; SI=[]; VI=[] + ist = 0 + for ist_true,file_name in enumerate(file_list): + with open(file_name,"r") as file: + Nk = int(re.compile(r"(\d+)\s+nks").search(file.read()).group(1)) + with open(file_name,"r") as file: + data = re.compile(r"(.+)", re.S).search(file.read()) + data = map(float,data.group(1).split()) + for ik in range(Nk): + print("read QI:",ist_true,ik) + qi = read_QI(info_stru[ist+ik], info_element, data) + QI.append( qi ) + with open(file_name,"r") as file: + data = re.compile(r"(.+)", re.S).search(file.read()) + data = map(float,data.group(1).split()) + for ik in range(Nk): + print("read SI:",ist_true,ik) + si = read_SI(info_stru[ist+ik], info_element, data) + SI.append( si ) + if V_info["init_from_file"]: + with open(file_name,"r") as file: + data = re.compile(r"(.+)", re.S).search(file.read()) + data = map(float,data.group(1).split()) + else: + data = () + for ik in range(Nk): + print("read VI:",ist_true,ik) + vi = read_VI(info_stru[ist+ik], V_info, ist_true, data) + VI.append( vi ) + ist += Nk + print() + return QI,SI,VI + + +def read_QI(info_stru, info_element, data): + """ QI[it][il][ib*ia*im,ie] <\psi|jY> """ + QI = dict() + for it in info_stru.Na.keys(): + QI[it] = util.ND_list(info_element[it].Nl) + for il in range(info_element[it].Nl): + QI[it][il] = torch.zeros((info_stru.Nb, info_stru.Na[it], util.Nm(il), info_element[it].Ne), dtype=torch.complex128) + for ib in range(info_stru.Nb): + for it in info_stru.Na.keys(): + for ia in range(info_stru.Na[it]): + for il in range(info_element[it].Nl): + for im in range(util.Nm(il)): + for ie in range(info_element[it].Ne): + QI[it][il][ib,ia,im,ie] = complex(next(data), next(data)) + for it in info_stru.Na.keys(): + for il in range(info_element[it].Nl): + QI[it][il] = QI[it][il][:info_stru.Nb_true,:,:,:].view(-1,info_element[it].Ne).conj() + return QI + + +def read_SI(info_stru, info_element, data): + """ SI[it1,it2][il1][il2][ie1,ia1,im1,ia2,im2,ie2] """ + SI = dict() + for it1,it2 in itertools.product( info_stru.Na.keys(), info_stru.Na.keys() ): + SI[it1,it2] = util.ND_list(info_element[it1].Nl, info_element[it2].Nl) + for il1,il2 in itertools.product( range(info_element[it1].Nl), range(info_element[it2].Nl) ): + SI[it1,it2][il1][il2] = torch.zeros((info_stru.Na[it1], util.Nm(il1), info_element[it1].Ne, info_stru.Na[it2], util.Nm(il2), info_element[it2].Ne), dtype=torch.complex128) + for it1 in info_stru.Na.keys(): + for ia1 in range(info_stru.Na[it1]): + for il1 in range(info_element[it1].Nl): + for im1 in range(util.Nm(il1)): + for it2 in info_stru.Na.keys(): + for ia2 in range(info_stru.Na[it2]): + for il2 in range(info_element[it2].Nl): + for im2 in range(util.Nm(il2)): + for ie1 in range(info_element[it1].Ne): + for ie2 in range(info_element[it2].Ne): + SI[it1,it2][il1][il2][ia1,im1,ie1,ia2,im2,ie2] = complex(next(data), next(data)) +# for it1,it2 in itertools.product( info.Nt[ist], info.Nt[ist] ): +# for il1,il2 in itertools.product( range(info.Nl[it1]), range(info.Nl[it2]) ): +# SI[it1,it2][il1][il2] = torch_complex.ComplexTensor( +# torch.from_numpy(SI[it1,it2][il1][il2].real), +# torch.from_numpy(SI[it1,it2][il1][il2].imag)) + return SI + + + +def read_VI(info_stru,V_info,ist,data): + if V_info["same_band"]: + """ VI[ib] """ + if V_info["init_from_file"]: + VI = np.empty(info_stru.Nb,dtype=np.float64) + for ib in range(info_stru.Nb): + VI.data[ib] = next(data) + VI = VI[:info_stru.Nb_true] + else: + VI = np.ones(info_stru.Nb_true, dtype=np.float64) + else: + """ VI[ib1,ib2] """ + if V_info["init_from_file"]: + VI = np.empty((info_stru.Nb,info_stru.Nb),dtype=np.float64) + for ib1,ib2 in itertools.product( range(info_stru.Nb), range(info_stru.Nb) ): + VI[ib1,ib2] = next(data) + VI = VI[info_stru.Nb_true, info_stru.Nb_true] + else: + VI = np.eye(info_stru.Nb_true, info_stru.Nb_true, dtype=np.float64) + return torch.from_numpy(VI) diff --git a/SIAB/opt_orb_pytorch_dpsi/IO/read_istate.py b/SIAB/opt_orb_pytorch_dpsi/IO/read_istate.py new file mode 100644 index 00000000..e1e7277e --- /dev/null +++ b/SIAB/opt_orb_pytorch_dpsi/IO/read_istate.py @@ -0,0 +1,42 @@ +import re +import torch +import itertools + +# occ[ik][ib] +def read_istate(file_name): + nspin0 = get_nspin0(file_name) + if nspin0==1: occ = [[]] + elif nspin0==2: occ = [[],[]] + with open(file_name,"r") as file: + content = file.read().split("BAND") + for content_k in content[1:]: + content_k = content_k.split("\n") + k = get_k(content_k[0]) + for ispin in range(nspin0): + occ[ispin].append([]) + for line in content_k[1:]: + line = line.strip() + if line: + line = line.split() + if nspin0==1: + occ[0][-1].append(float(line[2])) + elif nspin0==2: + occ[0][-1].append(float(line[2])) + occ[1][-1].append(float(line[4])) + for ispin in range(nspin0): + occ[ispin][-1] = torch.Tensor(occ[ispin][-1]) + occ = list(itertools.chain(*occ)) + return occ + +def get_k(line): + k = re.compile(r"Kpoint\s*=\s*(\d+)").search(line).group(1) + return int(k) + +def get_nspin0(file_name): + with open(file_name,"r") as file: + file.readline() + line = file.readline() + lens = len(line.split()) + if lens == 3: return 1 + elif lens == 5: return 2 + else: raise \ No newline at end of file diff --git a/SIAB/opt_orb_pytorch_dpsi/IO/read_json.py b/SIAB/opt_orb_pytorch_dpsi/IO/read_json.py new file mode 100644 index 00000000..67705b4b --- /dev/null +++ b/SIAB/opt_orb_pytorch_dpsi/IO/read_json.py @@ -0,0 +1,82 @@ +import json +from util import Info + +def read_json(file_name): + + with open(file_name,"r") as file: + input = file.read() + input = json.loads(input) + + info = Info() + for info_attr,info_value in input["info"].items(): + info.__dict__[info_attr] = info_value + info.Nl = { it:len(Nu) for it,Nu in info.Nu.items() } + + return input["file_list"], info, input["weight"], input["C_init_info"], input["V_info"] + + """ file_name + { + "file_list": { + "origin": [ + "~/C_bulk/orb_matrix/test.0.dat", + "~/CO2/orb_matrix/test.0.dat" + ], + "linear": [ + [ + "~/C_bulk/orb_matrix/test.1.dat", + "~/CO2/orb_matrix/test.1.dat" + ], + [ + "~/C_bulk/orb_matrix/test.2.dat", + "~/CO2/orb_matrix/test.2.dat" + ], + ] + }, + "info": { + "Nt_all": [ "C", "O" ], + "Nu": { "C":[2,2,1], "O":[3,2,1] }, + "Rcut": { "C":6, "O":6 }, + "dr": { "C":0.01, "O":0.01 }, + "Ecut": { "C":200, "O":200 }, + "lr": 0.01, + "cal_T": false, + "cal_smooth": false + }, + "weight": + { + "stru": [1, 2.3], + "bands_range": [10, 15], # "bands_range" and "bands_file" only once + "bands_file": + [ + "~/C_bulk/OUT.ABACUS/istate.info", + "~/CO2/OUT.ABACUS/istate.info" + ] + }, + "C_init_info": { + "init_from_file": false, + "C_init_file": "~/CO/ORBITAL_RESULTS.txt", + "opt_C_read": false + }, + "V_info": { + "init_from_file": true, + "same_band": true + } + } + """ + + """ info + Nt_all ['C', 'O'] + Nu {'C': [2, 2, 1], 'O': [3, 2, 1]} + Rcut {'C': 6, 'O': 6} + dr {'C': 0.01, 'O': 0.01} + Ecut {'C': 200, 'O': 200} + lr 0.01 + cal_T False + cal_smooth False + Nl {'C': 3, 'O': 3} + Nst 3 + Nt [['C'], ['C'], ['C', 'O']] + Na [{'C': 1}, {'C': 1}, {'C': 1, 'O': 2}] + Nb [6, 6, 10] + Ne {'C': 19, 'O': 19} + """ \ No newline at end of file diff --git a/SIAB/opt_orb_pytorch_dpsi/main.py b/SIAB/opt_orb_pytorch_dpsi/main.py new file mode 100755 index 00000000..a3e17e91 --- /dev/null +++ b/SIAB/opt_orb_pytorch_dpsi/main.py @@ -0,0 +1,161 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +import IO.read_QSV +import IO.print_QSV +import IO.func_C +import IO.read_json +import IO.print_orbital +import opt_orbital +import orbital +import torch +import numpy as np +import time +import torch_optimizer +import IO.cal_weight +import util +import IO.change_info +import pprint + +def main(): + seed = int(1000*time.time())%(2**32) + np.random.seed(seed) + print("seed:",seed) + time_start = time.time() + + file_list, info_true, weight_info, C_init_info, V_info = IO.read_json.read_json("INPUT") + + weight = IO.cal_weight.cal_weight(weight_info, V_info["same_band"], file_list["origin"]) + + info_kst = IO.read_QSV.read_file_head(info_true,file_list["origin"]) + + info_stru, info_element, info_opt = IO.change_info.change_info(info_kst,weight) + info_max = IO.change_info.get_info_max(info_stru, info_element) + + print("info_kst:", info_kst, sep="\n", end="\n"*2, flush=True) + print("info_stru:", pprint.pformat(info_stru), sep="\n", end="\n"*2, flush=True) + print("info_element:", pprint.pformat(info_element,width=40), sep="\n", end="\n"*2, flush=True) + print("info_opt:", pprint.pformat(info_opt,width=40), sep="\n", end="\n"*2, flush=True) + print("info_max:", pprint.pformat(info_max), sep="\n", end="\n"*2, flush=True) + + QI,SI,VI_origin = IO.read_QSV.read_QSV(info_stru, info_element, file_list["origin"], V_info) + if "linear" in file_list.keys(): + QI_linear, SI_linear, VI_linear = list(zip(*( IO.read_QSV.read_QSV(info_stru, info_element, file, V_info) for file in file_list["linear"] ))) + + if C_init_info["init_from_file"]: + C, C_read_index = IO.func_C.read_C_init( C_init_info["C_init_file"], info_element ) + else: + C = IO.func_C.random_C_init(info_element) + E = orbital.set_E(info_element) + orbital.normalize( + orbital.generate_orbital(info_element,C,E), + {it:info_element[it].dr for it in info_element}, + C, flag_norm_C=True) + + opt_orb = opt_orbital.Opt_Orbital() + + #opt = torch.optim.Adam(sum( ([c.real,c.imag] for c in sum(C,[])), []), lr=info_opt.lr, eps=1e-8) + #opt = torch.optim.Adam( sum(C.values(),[]), lr=info_opt.lr, eps=1e-20, weight_decay=info_opt.weight_decay) + #opt = radam.RAdam( sum(C.values(),[]), lr=info_opt.lr, eps=1e-20 ) + opt = torch_optimizer.SWATS( sum(C.values(),[]), lr=info_opt.lr, eps=1e-20 ) + + flag_finish = 0 + C_old = 0 + with open("Spillage.dat","w") as S_file: + + print( "\nSee \"Spillage.dat\" for detail status: " , flush=True ) + if info_opt.cal_T: + print( '%5s'%"istep", "%20s"%"Spillage", "%20s"%"T.item()", "%20s"%"Loss", flush=True ) + else: + print( '%5s'%"istep", "%20s"%"Spillage", flush=True ) + + loss_old = np.inf + for istep in range(10000): + + Spillage = 0 + for ist in range(len(info_stru)): + + Q = opt_orb.change_index_Q(opt_orb.cal_Q(QI[ist],C,info_stru[ist],info_element),info_stru[ist]) + S = opt_orb.change_index_S(opt_orb.cal_S(SI[ist],C,info_stru[ist],info_element),info_stru[ist],info_element) + coef = opt_orb.cal_coef(Q,S) + V = opt_orb.cal_V(coef,Q) + V_origin = opt_orb.cal_V_origin(V,V_info) + + if "linear" in file_list.keys(): + V_linear = [None] * len(file_list["linear"]) + for i in range(len(file_list["linear"])): + Q_linear = opt_orb.change_index_Q(opt_orb.cal_Q(QI_linear[i][ist],C,info_stru[ist],info_element),info_stru[ist]) + S_linear = opt_orb.change_index_S(opt_orb.cal_S(SI_linear[i][ist],C,info_stru[ist],info_element),info_stru[ist],info_element) + V_linear[i] = opt_orb.cal_V_linear(coef,Q_linear,S_linear,V,V_info) + + def cal_Spillage(V_delta): + Spillage = (V_delta * weight[ist][:info_stru[ist].Nb_true]).sum() + return Spillage + + def cal_delta(VI, V): + return ((VI[ist]-V)/util.update0(VI[ist])).abs() # abs or **2? + + Spillage += 2*cal_Spillage(cal_delta(VI_origin,V_origin)) + if "linear" in file_list.keys(): + for i in range(len(file_list["linear"])): + Spillage += cal_Spillage(cal_delta(VI_linear[i],V_linear[i])) + + if info_opt.cal_T: + T = opt_orb.cal_T(C,E) + if not "TSrate" in vars(): TSrate = torch.abs(0.002*Spillage/T).data[0] + Loss = Spillage + TSrate*T + else: + Loss = Spillage + + if info_opt.cal_T: + print_content = [istep, Spillage.item(), T.item(), Loss.item()] + else: + print_content = [istep, Spillage.item()] + print(*print_content, sep="\t", file=S_file, flush=True) + if not istep%100: + print(*print_content, sep="\t", flush=True) + + if Loss.item() < loss_old: + loss_old = Loss.item() + C_old = IO.func_C.copy_C(C,info_element) + flag_finish = 0 + else: + flag_finish += 1 + if flag_finish > 50: + break + + opt.zero_grad() + Loss.backward() + if C_init_info["init_from_file"] and not C_init_info["opt_C_read"]: + for it,il,iu in C_read_index: + C[it][il].grad[:,iu] = 0 + opt.step() + #orbital.normalize( + # orbital.generate_orbital(info_element,C,E), + # {it:info_element[it].dr for it in info_element}, + # C, flag_norm_C=True) + + orb = orbital.generate_orbital(info_element,C_old,E) + if info_opt.cal_smooth: + orbital.smooth_orbital( + orb, + {it:info_element[it].Rcut for it in info_element}, {it:info_element[it].dr for it in info_element}, + 0.1) + orbital.orth( + orb, + {it:info_element[it].dr for it in info_element}) + IO.print_orbital.print_orbital(orb,info_element) + IO.print_orbital.plot_orbital( + orb, + {it:info_element[it].Rcut for it in info_element}, + {it:info_element[it].dr for it in info_element}) + + IO.func_C.write_C("ORBITAL_RESULTS.txt",C_old,Spillage) + + print("Time (PyTorch): %s\n"%(time.time()-time_start), flush=True ) + + +if __name__=="__main__": + import sys + np.set_printoptions(threshold=sys.maxsize, linewidth=10000) + print( sys.version, flush=True ) + main() diff --git a/SIAB/opt_orb_pytorch_dpsi/opt_orbital.py b/SIAB/opt_orb_pytorch_dpsi/opt_orbital.py new file mode 100644 index 00000000..342a634a --- /dev/null +++ b/SIAB/opt_orb_pytorch_dpsi/opt_orbital.py @@ -0,0 +1,178 @@ +from util import ND_list +import util +import functools +import itertools +import torch + +class Opt_Orbital: + + def cal_Q(self,QI,C,info_stru,info_element): + """ + <\psi|\phi> = <\psi|jY> * + Q[it][il][ib,ia*im*iu] + = sum_{q} QI[it][il][ib*ia*im,ie] * C[it][il][ie,iu] + """ + Q = dict() + for it in info_stru.Na.keys(): + Q[it] = ND_list(info_element[it].Nl) + + for it in info_stru.Na.keys(): + for il in range(info_element[it].Nl): + Q[it][il] = torch.mm( QI[it][il], C[it][il].to(torch.complex128) ).view(info_stru.Nb_true,-1) + return Q + + + + def cal_S(self,SI,C,info_stru,info_element): + """ + <\phi|\phi> = <\phi|jY> * * + S[it1,it2][il1][il2][ia1*im1*iu1,ia2*im2*iu2] + = sum_{ie1 ie2} C^*[it1][il1][ie1,iu1] * SI[it1,it2][il1][il2][ia1,im1,ie1,ia2,im2,ie2] * C[it2][[il2][ie2,iu2] + """ + S = dict() + for it1,it2 in itertools.product( info_stru.Na.keys(), info_stru.Na.keys() ): + S[it1,it2] = ND_list(info_element[it1].Nl, info_element[it2].Nl) + + for it1,it2 in itertools.product( info_stru.Na.keys(), info_stru.Na.keys() ): + for il1,il2 in itertools.product( range(info_element[it1].Nl), range(info_element[it2].Nl) ): + # SI_C[ia1*im1*ie1*ia2*im2,iu2] + SI_C = torch.mm( + SI[it1,it2][il1][il2].view(-1,info_element[it2].Ne), + C[it2][il2].to(torch.complex128) ) + # SI_C[ia1*im1,ie1,ia2*im2*iu2] + SI_C = SI_C.view( info_stru.Na[it1]*util.Nm(il1), info_element[it1].Ne, -1 ) + # Ct[iu1,ie1] + Ct = C[it1][il1].t().to(torch.complex128) + C_mm = functools.partial(torch.mm,Ct) + # C_SI_C[ia1*im1][iu1,ia2*im2*iu2] + C_SI_C = list(map( C_mm, SI_C )) + # C_SI_C[ia1*im1*iu1,ia2*im2*iu2] + C_SI_C = torch.cat( C_SI_C, dim=0 ) +#??? C_SI_C = C_SI_C.view(info_stru.Na[it1]*util.Nm(il1)*info_element[it1].Nu[il1],-1) + S[it1,it2][il1][il2] = C_SI_C + return S + + + + def change_index_S(self,S,info_stru,info_element): # S[it1,it2][il1][il2][ia1*im1*iu1,ia2*im2*iu2] + """ + <\phi|\phi> + S_cat[it1*il1*iat*im1*iu1,iat2*il2*ia2*im2*iu2] + """ + # S_[it1][il1*ia1*im1*iu1,it2*il2*ia2*im2*iu2] + S_ = dict() + for it1 in info_stru.Na.keys(): + # S_t[it2][il1*ia1*im1*iu1,il2*ia2*im2*iu2] + S_t = dict() + for it2 in info_stru.Na.keys(): + # S_tt[il1][ia1*im1*iu1,il2*ia2*im2*iu2] + S_tt = ND_list(info_element[it1].Nl) + for il1 in range(info_element[it1].Nl): + S_tt[il1] = torch.cat( S[it1,it2][il1], dim=1 ) + S_t[it2] = torch.cat( S_tt, dim=0 ) + S_[it1] = torch.cat( list(S_t.values()), dim=1 ) + # S_cat[it1*il1*ia1*im1*iu1,it2*il2*ia2*im2*iu2] + S_cat = torch.cat( list(S_.values()), dim=0 ) + return S_cat + + + + def change_index_Q(self,Q,info_stru): # Q[it][il][ib,ia*im*iu] + """ + <\psi|\phi> + Q_cat[ib,it*il*ia*im*iu] + """ + # Q_b[ib][0,it*il*ia*im*iu] + Q_b = ND_list(info_stru.Nb_true) + for ib in range(info_stru.Nb_true): + # Q_[it][il*ia*im*iu] + Q_ = dict() + for it in info_stru.Na.keys(): + # Q_ts[il][ia*im*iu] + Q_ts = [ Q_tl[ib] for Q_tl in Q[it] ] + Q_[it] = torch.cat(Q_ts) + Q_b[ib] = torch.cat(list(Q_.values())).view(1,-1) + # Q_cat[ib,it*il*ia*im*iu] + Q_cat = torch.cat( Q_b, dim=0 ) + return Q_cat + + + + def cal_coef(self,Q,S): + # Q[ib,it*il*ia*im*iu] + # S[it1*il1*ia1*im1*iu1,it2*il2*ia2*im2*iu2] + """ + <\psi|\phi> * <\phi|\phi>^{-1} + coef[ib,it*il*ia*im*iu] + = Q[ib,it1*il1*ia1*im1*iu1] * S{[it1*il1*ia1*im1*iu1,it2*il2*ia2*im2*iu2]}^{-1} + """ + S_I = torch.inverse(S) + coef = torch.mm(Q, S_I) + return coef + + + + def cal_V(self,coef,Q): + # coef[ib,it*il*ia*im*iu] + # Q[ib,it*il*ia*im*iu] + """ + <\psi|\psi> = <\psi|\phi> * <\phi|\phi>^{-1} * <\phi|psi> + V[ib1,ib2] + = sum_{it1,ia1,il1,im1,iu1} sum_{it2,ia2,il2,im2,iu2} + Q[ib1,it1*il1*ia1*im1*iu1] * S{[it1*il1*ia1*im1*iu1,it2*il2*ia2*im2*iu2]}^{-1} * Q[ib2,it2*il2*ia2*im2*iu2] + """ + V = torch.mm( coef, Q.t().conj() ).real + return V + + + def cal_V_origin(self,V,V_info): + # V[ib1,ib2] + """ + <\psi|\psi> = <\psi|\phi> * <\phi|\phi>^{-1} * <\phi|psi> + V_origin[ib] + V_origin[ib1,ib2] + """ + if V_info["same_band"]: V_origin = V.diag().sqrt() + else: V_origin = V.sqrt() + return V_origin + + + def cal_V_linear(self,coef,Q_linear,S_linear,V,V_info): + # coef[ib,it*il*ia*im*iu] + # Q_linear[ib,it*il*ia*im*iu] + # S_linear[it1*il1*ia1*im1*iu1,it2*il2*ia2*im2*iu2] + # V[ib1,ib2] + """ + V_linear[ib] + V_linear[ib1,ib2] + """ + V_linear_1 = coef.mm(S_linear).mm(coef.t().conj()).real + V_linear_2 = Q_linear.mm(coef.t().conj()).real + V_linear_3 = coef.mm(Q_linear.t().conj()).real + if V_info["same_band"]: + V_linear_1 = V_linear_1.diag() + V_linear_2 = V_linear_2.diag() + V_linear_3 = V_linear_3.diag() + if V_info["same_band"]: Z = V.diag().sqrt() + else: Z = V.sqrt() + Z = util.update0(Z) + V_linear = (-V_linear_1/Z + V_linear_2 + V_linear_3) / Z + return V_linear + + + def cal_T(self,C,E): + """ T = 0.5* sum_{it,il,iu} sum_{ie} ( E[it][il,ie] * C[it][il][ie,iu] )**2 """ + T = torch.zeros(1) + num = 0 + for it,C_t in C.items(): + for il,C_tl in enumerate(C_t): + for iu in range(C_tl.size()[1]): + T_tlu = torch.zeros(1) + Z_tlu = 0 + for ie in range(C_tl.size()[0]): + T_tlu = T_tlu + ( E[it][il,ie] * C_tl[ie,iu] )**2 + Z_tlu = Z_tlu + E[it][il,ie].item()**2 + T = T + T_tlu/Z_tlu + num += C_tl.size()[1] + T = 0.5 * T / num + return T \ No newline at end of file diff --git a/SIAB/opt_orb_pytorch_dpsi/orbital.py b/SIAB/opt_orb_pytorch_dpsi/orbital.py new file mode 100644 index 00000000..7600e778 --- /dev/null +++ b/SIAB/opt_orb_pytorch_dpsi/orbital.py @@ -0,0 +1,80 @@ +from util import ND_list +import numpy as np +from scipy.special import spherical_jn +from scipy.integrate import simps +from scipy.optimize import fsolve +import functools +import torch + +def generate_orbital(info_element,C,E): + """ C[it][il][ie,iu] """ + """ orb[it][il][iu][r] = \suml_{ie} C[it][il][ie,iu] * jn(il,ie*r) """ + orb = dict() + for it in info_element: + Nr = int(info_element[it].Rcut/info_element[it].dr)+1 + orb[it] = ND_list(info_element[it].Nl) + for il in range(info_element[it].Nl): + orb[it][il] = ND_list(info_element[it].Nu[il]) + for iu in range(info_element[it].Nu[il]): + orb[it][il][iu] = np.zeros(Nr) + for ir in range(Nr): + r = ir * info_element[it].dr + for ie in range(info_element[it].Ne): + orb[it][il][iu][ir] += C[it][il][ie,iu].item() * spherical_jn(il,E[it][il,ie].item()*r) + return orb + + +def smooth_orbital(orb,Rcut,dr,smearing_sigma): + for it,orb_t in orb.items(): + for orb_tl in orb_t: + for orb_tlu in orb_tl: + for ir in range(orb_tlu.shape[0]): + assert orb_tlu.shape[0] == int(Rcut[it]/dr[it])+1 + r = ir * dr[it] + orb_tlu[ir] *= 1-np.exp( -(r-Rcut[it])**2/(2*smearing_sigma**2) ) + + + +def inner_product( orb1, orb2, dr ): + assert orb1.shape == orb2.shape + r = np.array(range(orb1.shape[0]))*dr + return simps( orb1 * orb2 * r * r, dx=dr ) + +def normalize(orb,dr,C=None,flag_norm_orb=False,flag_norm_C=False): + """ C[it][il][ie,iu] """ + """ orb[it][il][iu][r] = \suml_{ie} C[it][il][ie,iu] * jn(il,ie*r) """ + for it,orb_t in orb.items(): + for il,orb_tl in enumerate(orb_t): + for iu,orb_tlu in enumerate(orb_tl): + norm = np.sqrt(inner_product(orb_tlu,orb_tlu,dr[it])) + if flag_norm_orb: orb_tlu[:] = orb_tlu / norm + if flag_norm_C: C[it][il].data[:,iu] = C[it][il].data[:,iu] / norm + +def orth(orb,dr): + """ |n'> = 1/Z ( |n> - \sum_{i=0}^{n-1} |i> ) """ + """ orb[it][il][iu,r] """ + for it,orb_t in orb.items(): + for il,orb_tl in enumerate(orb_t): + for iu1,orb_tlu1 in enumerate(orb_tl): + for iu2 in range(iu1): + orb_tlu1[:] -= orb_tl[iu2] * inner_product(orb_tlu1,orb_tl[iu2],dr[it]) + orb_tlu1[:] = orb_tlu1 / np.sqrt(inner_product(orb_tlu1,orb_tlu1,dr[it])) + +def find_eigenvalue(Nl,Ne): + """ E[il,ie] """ + E = np.zeros((Nl,Ne+Nl+1)) + for ie in range(1,Ne+Nl+1): + E[0,ie] = ie*np.pi + for il in range(1,Nl): + jl = functools.partial(spherical_jn,il) + for ie in range(1,Ne+Nl+1-il): + E[il,ie] = fsolve( jl, (E[il-1,ie]+E[il-1,ie+1])/2 ) + return E[:,1:Ne+1] + +def set_E(info_element): + """ E[it][il,ie] """ + eigenvalue = { it:find_eigenvalue(info_element[it].Nl,info_element[it].Ne) for it in info_element } + E = dict() + for it in info_element: + E[it] = torch.from_numpy(( eigenvalue[it]/info_element[it].Rcut ).astype("float64")) + return E \ No newline at end of file diff --git a/SIAB/opt_orb_pytorch_dpsi/torch_complex_bak.py b/SIAB/opt_orb_pytorch_dpsi/torch_complex_bak.py new file mode 100644 index 00000000..f08be745 --- /dev/null +++ b/SIAB/opt_orb_pytorch_dpsi/torch_complex_bak.py @@ -0,0 +1,84 @@ +import torch + +class ComplexTensor: + def __init__(self,real,imag): + self.real = real + self.imag = imag + + def view(self,*args,**kwargs): + return ComplexTensor( self.real.view(*args,**kwargs), self.imag.view(*args,**kwargs) ) + def t(self,*args,**kwargs): + return ComplexTensor( self.real.t(*args,**kwargs), self.imag.t(*args,**kwargs) ) +# def transpose(self,*args,**kwargs): +# return ComplexTensor( self.real.transpose(*args,**kwargs), self.imag.transpose(*args,**kwargs) ) + def __getitem__(self,*args,**kwargs): + return ComplexTensor( self.real.__getitem__(*args,**kwargs), self.imag.__getitem__(*args,**kwargs) ) + def __str__(self): + return "<{0};{1}>".format(self.real, self.imag) + __repr__=__str__ +# def size(self,*args,**kwargs): +# return ComplexTensor( self.real.size(*args,**kwargs), self.imag.size(*args,**kwargs) ) + + def conj(self): + return ComplexTensor( self.real, -self.imag ) + + def mm( self,x2, *args,**kwargs ): + return mm( self,x2, *args,**kwargs ) + + +def dot( x1,x2, *args,**kwargs ): + if isinstance(x1,ComplexTensor): + if isinstance(x2,ComplexTensor): + return ComplexTensor( torch.dot( x1.real,x2.real, *args,**kwargs ) - torch.dot( x1.imag,x2.imag, *args,**kwargs ), torch.dot( x1.real,x2.imag, *args,**kwargs ) + torch.dot( x1.imag,x2.real, *args,**kwargs ) ) + else: + return ComplexTensor( torch.dot( x1.real,x2, *args,**kwargs ), torch.dot( x1.imag,x2, *args,**kwargs ) ) + else: + if isinstance(x2,ComplexTensor): + return ComplexTensor( torch.dot( x1,x2.real, *args,**kwargs ), torch.dot( x1,x2.imag, *args,**kwargs ) ) + else: + return torch.dot( x1,x2, *args,**kwargs ) +def mv( x1,x2, *args,**kwargs ): + if isinstance(x1,ComplexTensor): + if isinstance(x2,ComplexTensor): + return ComplexTensor( torch.mv( x1.real,x2.real, *args,**kwargs ) - torch.mv( x1.imag,x2.imag, *args,**kwargs ), torch.mv( x1.real,x2.imag, *args,**kwargs ) + torch.mv( x1.imag,x2.real, *args,**kwargs ) ) + else: + return ComplexTensor( torch.mv( x1.real,x2, *args,**kwargs ), torch.mv( x1.imag,x2, *args,**kwargs ) ) + else: + if isinstance(x2,ComplexTensor): + return ComplexTensor( torch.mv( x1,x2.real, *args,**kwargs ), torch.mv( x1,x2.imag, *args,**kwargs ) ) + else: + return torch.mv( x1,x2, *args,**kwargs ) +def mm( x1,x2, *args,**kwargs ): + if isinstance(x1,ComplexTensor): + if isinstance(x2,ComplexTensor): + return ComplexTensor( torch.mm( x1.real,x2.real, *args,**kwargs ) - torch.mm( x1.imag,x2.imag, *args,**kwargs ), torch.mm( x1.real,x2.imag, *args,**kwargs ) + torch.mm( x1.imag,x2.real, *args,**kwargs ) ) + else: + return ComplexTensor( torch.mm( x1.real,x2, *args,**kwargs ), torch.mm( x1.imag,x2, *args,**kwargs ) ) + else: + if isinstance(x2,ComplexTensor): + return ComplexTensor( torch.mm( x1,x2.real, *args,**kwargs ), torch.mm( x1,x2.imag, *args,**kwargs ) ) + else: + return torch.mm( x1,x2, *args,**kwargs ) + + +def cat( xs, *args,**kwargs ): + if isinstance(xs[0],ComplexTensor): + xs_real = []; xs_imag = [] + for x in xs: + xs_real.append(x.real) + xs_imag.append(x.imag) + return ComplexTensor( torch.cat(xs_real,*args,**kwargs), torch.cat(xs_imag,*args,**kwargs) ) + else: + return torch.cat(xs,*args,**kwargs) + + + +def inverse(M): + if isinstance(M,ComplexTensor): + A=M.real + B=M.imag + tmp_AB = torch.mm(A.inverse(),B) # A^{-1} B + tmp_X = (A+torch.mm(B,tmp_AB)).inverse() # ( A + B A^{-1} B )^{-1} + return ComplexTensor( tmp_X, -torch.mm(tmp_AB,tmp_X) ) + else: + return M.inverse() \ No newline at end of file diff --git a/SIAB/opt_orb_pytorch_dpsi/util.py b/SIAB/opt_orb_pytorch_dpsi/util.py new file mode 100644 index 00000000..9b64a8e4 --- /dev/null +++ b/SIAB/opt_orb_pytorch_dpsi/util.py @@ -0,0 +1,42 @@ +def ND_list(*sizes,element=None): + size_1,*size_other = sizes + l = [element] * size_1 + if size_other: + for i in range(len(l)): + l[i] = ND_list(*size_other,element=element) + else: + if element in ["dict()","list()"]: + for i in range(size_1): + l[i] = eval(element) + return l + + +def ignore_line(file,N): + for _ in range(N): + file.readline() + + +class Info: + def Nm(self,il): return 2*il+1 + def __str__(self): + return "\n".join([name+"\t"+str(value) for name,value in self.__dict__.items()]) + __repr__=__str__ + +def change_to_cuda(s): + if isinstance(s,list): + return [change_to_cuda(x) for x in s] + elif isinstance(s,dict): + return {i:change_to_cuda(x) for i,x in s.items()} + elif isinstance(s,torch.Tensor): + return s.cuda() + elif isinstance(s,torch_complex.ComplexTensor): + return torch_complex.ComplexTensor( change_to_cuda(s.real), change_to_cuda(s.imag) ) + else: + print(s) + raise TypeError("change_to_cuda") + +def update0(t): + return t.masked_fill(mask=(t==0), value=1E-10) + +def Nm(il): + return 2*il+1 \ No newline at end of file