diff --git a/deeplink_ext/internevo_ops/rotary_embedding.py b/deeplink_ext/internevo_ops/rotary_embedding.py index 7764b9b..1a2a36d 100644 --- a/deeplink_ext/internevo_ops/rotary_embedding.py +++ b/deeplink_ext/internevo_ops/rotary_embedding.py @@ -4,7 +4,8 @@ platform_type = deeplink_ext_get_platform_type() if platform_type == PlatformType.TORCH_NPU: - from ._rotary_embedding_npu import ApplyRotaryEmb + # from ._rotary_embedding_npu import ApplyRotaryEmb + from .rotary_embedding_fallback import ApplyRotaryEmbTorch as ApplyRotaryEmb elif platform_type == PlatformType.TORCH_DIPU: from ._rotary_embedding_dipu import ApplyRotaryEmb else: diff --git a/deeplink_ext/interntrain_ops/rms_norm.py b/deeplink_ext/interntrain_ops/rms_norm.py index e6834cb..e6e3f06 100644 --- a/deeplink_ext/interntrain_ops/rms_norm.py +++ b/deeplink_ext/interntrain_ops/rms_norm.py @@ -4,7 +4,8 @@ platform_type = deeplink_ext_get_platform_type() if platform_type == PlatformType.TORCH_NPU: - from ._mixed_rms_norm_npu import MixedFusedRMSNorm + # from ._mixed_rms_norm_npu import MixedFusedRMSNorm + from .rms_norm_fallback import MixedRMSNormTorch as MixedFusedRMSNorm elif platform_type == PlatformType.TORCH_DIPU: from ._mixed_rms_norm_dipu import MixedFusedRMSNorm else: diff --git a/deeplink_ext/interntrain_ops/rotary_embedding.py b/deeplink_ext/interntrain_ops/rotary_embedding.py index 8609a38..1805b67 100644 --- a/deeplink_ext/interntrain_ops/rotary_embedding.py +++ b/deeplink_ext/interntrain_ops/rotary_embedding.py @@ -4,7 +4,9 @@ platform_type = deeplink_ext_get_platform_type() if platform_type == PlatformType.TORCH_NPU: - from ._rotary_embedding_npu import ApplyRotaryEmb, ApplyRotaryEmbQKV_ + # from ._rotary_embedding_npu import ApplyRotaryEmb, ApplyRotaryEmbQKV_ + from .rotary_embedding_fallback import ApplyRotaryEmbTorch as ApplyRotaryEmb + from .rotary_embedding_fallback import ApplyRotaryEmbQKV_Torch as ApplyRotaryEmbQKV_ elif platform_type == PlatformType.TORCH_DIPU: from ._rotary_embedding_dipu import ApplyRotaryEmb, ApplyRotaryEmbQKV_ else: