Skip to content

Commit

Permalink
chore: optimize compute_smooth_weight (#4390)
Browse files Browse the repository at this point in the history
New implmenetation is obviously more efficient.

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

- **New Features**
- Enhanced the `compute_smooth_weight` functionality for improved
efficiency and clarity by simplifying the distance handling logic.
  
- **Bug Fixes**
- Removed unnecessary masking conditions, ensuring smoother calculations
within defined distance ranges.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz authored Nov 21, 2024
1 parent 6febc71 commit 8e2e3f3
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 12 deletions.
11 changes: 4 additions & 7 deletions deepmd/dpmodel/utils/env_mat.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,11 @@ def compute_smooth_weight(
if rmin >= rmax:
raise ValueError("rmin should be less than rmax.")
xp = array_api_compat.array_namespace(distance)
min_mask = distance <= rmin
max_mask = distance >= rmax
mid_mask = xp.logical_not(xp.logical_or(min_mask, max_mask))
distance = xp.clip(distance, min=rmin, max=rmax)
uu = (distance - rmin) / (rmax - rmin)
vv = uu * uu * uu * (-6.0 * uu * uu + 15.0 * uu - 10.0) + 1.0
return vv * xp.astype(mid_mask, distance.dtype) + xp.astype(
min_mask, distance.dtype
)
uu2 = uu * uu
vv = uu2 * uu * (-6.0 * uu2 + 15.0 * uu - 10.0) + 1.0
return vv


def _make_env_mat(
Expand Down
9 changes: 4 additions & 5 deletions deepmd/pt/utils/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,8 @@ def compute_smooth_weight(distance, rmin: float, rmax: float):
"""Compute smooth weight for descriptor elements."""
if rmin >= rmax:
raise ValueError("rmin should be less than rmax.")
min_mask = distance <= rmin
max_mask = distance >= rmax
mid_mask = torch.logical_not(torch.logical_or(min_mask, max_mask))
distance = torch.clamp(distance, min=rmin, max=rmax)
uu = (distance - rmin) / (rmax - rmin)
vv = uu * uu * uu * (-6 * uu * uu + 15 * uu - 10) + 1
return vv * mid_mask + min_mask
uu2 = uu * uu
vv = uu2 * uu * (-6 * uu2 + 15 * uu - 10) + 1
return vv

0 comments on commit 8e2e3f3

Please sign in to comment.