-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add opt_orb_pytorch_dpsi from PeizeLin
- Loading branch information
Showing
13 changed files
with
1,195 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
import IO.read_istate | ||
import torch | ||
import re | ||
import functools | ||
import operator | ||
|
||
def cal_weight(info_weight, flag_same_band, stru_file_list=None): | ||
""" weight[ist][ib] """ | ||
|
||
if "bands_file" in info_weight.keys(): | ||
if "bands_range" in info_weight.keys(): | ||
raise IOError('"bands_file" and "bands_range" only once') | ||
|
||
weight = [] # weight[ist][ib] | ||
for weight_stru, file_name in zip(info_weight["stru"], info_weight["bands_file"]): | ||
occ = IO.read_istate.read_istate(file_name) | ||
weight += [occ_k * weight_stru for occ_k in occ] | ||
|
||
elif "bands_range" in info_weight.keys(): | ||
k_weight = read_k_weight(stru_file_list) # k_weight[ist][ik] | ||
nbands = read_nbands(stru_file_list) # nbands[ist] | ||
|
||
st_weight = [] # st_weight[ist][ib] | ||
for weight_stru, bands_range, nbands_ist in zip(info_weight["stru"], info_weight["bands_range"], nbands): | ||
st_weight_tmp = torch.zeros((nbands_ist,)) | ||
st_weight_tmp[:bands_range] = weight_stru | ||
st_weight.append( st_weight_tmp ) | ||
|
||
weight = [] # weight[ist][ib] | ||
for ist,_ in enumerate(k_weight): | ||
for ik,_ in enumerate(k_weight[ist]): | ||
weight.append(st_weight[ist] * k_weight[ist][ik]) | ||
|
||
else: | ||
raise IOError('"bands_file" and "bands_range" must once') | ||
|
||
|
||
if not flag_same_band: | ||
for ist,_ in enumerate(weight): | ||
weight[ist] = torch.tensordot(weight[ist], weight[ist], dims=0) | ||
|
||
|
||
normalization = functools.reduce(operator.add, map(torch.sum, weight), 0) | ||
weight = list(map(lambda x:x/normalization, weight)) | ||
|
||
return weight | ||
|
||
|
||
def read_k_weight(stru_file_list): | ||
""" weight[ist][ik] """ | ||
weight = [] # weight[ist][ik] | ||
for file_name in stru_file_list: | ||
weight_k = [] # weight_k[ik] | ||
with open(file_name,"r") as file: | ||
data = re.compile(r"<WEIGHT_OF_KPOINTS>(.+)</WEIGHT_OF_KPOINTS>", re.S).search(file.read()).group(1).split("\n") | ||
for line in data: | ||
line = line.strip() | ||
if line: | ||
weight_k.append(float(line.split()[-1])) | ||
weight.append(weight_k) | ||
return weight | ||
|
||
|
||
def read_nbands(stru_file_list): | ||
""" nbands[ib] """ | ||
nbands = [] | ||
for file_name in stru_file_list: | ||
with open(file_name,"r") as file: | ||
nbands.append(int(re.compile(r"(\d+)\s+nbands").search(file.read()).group(1))) | ||
return nbands |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
import addict | ||
import util | ||
import itertools | ||
|
||
def change_info(info_old, weight_old): | ||
info_stru = [None] * info_old.Nst | ||
for ist in range(len(info_stru)): | ||
info_stru[ist] = addict.Dict() | ||
for ist,Na in enumerate(info_old.Na): | ||
info_stru[ist].Na = Na | ||
for ist,weight in enumerate(weight_old): | ||
info_stru[ist].weight = weight | ||
info_stru[ist].Nb = weight.shape[0] | ||
for ib in range(weight.shape[0], 0, -1): | ||
if weight[ib-1]>0: | ||
info_stru[ist].Nb_true = ib | ||
break | ||
|
||
info_element = addict.Dict() | ||
for it_index,it in enumerate(info_old.Nt_all): | ||
info_element[it].index = it_index | ||
for it,Nu in info_old.Nu.items(): | ||
info_element[it].Nu = Nu | ||
info_element[it].Nl = len(Nu) | ||
for it,Rcut in info_old.Rcut.items(): | ||
info_element[it].Rcut = Rcut | ||
for it,dr in info_old.dr.items(): | ||
info_element[it].dr = dr | ||
for it,Ecut in info_old.Ecut.items(): | ||
info_element[it].Ecut = Ecut | ||
for it,Ne in info_old.Ne.items(): | ||
info_element[it].Ne = Ne | ||
|
||
info_opt = addict.Dict() | ||
info_opt.lr = info_old.lr | ||
info_opt.cal_T = info_old.cal_T | ||
info_opt.cal_smooth = info_old.cal_smooth | ||
|
||
return info_stru, info_element, info_opt | ||
|
||
""" | ||
info_stru = | ||
[{'Na': {'C': 1}, | ||
'Nb': 6, | ||
'Nb_true': 4, | ||
'weight': tensor([0.0333, 0.0111, 0.0111, 0.0111, 0.0000, 0.0000])}, | ||
{'Na': {'C': 1}, | ||
'Nb': 6, | ||
'Nb_true': 2, | ||
'weight': tensor([0.0667, 0.0667, 0.0000, 0.0000, 0.0000, 0.0000])}, | ||
{'Na': {'C': 1, 'O': 2}, | ||
'Nb': 10, | ||
'Nb_true': 8, | ||
'weight': tensor([0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.0000, 0.0000])}] | ||
info_element = | ||
{'C': { | ||
'Ecut': 200, | ||
'Ne': 19, | ||
'Nl': 3, | ||
'Nu': [2, 2, 1], | ||
'Rcut': 6, | ||
'dr': 0.01, | ||
'index': 0}, | ||
'O': { | ||
'Ecut': 200, | ||
'Ne': 19, | ||
'Nl': 3, | ||
'Nu': [3, 2, 1], | ||
'Rcut': 6, | ||
'dr': 0.01, | ||
'index': 1}} | ||
info_opt = | ||
{'cal_T': False, | ||
'cal_smooth': False, | ||
'lr': 0.01} | ||
""" | ||
|
||
|
||
def get_info_max(info_stru, info_element): | ||
info_max = [None] * len(info_stru) | ||
for ist in range(len(info_stru)): | ||
Nt = info_stru[ist].Na.keys() | ||
info_max[ist] = addict.Dict() | ||
info_max[ist].Nt = len(Nt) | ||
info_max[ist].Na = max((info_stru[ist].Na[it] for it in Nt)) | ||
info_max[ist].Nl = max([info_element[it].Nl for it in Nt]) | ||
info_max[ist].Nm = max((util.Nm(info_element[it].Nl-1) for it in Nt)) | ||
info_max[ist].Nu = max(itertools.chain.from_iterable([info_element[it].Nu for it in Nt])) | ||
info_max[ist].Ne = max((info_element[it].Ne for it in Nt)) | ||
info_max[ist].Nb = info_stru[ist].Nb | ||
return info_max | ||
|
||
""" | ||
[{'Na': 2, 'Nb': 6, 'Ne': 19, 'Nl': 3, 'Nm': 5, 'Nt': 1, 'Nu': 2}, | ||
{'Na': 2, 'Nb': 6, 'Ne': 19, 'Nl': 3, 'Nm': 5, 'Nt': 1, 'Nu': 2}] | ||
""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
from util import * | ||
import torch | ||
import numpy as np | ||
|
||
def random_C_init(info_element): | ||
""" C[it][il][ie,iu] <jY|\phi> """ | ||
C = dict() | ||
for it in info_element.keys(): | ||
C[it] = ND_list(info_element[it].Nl) | ||
for il in range(info_element[it].Nl): | ||
C[it][il] = torch.tensor(np.random.uniform(-1,1, (info_element[it].Ne, info_element[it].Nu[il])), dtype=torch.float64, requires_grad=True) | ||
return C | ||
|
||
|
||
|
||
def read_C_init(file_name,info_element): | ||
""" C[it][il][ie,iu] <jY|\phi> """ | ||
C = random_C_init(info_element) | ||
|
||
with open(file_name,"r") as file: | ||
|
||
for line in file: | ||
if line.strip() == "<Coefficient>": | ||
line=None | ||
break | ||
ignore_line(file,1) | ||
|
||
C_read_index = set() | ||
while True: | ||
line = file.readline().strip() | ||
if line.startswith("Type"): | ||
it,il,iu = file.readline().split(); | ||
il = int(il) | ||
iu = int(iu)-1 | ||
C_read_index.add((it,il,iu)) | ||
line = file.readline().split() | ||
for ie in range(info_element[it].Ne): | ||
if not line: line = file.readline().split() | ||
C[it][il].data[ie,iu] = float(line.pop(0)) | ||
elif line.startswith("</Coefficient>"): | ||
break; | ||
else: | ||
raise IOError("unknown line in read_C_init "+file_name+"\n"+line) | ||
return C, C_read_index | ||
|
||
|
||
|
||
def copy_C(C,info_element): | ||
C_copy = dict() | ||
for it in info_element.keys(): | ||
C_copy[it] = ND_list(info_element[it].Nl) | ||
for il in range(info_element[it].Nl): | ||
C_copy[it][il] = C[it][il].clone() | ||
return C_copy | ||
|
||
|
||
|
||
def write_C(file_name,C,Spillage): | ||
with open(file_name,"w") as file: | ||
print("<Coefficient>", file=file) | ||
#print("\tTotal number of radial orbitals.", file=file) | ||
nTotal = 0 | ||
for it,C_t in C.items(): | ||
for il,C_tl in enumerate(C_t): | ||
for iu in range(C_tl.size()[1]): | ||
nTotal += 1 | ||
#nTotal = sum(info["Nu"][it]) | ||
print("\t %s Total number of radial orbitals."%nTotal , file=file) | ||
#print("\tTotal number of radial orbitals.", file=file) | ||
for it,C_t in C.items(): | ||
for il,C_tl in enumerate(C_t): | ||
for iu in range(C_tl.size()[1]): | ||
print("\tType\tL\tZeta-Orbital", file=file) | ||
print(f"\t {it} \t{il}\t {iu+1}", file=file) | ||
for ie in range(C_tl.size()[0]): | ||
print("\t", '%18.14f'%C_tl[ie,iu].item(), file=file) | ||
print("</Coefficient>", file=file) | ||
print("<Mkb>", file=file) | ||
print("Left spillage = %.10e"%Spillage.item(), file=file) | ||
print("</Mkb>", file=file) | ||
|
||
|
||
#def init_C(info): | ||
# """ C[it][il][ie,iu] """ | ||
# C = ND_list(max(info.Nt)) | ||
# for it in range(len(C)): | ||
# C[it] = ND_list(info.Nl[it]) | ||
# for il in range(info.Nl[it]): | ||
# C[it][il] = torch.autograd.Variable( torch.Tensor( info.Ne, info.Nu[it][il] ), requires_grad = True ) | ||
# | ||
# with open("C_init.dat","r") as file: | ||
# line = [] | ||
# for it in range(len(C)): | ||
# for il in range(info.Nl[it]): | ||
# for i_n in range(info.Nu[it][il]): | ||
# for ie in range(info.Ne[it]): | ||
# if not line: line=file.readline().split() | ||
# C[it][il].data[ie,i_n] = float(line.pop(0)) | ||
# return C |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
def print_V(V,file_name): | ||
""" V[ist][ib] """ | ||
with open(file_name,"w") as file: | ||
for V_s in V: | ||
for V_sb in V_s: | ||
print(1-V_sb.item(),end="\t",file=file) | ||
print(file=file) | ||
|
||
def print_S(S,file_name): | ||
""" S[ist][it1,it2][il1][il2][ia1*im1*in1,ia2*im2*in2] """ | ||
with open(file_name,"w") as file: | ||
for ist,S_s in enumerate(S): | ||
for (it1,it2),S_tt in S_s.items(): | ||
for il1,S_ttl in enumerate(S_tt): | ||
for il2,S_ttll in enumerate(S_ttl): | ||
print(ist,it1,it2,il1,il2,file=file) | ||
print(S_ttll.real.numpy(),file=file) | ||
print(S_ttll.imag.numpy(),"\n",file=file) | ||
|
||
def print_Q(Q,file_name): | ||
""" Q[ist][it][il][ib,ia*im*iu] """ | ||
with open(file_name,"w") as file: | ||
for ist,Q_s in enumerate(Q): | ||
for it,Q_st in Q_s.items(): | ||
for il,Q_stl in enumerate(Q_st): | ||
print(ist,it,il,file=file) | ||
print(Q_stl.real.numpy(),file=file) | ||
print(Q_stl.imag.numpy(),"\n",file=file) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
periodtable = { 'H': 1, 'He': 2, 'Li': 3, 'Be': 4, 'B': 5, 'C': 6, 'N': 7, | ||
'O': 8, 'F': 9, 'Ne': 10, 'Na': 11, 'Mg': 12, 'Al': 13, | ||
'Si': 14, 'P': 15, 'S': 16, 'Cl': 17, 'Ar': 18, 'K': 19, | ||
'Ca': 20, 'Sc': 21, 'Ti': 22, 'V': 23, 'Cr': 24, 'Mn': 25, | ||
'Fe': 26, 'Co': 27, 'Ni': 28, 'Cu': 29, 'Zn': 30, 'Ga': 31, | ||
'Ge': 32, 'As': 33, 'Se': 34, 'Br': 35, 'Kr': 36, 'Rb': 37, | ||
'Sr': 38, 'Y': 39, 'Zr': 40, 'Nb': 41, 'Mo': 42, 'Tc': 43, | ||
'Ru': 44, 'Rh': 45, 'Pd': 46, 'Ag': 47, 'Cd': 48, 'In': 49, | ||
'Sn': 50, 'Sb': 51, 'Te': 52, 'I': 53, 'Xe': 54, 'Cs': 55, | ||
'Ba': 56, #'La': 57, 'Ce': 58, 'Pr': 59, 'Nd': 60, 'Pm': 61, | ||
## 'Sm': 62, 'Eu': 63, 'Gd': 64, 'Tb': 65, 'Dy': 66, 'Ho': 67, | ||
## 'Er': 68, 'Tm': 69, 'Yb': 70, | ||
## 'Lu': 71, | ||
'Hf': 72, 'Ta': 73, | ||
'W': 74, 'Re': 75, 'Os': 76, 'Ir': 77, 'Pt': 78, 'Au': 79, | ||
'Hg': 80, 'Tl': 81, 'Pb': 82, 'Bi': 83, | ||
## 'Po': 84, #'At': 85, | ||
## 'Rn': 86, #'Fr': 87, 'Ra': 88, 'Ac': 89, 'Th': 90, 'Pa': 91, | ||
## 'U': 92, 'Np': 93, 'Pu': 94, 'Am': 95, 'Cm': 96, 'Bk': 97, | ||
## 'Cf': 98, 'Es': 99, 'Fm': 100, 'Md': 101, 'No': 102, 'Lr': 103, | ||
## 'Rf': 104, 'Db': 105, 'Sg': 106, 'Bh': 107, 'Hs': 108, | ||
## 'Mt': 109, 'Ds': 110, 'Rg': 111, 'Cn': 112, 'Uut': 113, | ||
## 'Fl': 114, 'Uup': 115, 'Lv': 116, 'Uus': 117, 'Uuo': 118 | ||
} | ||
|
||
def print_orbital(orb,info_element): | ||
""" orb[it][il][iu][r] """ | ||
for it,orb_t in orb.items(): | ||
#with open("orb_{0}.dat".format(it),"w") as file: | ||
with open("ORBITAL_{0}U.dat".format( periodtable[it] ),"w") as file: | ||
print_orbital_head(file,info_element,it) | ||
for il,orb_tl in enumerate(orb_t): | ||
for iu,orb_tlu in enumerate(orb_tl): | ||
print(""" Type L N""",file=file) | ||
print(""" 0 {0} {1}""".format(il,iu),file=file) | ||
for ir,orb_tlur in enumerate(orb_tlu): | ||
print( '%.14e'%orb_tlur, end=" ",file=file) | ||
if ir%4==3: print(file=file) | ||
print(file=file) | ||
|
||
|
||
def plot_orbital(orb,Rcut,dr): | ||
for it,orb_t in orb.items(): | ||
#with open("orb_{0}_plot.dat".format(it),"w") as file: | ||
with open("ORBITAL_PLOTU.dat", "w") as file: | ||
Nr = int(Rcut[it]/dr[it])+1 | ||
for ir in range(Nr): | ||
print( '%10.6f'%(ir*dr[it]),end=" ",file=file) | ||
for il,orb_tl in enumerate(orb_t): | ||
for orb_tlu in orb_tl: | ||
print( '%18.14f'%orb_tlu[ir],end=" ",file=file) | ||
print(file=file) | ||
|
||
|
||
def print_orbital_head(file,info_element,it): | ||
print( "---------------------------------------------------------------------------", file=file ) | ||
print( "Element {0}".format(it), file=file ) | ||
print( "Energy Cutoff(Ry) {0}".format(info_element[it].Ecut), file=file ) | ||
print( "Radius Cutoff(a.u.) {0}".format(info_element[it].Rcut), file=file ) | ||
print( "Lmax {0}".format(info_element[it].Nl-1), file=file ) | ||
l_name = ["S","P","D"]+list(map(chr,range(ord('F'),ord('Z')+1))) | ||
for il,iu in enumerate(info_element[it].Nu): | ||
print( "Number of {0}orbital--> {1}".format(l_name[il],iu), file=file ) | ||
print( "---------------------------------------------------------------------------", file=file ) | ||
print( "SUMMARY END", file=file ) | ||
print( file=file ) | ||
print( "Mesh {0}".format(int(info_element[it].Rcut/info_element[it].dr)+1), file=file ) | ||
print( "dr {0}".format(info_element[it].dr), file=file ) | ||
|
||
|
Oops, something went wrong.