Skip to content

Commit

Permalink
modify main code
Browse files Browse the repository at this point in the history
  • Loading branch information
hswoo369 committed Apr 4, 2021
1 parent 8809008 commit 078550a
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 50 deletions.
65 changes: 25 additions & 40 deletions simple_nn_v2/features/symmetry_function/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@ def __init__(self, inputs=None):
self.default_inputs = {'symmetry_function':
{
'params': dict(),
'refdata_format':'vasp-out',
'compress_outcar':True,
'refdata_format': 'vasp-out',
'compress_outcar': True,
'data_per_tfrecord': 150,
'valid_rate': 0.1,
'shuffle':True,
'shuffle': True,
'add_NNP_ref': False, # atom E to tfrecord
'remain_pickle': False,
'continue': False,
Expand All @@ -46,8 +46,8 @@ def __init__(self, inputs=None):
'scale_type': 'minmax',
'scale_scale': 1.0,
'scale_rho': None,
'save_to_pickle':False,
'save_directory':'./data'
'save_to_pickle': False, # default format is .pt / if True, format is .pickle
'save_directory': './data' # directory of data files
}
}
self.structure_list = './str_list'
Expand All @@ -59,7 +59,6 @@ def __init__(self, inputs=None):
def set_inputs(self):
self.inputs = self.parent.inputs['symmetry_function']

# Genreate Method
def generate(self):
""" Generate structure data files(format: pickle/pt) that listed in "structure_list" file
Expand All @@ -70,29 +69,15 @@ def generate(self):
'total': full parameters in file
'int': int value parameters in file
'double': double value parameters in file
'int_p': convert int value parameters to c type array
'double_p': convert dobule value parameters to c type array
'int_p': convert int value parameters to C type array
'double_p': convert dobule value parameters to C type array
3. Load structure information using ase module (format: ase.atoms.Atoms object iterator)
4. Extract structure information from snapshot
5. Calculate symmetry functon values using C implemented code
result keys:
'x':
'dx':
'da':
'params':
'N':
'tot_num':
'partition':
'partition_XX':
'struct_type':
'struct_weight':
'atom_idx':
6. Save to data file (format: pickle or pt)
"""

# Data_generator object for handling [str_list], OUTCAR files, pickle/pt files
#data_generator = Data_generator(self.inputs, self.structure_list, self.pickle_list, self.parent)
data_generator = Data_generator(self.inputs, self.parent.logfile, self.structure_list, self.pickle_list)

# 1. Get structure list from "structure_list" file
Expand All @@ -107,7 +92,7 @@ def generate(self):

# Parsing C type symmetry function parameter to symf_params_set dictionary
for element in self.parent.inputs['atom_types']:
symf_params_set[element]['int_p'] = _gen_2Darray_for_ffi(symf_params_set[element]['int'], ffi, "int")
symf_params_set[element]['int_p'] = _gen_2Darray_for_ffi(symf_params_set[element]['int'], ffi, 'int')
symf_params_set[element]['double_p'] = _gen_2Darray_for_ffi(symf_params_set[element]['double'], ffi)

for item, tag_idx in zip(structures, structure_tag_idx):
Expand All @@ -116,21 +101,21 @@ def generate(self):

for snapshot in snapshots:
# 4. Extract structure information from snapshot (atom_num, cart, scale, cell)
# atom_type_idx(int list): list of type index for each atoms ex) [1,1,2,2,2,2]
# atom_type_idx(int list): list of type index for each atoms(start from 1) ex) [1,1,2,2,2,2]
# type_num(int dic): number of atoms for each types ex) {'Si': 2, 'O': 4}
# type_atom_idx(int list dic): list of atoms index that for each atom types ex) {'Si': [0,1], 'O': [2,3,4,5]}
atom_num, atom_type_idx, type_num, type_atom_idx, cart, scale, cell = self._get_structure_info(snapshot)

