Skip to content

Commit

Permalink
chore: vectorize for loop
Browse files Browse the repository at this point in the history
  • Loading branch information
anyangml committed Nov 19, 2024
1 parent 9071e73 commit 1dc7fd9
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 4 deletions.
3 changes: 1 addition & 2 deletions deepmd/dpmodel/atomic_model/polar_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,7 @@ def apply_out_stat(
for kk in self.bias_keys:
ntypes = out_bias[kk].shape[0]
temp = np.zeros(ntypes, dtype=dtype)
for i in range(ntypes):
temp[i] = np.mean(np.diagonal(out_bias[kk][i].reshape(3, 3)))
temp = np.mean(np.diagonal(out_bias[kk].reshape(ntypes, 3, 3), axis1=1, axis2=2), axis=1)
modified_bias = temp[atype]

# (nframes, nloc, 1)
Expand Down
3 changes: 1 addition & 2 deletions deepmd/pt/model/atomic_model/polar_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,7 @@ def apply_out_stat(
for kk in self.bias_keys:
ntypes = out_bias[kk].shape[0]
temp = torch.zeros(ntypes, dtype=dtype, device=device)
for i in range(ntypes):
temp[i] = torch.mean(torch.diagonal(out_bias[kk][i].reshape(3, 3)))
temp = torch.mean(torch.diagonal(out_bias[kk].reshape(ntypes, 3, 3), dim1=-2, dim2=-1), dim=-1)
modified_bias = temp[atype]

# (nframes, nloc, 1)
Expand Down

0 comments on commit 1dc7fd9

Please sign in to comment.