Skip to content

Commit

Permalink
add code comment
Browse files Browse the repository at this point in the history
  • Loading branch information
POI-WX authored Sep 13, 2024
1 parent 599ec8d commit 082ac72
Showing 1 changed file with 2 additions and 0 deletions.
2 changes: 2 additions & 0 deletions deeplink_ext/interntrain_ops/rms_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
platform_type = deeplink_ext_get_platform_type()
if platform_type == PlatformType.TORCH_NPU:
# from ._mixed_rms_norm_npu import MixedFusedRMSNorm
# Due to the accuracy problem of the npu fused operator, a torch combination is used as an alternative.
from .rms_norm_fallback import MixedRMSNormTorch as MixedFusedRMSNorm
elif platform_type == PlatformType.TORCH_DIPU:
# from ._mixed_rms_norm_dipu import MixedFusedRMSNorm
# Due to the accuracy problem of the npu fused operator, a torch combination is used as an alternative.
from .rms_norm_fallback import MixedRMSNormTorch as MixedFusedRMSNorm
else:
raise ImportError
Expand Down

0 comments on commit 082ac72

Please sign in to comment.