From 8e2e3f3a99b705e93fda4c85a455fe05db13d08d Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Thu, 21 Nov 2024 00:37:33 -0500 Subject: [PATCH] chore: optimize compute_smooth_weight (#4390) New implmenetation is obviously more efficient. ## Summary by CodeRabbit - **New Features** - Enhanced the `compute_smooth_weight` functionality for improved efficiency and clarity by simplifying the distance handling logic. - **Bug Fixes** - Removed unnecessary masking conditions, ensuring smoother calculations within defined distance ranges. Signed-off-by: Jinzhe Zeng --- deepmd/dpmodel/utils/env_mat.py | 11 ++++------- deepmd/pt/utils/preprocess.py | 9 ++++----- 2 files changed, 8 insertions(+), 12 deletions(-) diff --git a/deepmd/dpmodel/utils/env_mat.py b/deepmd/dpmodel/utils/env_mat.py index abbd68945b..ee69716627 100644 --- a/deepmd/dpmodel/utils/env_mat.py +++ b/deepmd/dpmodel/utils/env_mat.py @@ -25,14 +25,11 @@ def compute_smooth_weight( if rmin >= rmax: raise ValueError("rmin should be less than rmax.") xp = array_api_compat.array_namespace(distance) - min_mask = distance <= rmin - max_mask = distance >= rmax - mid_mask = xp.logical_not(xp.logical_or(min_mask, max_mask)) + distance = xp.clip(distance, min=rmin, max=rmax) uu = (distance - rmin) / (rmax - rmin) - vv = uu * uu * uu * (-6.0 * uu * uu + 15.0 * uu - 10.0) + 1.0 - return vv * xp.astype(mid_mask, distance.dtype) + xp.astype( - min_mask, distance.dtype - ) + uu2 = uu * uu + vv = uu2 * uu * (-6.0 * uu2 + 15.0 * uu - 10.0) + 1.0 + return vv def _make_env_mat( diff --git a/deepmd/pt/utils/preprocess.py b/deepmd/pt/utils/preprocess.py index 7d5b0cf314..8ab489dede 100644 --- a/deepmd/pt/utils/preprocess.py +++ b/deepmd/pt/utils/preprocess.py @@ -10,9 +10,8 @@ def compute_smooth_weight(distance, rmin: float, rmax: float): """Compute smooth weight for descriptor elements.""" if rmin >= rmax: raise ValueError("rmin should be less than rmax.") - min_mask = distance <= rmin - max_mask = distance >= rmax - mid_mask = torch.logical_not(torch.logical_or(min_mask, max_mask)) + distance = torch.clamp(distance, min=rmin, max=rmax) uu = (distance - rmin) / (rmax - rmin) - vv = uu * uu * uu * (-6 * uu * uu + 15 * uu - 10) + 1 - return vv * mid_mask + min_mask + uu2 = uu * uu + vv = uu2 * uu * (-6 * uu2 + 15 * uu - 10) + 1 + return vv