Skip to content

Commit

Permalink
add opt_orb_pytorch_dpsi from PeizeLin
Browse files Browse the repository at this point in the history
  • Loading branch information
PeizeLin committed Aug 31, 2022
1 parent b6afc15 commit 48a6e42
Show file tree
Hide file tree
Showing 13 changed files with 1,195 additions and 0 deletions.
70 changes: 70 additions & 0 deletions SIAB/opt_orb_pytorch_dpsi/IO/cal_weight.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import IO.read_istate
import torch
import re
import functools
import operator

def cal_weight(info_weight, flag_same_band, stru_file_list=None):
""" weight[ist][ib] """

if "bands_file" in info_weight.keys():
if "bands_range" in info_weight.keys():
raise IOError('"bands_file" and "bands_range" only once')

weight = [] # weight[ist][ib]
for weight_stru, file_name in zip(info_weight["stru"], info_weight["bands_file"]):
occ = IO.read_istate.read_istate(file_name)
weight += [occ_k * weight_stru for occ_k in occ]

elif "bands_range" in info_weight.keys():
k_weight = read_k_weight(stru_file_list) # k_weight[ist][ik]
nbands = read_nbands(stru_file_list) # nbands[ist]

st_weight = [] # st_weight[ist][ib]
for weight_stru, bands_range, nbands_ist in zip(info_weight["stru"], info_weight["bands_range"], nbands):
st_weight_tmp = torch.zeros((nbands_ist,))
st_weight_tmp[:bands_range] = weight_stru
st_weight.append( st_weight_tmp )

weight = [] # weight[ist][ib]
for ist,_ in enumerate(k_weight):
for ik,_ in enumerate(k_weight[ist]):
weight.append(st_weight[ist] * k_weight[ist][ik])

else:
raise IOError('"bands_file" and "bands_range" must once')


if not flag_same_band:
for ist,_ in enumerate(weight):
weight[ist] = torch.tensordot(weight[ist], weight[ist], dims=0)


normalization = functools.reduce(operator.add, map(torch.sum, weight), 0)
weight = list(map(lambda x:x/normalization, weight))

return weight


def read_k_weight(stru_file_list):
""" weight[ist][ik] """
weight = [] # weight[ist][ik]
for file_name in stru_file_list:
weight_k = [] # weight_k[ik]
with open(file_name,"r") as file:
data = re.compile(r"<WEIGHT_OF_KPOINTS>(.+)</WEIGHT_OF_KPOINTS>", re.S).search(file.read()).group(1).split("\n")
for line in data:
line = line.strip()
if line:
weight_k.append(float(line.split()[-1]))
weight.append(weight_k)
return weight


def read_nbands(stru_file_list):
""" nbands[ib] """
nbands = []
for file_name in stru_file_list:
with open(file_name,"r") as file:
nbands.append(int(re.compile(r"(\d+)\s+nbands").search(file.read()).group(1)))
return nbands
98 changes: 98 additions & 0 deletions SIAB/opt_orb_pytorch_dpsi/IO/change_info.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import addict
import util
import itertools

def change_info(info_old, weight_old):
info_stru = [None] * info_old.Nst
for ist in range(len(info_stru)):
info_stru[ist] = addict.Dict()
for ist,Na in enumerate(info_old.Na):
info_stru[ist].Na = Na
for ist,weight in enumerate(weight_old):
info_stru[ist].weight = weight
info_stru[ist].Nb = weight.shape[0]
for ib in range(weight.shape[0], 0, -1):
if weight[ib-1]>0:
info_stru[ist].Nb_true = ib
break

info_element = addict.Dict()
for it_index,it in enumerate(info_old.Nt_all):
info_element[it].index = it_index
for it,Nu in info_old.Nu.items():
info_element[it].Nu = Nu
info_element[it].Nl = len(Nu)
for it,Rcut in info_old.Rcut.items():
info_element[it].Rcut = Rcut
for it,dr in info_old.dr.items():
info_element[it].dr = dr
for it,Ecut in info_old.Ecut.items():
info_element[it].Ecut = Ecut
for it,Ne in info_old.Ne.items():
info_element[it].Ne = Ne

