Skip to content

Commit

Permalink
rms_norm and rotary embedding back to combined impl
Browse files Browse the repository at this point in the history
  • Loading branch information
jingguo-st committed Sep 11, 2024
1 parent 2b1b705 commit eb4daba
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 3 deletions.
3 changes: 2 additions & 1 deletion deeplink_ext/internevo_ops/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

platform_type = deeplink_ext_get_platform_type()
if platform_type == PlatformType.TORCH_NPU:
from ._rotary_embedding_npu import ApplyRotaryEmb
# from ._rotary_embedding_npu import ApplyRotaryEmb
from .rotary_embedding_fallback import ApplyRotaryEmbTorch as ApplyRotaryEmb
elif platform_type == PlatformType.TORCH_DIPU:
from ._rotary_embedding_dipu import ApplyRotaryEmb
else:
Expand Down
3 changes: 2 additions & 1 deletion deeplink_ext/interntrain_ops/rms_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

platform_type = deeplink_ext_get_platform_type()
if platform_type == PlatformType.TORCH_NPU:
from ._mixed_rms_norm_npu import MixedFusedRMSNorm
# from ._mixed_rms_norm_npu import MixedFusedRMSNorm
from .rms_norm_fallback import MixedRMSNormTorch as MixedFusedRMSNorm
elif platform_type == PlatformType.TORCH_DIPU:
from ._mixed_rms_norm_dipu import MixedFusedRMSNorm
else:
Expand Down
4 changes: 3 additions & 1 deletion deeplink_ext/interntrain_ops/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@

platform_type = deeplink_ext_get_platform_type()
if platform_type == PlatformType.TORCH_NPU:
from ._rotary_embedding_npu import ApplyRotaryEmb, ApplyRotaryEmbQKV_
# from ._rotary_embedding_npu import ApplyRotaryEmb, ApplyRotaryEmbQKV_
from .rotary_embedding_fallback import ApplyRotaryEmbTorch as ApplyRotaryEmb
from .rotary_embedding_fallback import ApplyRotaryEmbQKV_Torch as ApplyRotaryEmbQKV_
elif platform_type == PlatformType.TORCH_DIPU:
from ._rotary_embedding_dipu import ApplyRotaryEmb, ApplyRotaryEmbQKV_
else:
Expand Down

0 comments on commit eb4daba

Please sign in to comment.