diff --git a/deepmd/dpmodel/atomic_model/polar_atomic_model.py b/deepmd/dpmodel/atomic_model/polar_atomic_model.py index faff2444ea..3dfd0dfc99 100644 --- a/deepmd/dpmodel/atomic_model/polar_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/polar_atomic_model.py @@ -42,8 +42,7 @@ def apply_out_stat( for kk in self.bias_keys: ntypes = out_bias[kk].shape[0] temp = np.zeros(ntypes, dtype=dtype) - for i in range(ntypes): - temp[i] = np.mean(np.diagonal(out_bias[kk][i].reshape(3, 3))) + temp = np.mean(np.diagonal(out_bias[kk].reshape(ntypes, 3, 3), axis1=1, axis2=2), axis=1) modified_bias = temp[atype] # (nframes, nloc, 1) diff --git a/deepmd/pt/model/atomic_model/polar_atomic_model.py b/deepmd/pt/model/atomic_model/polar_atomic_model.py index fc6c351830..7d5014d5a2 100644 --- a/deepmd/pt/model/atomic_model/polar_atomic_model.py +++ b/deepmd/pt/model/atomic_model/polar_atomic_model.py @@ -43,8 +43,7 @@ def apply_out_stat( for kk in self.bias_keys: ntypes = out_bias[kk].shape[0] temp = torch.zeros(ntypes, dtype=dtype, device=device) - for i in range(ntypes): - temp[i] = torch.mean(torch.diagonal(out_bias[kk][i].reshape(3, 3))) + temp = torch.mean(torch.diagonal(out_bias[kk].reshape(ntypes, 3, 3), dim1=-2, dim2=-1), dim=-1) modified_bias = temp[atype] # (nframes, nloc, 1)