info_opt = addict.Dict()
info_opt.lr = info_old.lr
info_opt.cal_T = info_old.cal_T
info_opt.cal_smooth = info_old.cal_smooth

return info_stru, info_element, info_opt

"""
info_stru =
[{'Na': {'C': 1},
'Nb': 6,
'Nb_true': 4,
'weight': tensor([0.0333, 0.0111, 0.0111, 0.0111, 0.0000, 0.0000])},
{'Na': {'C': 1},
'Nb': 6,
'Nb_true': 2,
'weight': tensor([0.0667, 0.0667, 0.0000, 0.0000, 0.0000, 0.0000])},
{'Na': {'C': 1, 'O': 2},
'Nb': 10,
'Nb_true': 8,
'weight': tensor([0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.0000, 0.0000])}]
info_element =
{'C': {
'Ecut': 200,
'Ne': 19,
'Nl': 3,
'Nu': [2, 2, 1],
'Rcut': 6,
'dr': 0.01,
'index': 0},
'O': {
'Ecut': 200,
'Ne': 19,
'Nl': 3,
'Nu': [3, 2, 1],
'Rcut': 6,
'dr': 0.01,
'index': 1}}
info_opt =
{'cal_T': False,
'cal_smooth': False,
'lr': 0.01}
"""


def get_info_max(info_stru, info_element):
info_max = [None] * len(info_stru)
for ist in range(len(info_stru)):
Nt = info_stru[ist].Na.keys()
info_max[ist] = addict.Dict()
info_max[ist].Nt = len(Nt)
info_max[ist].Na = max((info_stru[ist].Na[it] for it in Nt))
info_max[ist].Nl = max([info_element[it].Nl for it in Nt])
info_max[ist].Nm = max((util.Nm(info_element[it].Nl-1) for it in Nt))
info_max[ist].Nu = max(itertools.chain.from_iterable([info_element[it].Nu for it in Nt]))
info_max[ist].Ne = max((info_element[it].Ne for it in Nt))
info_max[ist].Nb = info_stru[ist].Nb
return info_max

"""
[{'Na': 2, 'Nb': 6, 'Ne': 19, 'Nl': 3, 'Nm': 5, 'Nt': 1, 'Nu': 2},
{'Na': 2, 'Nb': 6, 'Ne': 19, 'Nl': 3, 'Nm': 5, 'Nt': 1, 'Nu': 2}]
"""
99 changes: 99 additions & 0 deletions SIAB/opt_orb_pytorch_dpsi/IO/func_C.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
from util import *
import torch
import numpy as np

def random_C_init(info_element):
""" C[it][il][ie,iu] <jY|\phi> """
C = dict()
for it in info_element.keys():
C[it] = ND_list(info_element[it].Nl)
for il in range(info_element[it].Nl):
C[it][il] = torch.tensor(np.random.uniform(-1,1, (info_element[it].Ne, info_element[it].Nu[il])), dtype=torch.float64, requires_grad=True)
return C



def read_C_init(file_name,info_element):
""" C[it][il][ie,iu] <jY|\phi> """
C = random_C_init(info_element)

with open(file_name,"r") as file:

for line in file:
if line.strip() == "<Coefficient>":
line=None
break
ignore_line(file,1)

C_read_index = set()
while True:
line = file.readline().strip()
if line.startswith("Type"):
it,il,iu = file.readline().split();
il = int(il)
iu = int(iu)-1
C_read_index.add((it,il,iu))
line = file.readline().split()
for ie in range(info_element[it].Ne):
if not line: line = file.readline().split()
C[it][il].data[ie,iu] = float(line.pop(0))
elif line.startswith("</Coefficient>"):
break;
else:
raise IOError("unknown line in read_C_init "+file_name+"\n"+line)
return C, C_read_index



def copy_C(C,info_element):
C_copy = dict()
for it in info_element.keys():
C_copy[it] = ND_list(info_element[it].Nl)
for il in range(info_element[it].Nl):
C_copy[it][il] = C[it][il].clone()
return C_copy



