Skip to content

Commit

Permalink
revise expression
Browse files Browse the repository at this point in the history
  • Loading branch information
Seungwoo-Hwang committed Jun 15, 2021
1 parent f2347e5 commit 28007d6
Show file tree
Hide file tree
Showing 5 changed files with 120 additions and 119 deletions.
10 changes: 5 additions & 5 deletions simple_nn_v2/models/data_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def save_filename(self):
for f in self.filelist:
tmp_dict = torch.load(f)
tmp_dict['filename'] = f
torch.save(tmp_dict,f)
torch.save(tmp_dict, f)

#Used in Structure rmse
class StructlistDataset(FilelistDataset):
Expand Down Expand Up @@ -182,7 +182,7 @@ def my_collate(batch, atom_types, scale_factor=None, pca=None, pca_min_whiten_le
if use_stress:
S = torch.cat(S, axis=0)

return {'x': x, 'dx': dx, 'da': da, 'n': n, 'E': E, 'F': F, 'S': S, 'sp_idx': sparse_index, 'struct_weight': struct_weight, 'tot_num':tot_num}
return {'x': x, 'dx': dx, 'da': da, 'n': n, 'E': E, 'F': F, 'S': S, 'sp_idx': sparse_index, 'struct_weight': struct_weight, 'tot_num': tot_num}



Expand All @@ -200,14 +200,14 @@ def gen_sparse_index(nlist):
idx = 0
for i,item in enumerate(nlist):
for jtem in range(item):
res[0,idx] = i
res[1,idx] = idx
res[0, idx] = i
res[1, idx] = idx
idx += 1
return res


#Load collate from train, valid dataset
def _load_collate(inputs, logfile, scale_factor, pca, train_dataset, valid_dataset, batch_size=1,my_collate=my_collate):
def _load_collate(inputs, logfile, scale_factor, pca, train_dataset, valid_dataset, batch_size=1, my_collate=my_collate):
partial_collate = partial(
my_collate,
atom_types=inputs['atom_types'],
Expand Down
14 changes: 7 additions & 7 deletions simple_nn_v2/models/neural_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
import shutil

class FCNDict(torch.nn.Module):

def __init__(self, nets):
super(FCNDict, self).__init__()
self.nets = torch.nn.ModuleDict(nets)
self.keys = self.nets.keys()

def forward(self, x):
assert [item for item in self.nets.keys()].sort() == [item for item in x.keys()].sort()
res = {}
Expand Down Expand Up @@ -36,9 +36,9 @@ def write_lammps_potential(self, filename, inputs, scale_factor=None, pca=None):
FIL.write('SYM {}\n'.format(len(params)))

for ctem in params:
tmp_types = inputs['atom_types'][int(ctem[1])-1]
tmp_types = inputs['atom_types'][int(ctem[1]) - 1]
if int(ctem[0]) > 3:
tmp_types += ' {}'.format(inputs['atom_types'][int(ctem[2])-1])
tmp_types += ' {}'.format(inputs['atom_types'][int(ctem[2]) - 1])
if len(ctem) != 7:
raise ValueError("params file must have lines with 7 columns.")

Expand Down Expand Up @@ -97,7 +97,7 @@ def write_lammps_potential(self, filename, inputs, scale_factor=None, pca=None):

FIL.write('LAYER {} {}\n'.format(j+joffset, acti))

for k in range(nodes[j+joffset]):
for k in range(nodes[j + joffset]):
FIL.write('w{} {}\n'.format(k, ' '.join(weights[j][k,:].astype(np.str))))
FIL.write('b{} {}\n'.format(k, biases[j][k]))

Expand All @@ -106,13 +106,13 @@ def write_lammps_potential(self, filename, inputs, scale_factor=None, pca=None):
FIL.close()

class FCN(torch.nn.Module):
def __init__(self, dim_input, dim_hidden, acti_func='sigmoid',dropout=None):
def __init__(self, dim_input, dim_hidden, acti_func='sigmoid', dropout=None):
super(FCN, self).__init__()

self.lin = torch.nn.Sequential()

dim_in = dim_input
for i,hn in enumerate(dim_hidden):
for i, hn in enumerate(dim_hidden):
if dropout:
self.lin.add_module(f'drop_{i}', torch.nn.Dropout(p=dropout))
self.lin.add_module(f'lin_{i}', torch.nn.Linear(dim_in, hn))
Expand Down Expand Up @@ -157,7 +157,7 @@ def _read_until(fil, stop_tag):

weights = dict()
with open(filename) as fil:
atom_types = fil.readline().replace('\n','').split()[1:]
atom_types = fil.readline().replace('\n', '').split()[1:]
for item in atom_types:
weights[item] = dict()

Expand Down
Loading

0 comments on commit 28007d6

Please sign in to comment.