From 1dc7fd95c35e7727672e7ef765fe87344a379553 Mon Sep 17 00:00:00 2001 From: Anyang Peng <137014849+anyangml@users.noreply.github.com> Date: Tue, 19 Nov 2024 11:26:41 +0800 Subject: [PATCH] chore: vectorize for loop --- deepmd/dpmodel/atomic_model/polar_atomic_model.py | 3 +-- deepmd/pt/model/atomic_model/polar_atomic_model.py | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/deepmd/dpmodel/atomic_model/polar_atomic_model.py b/deepmd/dpmodel/atomic_model/polar_atomic_model.py index faff2444ea..3dfd0dfc99 100644 --- a/deepmd/dpmodel/atomic_model/polar_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/polar_atomic_model.py @@ -42,8 +42,7 @@ def apply_out_stat( for kk in self.bias_keys: ntypes = out_bias[kk].shape[0] temp = np.zeros(ntypes, dtype=dtype) - for i in range(ntypes): - temp[i] = np.mean(np.diagonal(out_bias[kk][i].reshape(3, 3))) + temp = np.mean(np.diagonal(out_bias[kk].reshape(ntypes, 3, 3), axis1=1, axis2=2), axis=1) modified_bias = temp[atype] # (nframes, nloc, 1) diff --git a/deepmd/pt/model/atomic_model/polar_atomic_model.py b/deepmd/pt/model/atomic_model/polar_atomic_model.py index fc6c351830..7d5014d5a2 100644 --- a/deepmd/pt/model/atomic_model/polar_atomic_model.py +++ b/deepmd/pt/model/atomic_model/polar_atomic_model.py @@ -43,8 +43,7 @@ def apply_out_stat( for kk in self.bias_keys: ntypes = out_bias[kk].shape[0] temp = torch.zeros(ntypes, dtype=dtype, device=device) - for i in range(ntypes): - temp[i] = torch.mean(torch.diagonal(out_bias[kk][i].reshape(3, 3))) + temp = torch.mean(torch.diagonal(out_bias[kk].reshape(ntypes, 3, 3), dim1=-2, dim2=-1), dim=-1) modified_bias = temp[atype] # (nframes, nloc, 1)