Skip to content
This repository has been archived by the owner on May 19, 2022. It is now read-only.

Commit

Permalink
Merge pull request #77 from JisuJung928/master
Browse files Browse the repository at this point in the history
lammps and multi-component stress bug fix
  • Loading branch information
JisuJung928 authored Feb 21, 2020
2 parents 6328066 + 4df944a commit ee3359a
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 90 deletions.
143 changes: 58 additions & 85 deletions simple_nn/features/symmetry_function/pair_nn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ void PairNN::compute(int eflag, int vflag)
{
int i,j,k,n,ii,jj,kk,tt,nn,inum,jnum;
int itype,jtype,ktype,ielem,jelem,kelem;
int jbeg,jend,kbeg,kend;
double xtmp,ytmp,ztmp,dradtmp,tmpc,tmpE;
double dangtmp[3];
double tmpd[9];
Expand Down Expand Up @@ -188,47 +187,49 @@ void PairNN::compute(int eflag, int vflag)
}

// calc radial symfunc
jbeg = nets[ielem].ridx[jelem];
jend = jbeg + nets[ielem].rsym[jelem];
for (tt=jbeg; tt<jend; tt++) {
for (tt=0; tt<nsym; tt++) {
sym = &nets[ielem].slists[tt];
if (rRij > sym->coefs[0]) continue;
cutf2(rRij, sym->coefs[0], precal[0], precal[1], 0);

symvec[tt] += G2(rRij, precal, sym->coefs, dradtmp);

tmpd[0] = dradtmp*vecij[0];
tmpd[1] = dradtmp*vecij[1];
tmpd[2] = dradtmp*vecij[2];

tmpf[tt*(jnum+1)*3 + jj*3] += tmpd[0];
tmpf[tt*(jnum+1)*3 + jj*3 + 1] += tmpd[1];
tmpf[tt*(jnum+1)*3 + jj*3 + 2] += tmpd[2];

tmpf[tt*(jnum+1)*3 + jnum*3] -= tmpd[0];
tmpf[tt*(jnum+1)*3 + jnum*3 + 1] -= tmpd[1];
tmpf[tt*(jnum+1)*3 + jnum*3 + 2] -= tmpd[2];

if (vflag_atom) {
tmps[tt*3*6] += tmpd[0]*lcoeff[0]*cell[0][0];
tmps[tt*3*6 + 1] += tmpd[1]*lcoeff[0]*cell[0][1];
tmps[tt*3*6 + 2] += tmpd[2]*lcoeff[0]*cell[0][2];
tmps[tt*3*6 + 3] += tmpd[0]*lcoeff[0]*cell[0][1];
tmps[tt*3*6 + 4] += tmpd[1]*lcoeff[0]*cell[0][2];
tmps[tt*3*6 + 5] += tmpd[2]*lcoeff[0]*cell[0][0];
tmps[tt*3*6 + 6] += tmpd[0]*lcoeff[1]*cell[1][0];
tmps[tt*3*6 + 7] += tmpd[1]*lcoeff[1]*cell[1][1];
tmps[tt*3*6 + 8] += tmpd[2]*lcoeff[1]*cell[1][2];
tmps[tt*3*6 + 9] += tmpd[0]*lcoeff[1]*cell[1][1];
tmps[tt*3*6 + 10] += tmpd[1]*lcoeff[1]*cell[1][2];
tmps[tt*3*6 + 11] += tmpd[2]*lcoeff[1]*cell[1][0];
tmps[tt*3*6 + 12] += tmpd[0]*lcoeff[2]*cell[2][0];
tmps[tt*3*6 + 13] += tmpd[1]*lcoeff[2]*cell[2][1];
tmps[tt*3*6 + 14] += tmpd[2]*lcoeff[2]*cell[2][2];
tmps[tt*3*6 + 15] += tmpd[0]*lcoeff[2]*cell[2][1];
tmps[tt*3*6 + 16] += tmpd[1]*lcoeff[2]*cell[2][2];
tmps[tt*3*6 + 17] += tmpd[2]*lcoeff[2]*cell[2][0];
if (sym->atype[0] != jelem) continue;
if (sym->stype == 2) {
cutf2(rRij, sym->coefs[0], precal[0], precal[1], 0);

symvec[tt] += G2(rRij, precal, sym->coefs, dradtmp);

tmpd[0] = dradtmp*vecij[0];
tmpd[1] = dradtmp*vecij[1];
tmpd[2] = dradtmp*vecij[2];

tmpf[tt*(jnum+1)*3 + jj*3] += tmpd[0];
tmpf[tt*(jnum+1)*3 + jj*3 + 1] += tmpd[1];
tmpf[tt*(jnum+1)*3 + jj*3 + 2] += tmpd[2];

tmpf[tt*(jnum+1)*3 + jnum*3] -= tmpd[0];
tmpf[tt*(jnum+1)*3 + jnum*3 + 1] -= tmpd[1];
tmpf[tt*(jnum+1)*3 + jnum*3 + 2] -= tmpd[2];

if (vflag_atom) {
tmps[tt*3*6] += tmpd[0]*lcoeff[0]*cell[0][0];
tmps[tt*3*6 + 1] += tmpd[1]*lcoeff[0]*cell[0][1];
tmps[tt*3*6 + 2] += tmpd[2]*lcoeff[0]*cell[0][2];
tmps[tt*3*6 + 3] += tmpd[0]*lcoeff[0]*cell[0][1];
tmps[tt*3*6 + 4] += tmpd[1]*lcoeff[0]*cell[0][2];
tmps[tt*3*6 + 5] += tmpd[2]*lcoeff[0]*cell[0][0];
tmps[tt*3*6 + 6] += tmpd[0]*lcoeff[1]*cell[1][0];
tmps[tt*3*6 + 7] += tmpd[1]*lcoeff[1]*cell[1][1];
tmps[tt*3*6 + 8] += tmpd[2]*lcoeff[1]*cell[1][2];
tmps[tt*3*6 + 9] += tmpd[0]*lcoeff[1]*cell[1][1];
tmps[tt*3*6 + 10] += tmpd[1]*lcoeff[1]*cell[1][2];
tmps[tt*3*6 + 11] += tmpd[2]*lcoeff[1]*cell[1][0];
tmps[tt*3*6 + 12] += tmpd[0]*lcoeff[2]*cell[2][0];
tmps[tt*3*6 + 13] += tmpd[1]*lcoeff[2]*cell[2][1];
tmps[tt*3*6 + 14] += tmpd[2]*lcoeff[2]*cell[2][2];
tmps[tt*3*6 + 15] += tmpd[0]*lcoeff[2]*cell[2][1];
tmps[tt*3*6 + 16] += tmpd[1]*lcoeff[2]*cell[2][2];
tmps[tt*3*6 + 17] += tmpd[2]*lcoeff[2]*cell[2][0];
}
}
else continue;
}

if (rRij > max_rc_ang) continue;
Expand Down Expand Up @@ -288,11 +289,11 @@ void PairNN::compute(int eflag, int vflag)
}

// calc angular symfunc
kbeg = nets[ielem].aidx[jelem*nelements + kelem];
kend = kbeg + nets[ielem].asym[jelem*nelements + kelem];
for (tt=kbeg; tt<kend; tt++) {
for (tt=0; tt<nsym; tt++) {
sym = &nets[ielem].slists[tt];
if (rRik > sym->coefs[0]) continue;
if (!((sym->atype[0] == jelem && sym->atype[1] == kelem) || \
(sym->atype[0] == kelem && sym->atype[1] == jelem))) continue;
if ((sym->stype) == 4) {
if (rRjk > sym->coefs[0]) continue;
cutf2(rRij, nets[ielem].slists[tt].coefs[0], precal[0], precal[1], 0);
Expand Down Expand Up @@ -644,10 +645,6 @@ void PairNN::read_file(char *fname) {
nsym = atoi(strtok(NULL," \t\n\r\f"));
nets[nnet].slists = new Symc[nsym];
nets[nnet].powtwo = new double[nsym];
nets[nnet].rsym = new int[nelements]();
nets[nnet].asym = new int[nelements*nelements]();
nets[nnet].ridx = new int[nelements];
nets[nnet].aidx = new int[nelements*nelements];
nets[nnet].scale = new double*[2];
for (i=0; i<2; ++i) {
nets[nnet].scale[i] = new double[nsym];
Expand All @@ -661,34 +658,24 @@ void PairNN::read_file(char *fname) {
nets[nnet].slists[isym].coefs[2] = atof(strtok(NULL," \t\n\r\f"));
nets[nnet].slists[isym].coefs[3] = atof(strtok(NULL," \t\n\r\f"));

tstr = strtok(NULL," \t\n\r\f");
nets[nnet].slists[isym].atype[0] = nelements;
for (i=0; i<nelements; i++) {
if (strcmp(tstr, elements[i]) == 0) {
nets[nnet].slists[isym].atype[0] = i;
break;
}
}
// In this code, SF type >= 4 means that it is angular function.
if (nets[nnet].slists[isym].stype >= 4) {
// Find maximum cutoff distance among angular functions.
max_rc_ang = max(max_rc_ang, nets[nnet].slists[isym].coefs[0]);
nsf[nets[nnet].slists[isym].stype]++;
char *tstrj = strtok(NULL," \t\n\r\f");
char *tstrk = strtok(NULL," \t\n\r\f");
for (j=0; j<nelements; j++) {
if (strcmp(tstrj,elements[j]) == 0) {
nets[nnet].slists[isym].atype[0] = j;
for (k=0; k<nelements; k++) {
if (strcmp(tstrk,elements[k]) == 0) {
nets[nnet].slists[isym].atype[1] = k;
nets[nnet].asym[j*nelements + k]++;
if (j != k) {
nets[nnet].asym[k*nelements + j]++;
}
break;
}
}
break;
}
}
} else {
char *tstrj = strtok(NULL," \t\n\r\f");
for (j=0; j<nelements; j++) {
if (strcmp(tstrj,elements[j]) == 0) {
nets[nnet].slists[isym].atype[0] = j;
nets[nnet].rsym[j]++;
nsf[nets[nnet].slists[isym].stype] += 1;
tstr = strtok(NULL," \t\n\r\f");
nets[nnet].slists[isym].atype[1] = nelements;
for (i=0; i<nelements; i++) {
if (strcmp(tstr,elements[i]) == 0) {
nets[nnet].slists[isym].atype[1] = i;
break;
}
}
Expand All @@ -714,20 +701,6 @@ void PairNN::read_file(char *fname) {
if (isym == nsym) {
stats = 4;
iscale = 0;
int tmpidx = 0;
for (j=0; j<nelements; j++) {
nets[nnet].ridx[j] = tmpidx;
tmpidx += nets[nnet].rsym[j];
}
for (j=0; j<nelements; j++) {
for (k=j; k<nelements; k++) {
nets[nnet].aidx[j*nelements + k] = tmpidx;
if (j != k) {
nets[nnet].aidx[k*nelements + j] = tmpidx;
}
tmpidx += nets[nnet].asym[j*nelements + k];
}
}
}
} else if (stats == 4) { // scale
tstr = strtok(line," \t\n\r\f");
Expand Down
4 changes: 0 additions & 4 deletions simple_nn/features/symmetry_function/pair_nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,6 @@ class PairNN : public Pair {
double **scale; // scale
Symc *slists; // symmetry function related parameters
double *powtwo; // power of two
int *rsym; // # of radial symmetry function
int *asym; // # of angular symmetry function
int *ridx; // start index of radial symmetry function
int *aidx; // start index of angular symmetry function
bool *powint;
};

Expand Down
2 changes: 1 addition & 1 deletion simple_nn/models/neural_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ def _calc_output(self):
tf.expand_dims(self.dys[item], axis=2),
axis=3)
tmp_stress = tf.cond(zero_cond,
lambda: tf.cast(0., tf.float64),
lambda: tf.cast(0., tf.float64) * tmp_stress,
lambda: tf.sparse_segment_sum(tmp_stress, self.next_elem['sparse_indices_'+item], self.next_elem['seg_id_'+item],
num_segments=self.next_elem['num_seg'])[1:])
self.S -= tf.reduce_sum(tmp_stress, axis=[1,2])/units.GPa*10
Expand Down

0 comments on commit ee3359a

Please sign in to comment.