def write_C(file_name,C,Spillage):
with open(file_name,"w") as file:
print("<Coefficient>", file=file)
#print("\tTotal number of radial orbitals.", file=file)
nTotal = 0
for it,C_t in C.items():
for il,C_tl in enumerate(C_t):
for iu in range(C_tl.size()[1]):
nTotal += 1
#nTotal = sum(info["Nu"][it])
print("\t %s Total number of radial orbitals."%nTotal , file=file)
#print("\tTotal number of radial orbitals.", file=file)
for it,C_t in C.items():
for il,C_tl in enumerate(C_t):
for iu in range(C_tl.size()[1]):
print("\tType\tL\tZeta-Orbital", file=file)
print(f"\t {it} \t{il}\t {iu+1}", file=file)
for ie in range(C_tl.size()[0]):
print("\t", '%18.14f'%C_tl[ie,iu].item(), file=file)
print("</Coefficient>", file=file)
print("<Mkb>", file=file)
print("Left spillage = %.10e"%Spillage.item(), file=file)
print("</Mkb>", file=file)


#def init_C(info):
# """ C[it][il][ie,iu] """
# C = ND_list(max(info.Nt))
# for it in range(len(C)):
# C[it] = ND_list(info.Nl[it])
# for il in range(info.Nl[it]):
# C[it][il] = torch.autograd.Variable( torch.Tensor( info.Ne, info.Nu[it][il] ), requires_grad = True )
#
# with open("C_init.dat","r") as file:
# line = []
# for it in range(len(C)):
# for il in range(info.Nl[it]):
# for i_n in range(info.Nu[it][il]):
# for ie in range(info.Ne[it]):
# if not line: line=file.readline().split()
# C[it][il].data[ie,i_n] = float(line.pop(0))
# return C
29 changes: 29 additions & 0 deletions SIAB/opt_orb_pytorch_dpsi/IO/print_QSV.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
def print_V(V,file_name):
""" V[ist][ib] """
with open(file_name,"w") as file:
for V_s in V:
for V_sb in V_s:
print(1-V_sb.item(),end="\t",file=file)
print(file=file)

def print_S(S,file_name):
""" S[ist][it1,it2][il1][il2][ia1*im1*in1,ia2*im2*in2] """
with open(file_name,"w") as file:
for ist,S_s in enumerate(S):
for (it1,it2),S_tt in S_s.items():
for il1,S_ttl in enumerate(S_tt):
for il2,S_ttll in enumerate(S_ttl):
print(ist,it1,it2,il1,il2,file=file)
print(S_ttll.real.numpy(),file=file)
print(S_ttll.imag.numpy(),"\n",file=file)

def print_Q(Q,file_name):
""" Q[ist][it][il][ib,ia*im*iu] """
with open(file_name,"w") as file:
for ist,Q_s in enumerate(Q):
for it,Q_st in Q_s.items():
for il,Q_stl in enumerate(Q_st):
print(ist,it,il,file=file)
print(Q_stl.real.numpy(),file=file)
print(Q_stl.imag.numpy(),"\n",file=file)