# Make C type data & 2D array from atom_type_idx, cart, scale, cell
atom_type_idx_p = ffi.cast("int *", atom_type_idx.ctypes.data)
# Make C type data of atom_type_idx, cart, scale, cell
atom_type_idx_p = ffi.cast('int *', atom_type_idx.ctypes.data)
cart_p = _gen_2Darray_for_ffi(cart, ffi)
scale_p = _gen_2Darray_for_ffi(scale, ffi)
cell_p = _gen_2Darray_for_ffi(cell, ffi)

# Initialize result dictionary
result = self._init_result(type_num, structure_tags, structure_weights, tag_idx, atom_type_idx)

for _ ,jtem in enumerate(self.parent.inputs['atom_types']):
for _, jtem in enumerate(self.parent.inputs['atom_types']):
# Set number of MPI
#begin , end = self._set_mpi(type_num , jtem)
#cal_atom_num , cal_atom_idx_p , x , dx , da , x_p , dx_p , da_p = self._get_sf_input(type_atom_idx ,\
Expand All @@ -139,11 +124,11 @@ def generate(self):
# Initialize variables for calculation
# cal_atom_idx(int list): atom index for calculation ex) [2,3,4]
# cal_atom_num(int): atom numbers for calculation ex) 3
cal_atom_idx, cal_atom_num, x, dx, da = self._init_sf_variables(type_atom_idx,\
cal_atom_idx, cal_atom_num, x, dx, da = self._init_sf_variables(type_atom_idx,\
jtem, symf_params_set, atom_num)

# Make C array from x, dx, da
cal_atom_idx_p = ffi.cast("int *", cal_atom_idx.ctypes.data)
cal_atom_idx_p = ffi.cast('int *', cal_atom_idx.ctypes.data)
x_p = _gen_2Darray_for_ffi(x, ffi)
dx_p = _gen_2Darray_for_ffi(dx, ffi)
da_p = _gen_2Darray_for_ffi(da, ffi)
Expand All @@ -161,7 +146,7 @@ def generate(self):
#self._check_error(errno)

# Set result to dictionary format from calculated value
self._set_result(result, x , dx, da, type_num, jtem, symf_params_set, atom_num)
self._set_result(result, x, dx, da, type_num, jtem, symf_params_set, atom_num)
# End of for loop

# Extract E, F, S from snapshot and append to result dictionary
Expand All @@ -188,18 +173,18 @@ def _parsing_symf_params(self):
return symf_params_set

def __read_params(self, filename):
params_i = list()
params_d = list()
params_int = list()
params_double = list()
with open(filename, 'r') as fil:
for line in fil:
tmp = line.split()
params_i += [list(map(int, tmp[:3]))]
params_d += [list(map(float, tmp[3:]))]
params_int += [list(map(int, tmp[:3]))]
params_double += [list(map(float, tmp[3:]))]

params_i = np.asarray(params_i, dtype=np.intc, order='C')
params_d = np.asarray(params_d, dtype=np.float64, order='C')
params_int = np.asarray(params_int, dtype=np.intc, order='C')
params_double = np.asarray(params_double, dtype=np.float64, order='C')

return params_i, params_d
return params_int, params_double

# Extract structure information from snapshot (atom numbers, cart, scale, cell)
# Return variables related to structure information (atom_type_idx, type_num, type_atom_idx)
Expand All @@ -213,7 +198,7 @@ def _get_structure_info(self, snapshot):
atom_type_idx = np.zeros([len(symbols)], dtype=np.intc, order='C')
type_num = dict()
type_atom_idx = dict()
for j,jtem in enumerate(self.parent.inputs['atom_types']):
for j, jtem in enumerate(self.parent.inputs['atom_types']):
tmp = symbols==jtem
atom_type_idx[tmp] = j+1
type_num[jtem] = np.sum(tmp).astype(np.int64)
Expand All @@ -222,7 +207,7 @@ def _get_structure_info(self, snapshot):
# if not, it could generate bug in training process for force training
type_atom_idx[jtem] = np.arange(atom_num)[tmp]

return atom_num, atom_type_idx, type_num, type_atom_idx , cart , scale , cell
return atom_num, atom_type_idx, type_num, type_atom_idx, cart, scale, cell

# Init result Dictionary
def _init_result(self, type_num, structure_tags, structure_weights, idx, atom_type_idx):
Expand Down Expand Up @@ -282,7 +267,7 @@ def _check_error(self, errnos):
assert errno == 0

# Set resulatant Dictionary
def _set_result(self, result, x, dx, da, type_num, jtem, symf_params_set, atom_num):
def _set_result(self, result, x, dx, da, type_num, jtem, symf_params_set, atom_num):
if type_num[jtem] != 0:
# IF MPI available
#result['x'][jtem] = np.array(comm.gather(x, root=0))
Expand Down
21 changes: 11 additions & 10 deletions simple_nn_v2/utils/datagenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,12 @@ def parse_structure_list(self):
structure_tags(str list): list of structure tag
structure_weights(float list): list of structure weight
"""

structures = []
structure_tag_idx = []
structure_tags = ["None"]
structure_tags = ['None']
structure_weights = [1.0]
tag = "None"
tag = 'None'
weight = 1.0

with open(self.structure_list, 'r') as fil:
Expand All @@ -87,7 +88,7 @@ def parse_structure_list(self):
if len(line) == 0 or line.isspace():
continue
# 1. Extract structure tag and weight in between "[ ]"
elif line[0] == "[" and line[-1] == "]":
elif line[0] == '[' and line[-1] == ']':
tag_line = line[1:-1]
tag, weight = self._get_tag_and_weight(tag_line)

Expand Down Expand Up @@ -159,18 +160,18 @@ def load_snapshots(self, item):
file_path = item[0]
if len(item) == 1:
index = 0
self.logfile.write('{} 0'.format(file_path))
self.logfile.write("{} 0".format(file_path))
else:
if ':' in item[1]:
index = item[1]
else:
index = int(item[1])
self.logfile.write('{} {}'.format(file_path, item[1]))
self.logfile.write("{} {}".format(file_path, item[1]))

if self.inputs['refdata_format'] == 'vasp-out':
if self.inputs['compress_outcar']:
tmp_name = compress_outcar(file_path)
print(tmp_name)

if ase.__version__ >= '3.18.0':
snapshots = io.read(tmp_name, index=index, format=self.inputs['refdata_format'])
else:
Expand Down Expand Up @@ -211,11 +212,11 @@ def save_to_datafile(self, data, tag_idx):
self._data_idx += 1
try:
if self.inputs['save_to_pickle'] == False:
tmp_filename = os.path.join(self.data_dir, "data{}.pt".format(self._data_idx))
tmp_filename = os.path.join(self.data_dir, 'data{}.pt'.format(self._data_idx))
torch.save(data, tmp_filename)
elif self.inputs['save_to_pickle'] == True:
tmp_filename = os.path.join(self.data_dir, "data{}.pickle".format(self._data_idx))
with open(tmp_filename, "wb") as fil:
tmp_filename = os.path.join(self.data_dir, 'data{}.pickle'.format(self._data_idx))
with open(tmp_filename, 'wb') as fil:
pickle.dump(data, fil, protocol=2)
except:
self._data_idx -= 1
Expand All @@ -226,6 +227,6 @@ def save_to_datafile(self, data, tag_idx):
self.logfile.write("\nError: {:}\n".format(err))
raise NotImplementedError(err)

self._data_list_fil.write('{}:{}\n'.format(tag_idx, tmp_filename))
self._data_list_fil.write("{}:{}\n".format(tag_idx, tmp_filename))

return tmp_filename

0 comments on commit 078550a

Please sign in to comment.