70 changes: 70 additions & 0 deletions SIAB/opt_orb_pytorch_dpsi/IO/print_orbital.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
periodtable = { 'H': 1, 'He': 2, 'Li': 3, 'Be': 4, 'B': 5, 'C': 6, 'N': 7,
'O': 8, 'F': 9, 'Ne': 10, 'Na': 11, 'Mg': 12, 'Al': 13,
'Si': 14, 'P': 15, 'S': 16, 'Cl': 17, 'Ar': 18, 'K': 19,
'Ca': 20, 'Sc': 21, 'Ti': 22, 'V': 23, 'Cr': 24, 'Mn': 25,
'Fe': 26, 'Co': 27, 'Ni': 28, 'Cu': 29, 'Zn': 30, 'Ga': 31,
'Ge': 32, 'As': 33, 'Se': 34, 'Br': 35, 'Kr': 36, 'Rb': 37,
'Sr': 38, 'Y': 39, 'Zr': 40, 'Nb': 41, 'Mo': 42, 'Tc': 43,
'Ru': 44, 'Rh': 45, 'Pd': 46, 'Ag': 47, 'Cd': 48, 'In': 49,
'Sn': 50, 'Sb': 51, 'Te': 52, 'I': 53, 'Xe': 54, 'Cs': 55,
'Ba': 56, #'La': 57, 'Ce': 58, 'Pr': 59, 'Nd': 60, 'Pm': 61,
## 'Sm': 62, 'Eu': 63, 'Gd': 64, 'Tb': 65, 'Dy': 66, 'Ho': 67,
## 'Er': 68, 'Tm': 69, 'Yb': 70,
## 'Lu': 71,
'Hf': 72, 'Ta': 73,
'W': 74, 'Re': 75, 'Os': 76, 'Ir': 77, 'Pt': 78, 'Au': 79,
'Hg': 80, 'Tl': 81, 'Pb': 82, 'Bi': 83,
## 'Po': 84, #'At': 85,
## 'Rn': 86, #'Fr': 87, 'Ra': 88, 'Ac': 89, 'Th': 90, 'Pa': 91,
## 'U': 92, 'Np': 93, 'Pu': 94, 'Am': 95, 'Cm': 96, 'Bk': 97,
## 'Cf': 98, 'Es': 99, 'Fm': 100, 'Md': 101, 'No': 102, 'Lr': 103,
## 'Rf': 104, 'Db': 105, 'Sg': 106, 'Bh': 107, 'Hs': 108,
## 'Mt': 109, 'Ds': 110, 'Rg': 111, 'Cn': 112, 'Uut': 113,
## 'Fl': 114, 'Uup': 115, 'Lv': 116, 'Uus': 117, 'Uuo': 118
}

def print_orbital(orb,info_element):
""" orb[it][il][iu][r] """
for it,orb_t in orb.items():
#with open("orb_{0}.dat".format(it),"w") as file:
with open("ORBITAL_{0}U.dat".format( periodtable[it] ),"w") as file:
print_orbital_head(file,info_element,it)
for il,orb_tl in enumerate(orb_t):
for iu,orb_tlu in enumerate(orb_tl):
print(""" Type L N""",file=file)
print(""" 0 {0} {1}""".format(il,iu),file=file)
for ir,orb_tlur in enumerate(orb_tlu):
print( '%.14e'%orb_tlur, end=" ",file=file)
if ir%4==3: print(file=file)
print(file=file)


def plot_orbital(orb,Rcut,dr):
for it,orb_t in orb.items():
#with open("orb_{0}_plot.dat".format(it),"w") as file:
with open("ORBITAL_PLOTU.dat", "w") as file:
Nr = int(Rcut[it]/dr[it])+1
for ir in range(Nr):
print( '%10.6f'%(ir*dr[it]),end=" ",file=file)
for il,orb_tl in enumerate(orb_t):
for orb_tlu in orb_tl:
print( '%18.14f'%orb_tlu[ir],end=" ",file=file)
print(file=file)


def print_orbital_head(file,info_element,it):
print( "---------------------------------------------------------------------------", file=file )
print( "Element {0}".format(it), file=file )
print( "Energy Cutoff(Ry) {0}".format(info_element[it].Ecut), file=file )
print( "Radius Cutoff(a.u.) {0}".format(info_element[it].Rcut), file=file )
print( "Lmax {0}".format(info_element[it].Nl-1), file=file )
l_name = ["S","P","D"]+list(map(chr,range(ord('F'),ord('Z')+1)))
for il,iu in enumerate(info_element[it].Nu):
print( "Number of {0}orbital--> {1}".format(l_name[il],iu), file=file )
print( "---------------------------------------------------------------------------", file=file )
print( "SUMMARY END", file=file )
print( file=file )
print( "Mesh {0}".format(int(info_element[it].Rcut/info_element[it].dr)+1), file=file )
print( "dr {0}".format(info_element[it].dr), file=file )


Loading

0 comments on commit 48a6e42

Please sign in to comment.