From b422fa86a48604a4757933985ad06f403694a2b9 Mon Sep 17 00:00:00 2001 From: HuayiL <442488254@qq.com> Date: Tue, 5 Nov 2024 17:37:24 +0800 Subject: [PATCH 1/7] refactor python code refactor python code by ops --- .../ascend_speed/_flash_attention_dipu.py | 1 - deeplink_ext/ascend_speed/_rms_norm_dipu.py | 1 - .../_scaled_masked_softmax_dipu.py | 1 - .../_scaled_masked_softmax_npu.py | 1 - deeplink_ext/easyllm_ops/__init__.py | 40 +-- deeplink_ext/easyllm_ops/adamw.py | 5 - deeplink_ext/easyllm_ops/flash_attention.py | 19 - .../easyllm_ops/flash_attention_fallback.py | 20 -- deeplink_ext/internevo_ops/__init__.py | 45 +-- .../internevo_ops/_flash_attention_npu.py | 337 ------------------ .../internevo_ops/_rotary_embedding_npu.py | 75 ---- deeplink_ext/internevo_ops/adamw.py | 5 - deeplink_ext/internevo_ops/rms_norm.py | 5 - .../internevo_ops/rms_norm_fallback.py | 5 - .../internevo_ops/rotary_embedding.py | 14 - deeplink_ext/interntrain_ops/__init__.py | 28 +- .../interntrain_ops/_flash_attention_npu.py | 291 --------------- .../interntrain_ops/_mixed_rms_norm_npu.py | 58 --- .../interntrain_ops/_rotary_embedding_npu.py | 171 --------- .../interntrain_ops/flash_attention.py | 13 - deeplink_ext/interntrain_ops/rms_norm.py | 17 - .../interntrain_ops/rotary_embedding.py | 15 - .../adamw.py => ops/adamw/__init__.py} | 5 +- .../adamw}/_adamw_dipu.py | 1 - deeplink_ext/ops/bert_padding/__init__.py | 3 + .../bert_padding}/bert_padding.py | 0 deeplink_ext/ops/flash_attention/__init__.py | 32 ++ .../internevo_flash_attention.py} | 13 +- .../internevo_flash_attention_dipu.py} | 12 - .../internevo_flash_attention_fallback.py} | 0 .../interntrain_flash_attention.py | 14 + .../interntrain_flash_attention_dipu.py} | 9 - .../interntrain_flash_attention_fallback.py} | 0 deeplink_ext/ops/rms_norm/__init__.py | 17 + .../rms_norm/easyllm_rms_norm.py} | 2 +- .../ops/rms_norm/easyllm_rms_norm_dipu.py | 49 +++ .../rms_norm/easyllm_rms_norm_fallback.py} | 0 .../internevo_mixed_rms_norm_dipu.py} | 2 - .../ops/rms_norm/internevo_rms_norm.py | 13 + .../rms_norm/internevo_rms_norm_fallback.py} | 0 deeplink_ext/ops/rotary_embedding/__init__.py | 20 ++ .../internevo_rotary_embedding.py | 11 + .../internevo_rotary_embedding_dipu.py} | 0 .../internevo_rotary_embedding_fallback.py} | 0 .../interntrain_rotary_embedding.py | 11 + .../interntrain_rotary_embedding_dipu.py} | 0 .../interntrain_rotary_embedding_fallback.py} | 0 47 files changed, 200 insertions(+), 1181 deletions(-) delete mode 100644 deeplink_ext/easyllm_ops/adamw.py delete mode 100644 deeplink_ext/easyllm_ops/flash_attention.py delete mode 100644 deeplink_ext/easyllm_ops/flash_attention_fallback.py delete mode 100644 deeplink_ext/internevo_ops/_flash_attention_npu.py delete mode 100644 deeplink_ext/internevo_ops/_rotary_embedding_npu.py delete mode 100644 deeplink_ext/internevo_ops/adamw.py delete mode 100644 deeplink_ext/internevo_ops/rms_norm.py delete mode 100644 deeplink_ext/internevo_ops/rms_norm_fallback.py delete mode 100644 deeplink_ext/internevo_ops/rotary_embedding.py delete mode 100644 deeplink_ext/interntrain_ops/_flash_attention_npu.py delete mode 100644 deeplink_ext/interntrain_ops/_mixed_rms_norm_npu.py delete mode 100644 deeplink_ext/interntrain_ops/_rotary_embedding_npu.py delete mode 100644 deeplink_ext/interntrain_ops/flash_attention.py delete mode 100644 deeplink_ext/interntrain_ops/rms_norm.py delete mode 100644 deeplink_ext/interntrain_ops/rotary_embedding.py rename deeplink_ext/{interntrain_ops/adamw.py => ops/adamw/__init__.py} (68%) rename deeplink_ext/{interntrain_ops => ops/adamw}/_adamw_dipu.py (99%) create mode 100644 deeplink_ext/ops/bert_padding/__init__.py rename deeplink_ext/{easyllm_ops => ops/bert_padding}/bert_padding.py (100%) create mode 100644 deeplink_ext/ops/flash_attention/__init__.py rename deeplink_ext/{internevo_ops/flash_attention.py => ops/flash_attention/internevo_flash_attention.py} (61%) rename deeplink_ext/{internevo_ops/_flash_attention_dipu.py => ops/flash_attention/internevo_flash_attention_dipu.py} (99%) rename deeplink_ext/{internevo_ops/flash_attention_fallback.py => ops/flash_attention/internevo_flash_attention_fallback.py} (100%) create mode 100644 deeplink_ext/ops/flash_attention/interntrain_flash_attention.py rename deeplink_ext/{interntrain_ops/_flash_attention_dipu.py => ops/flash_attention/interntrain_flash_attention_dipu.py} (99%) rename deeplink_ext/{interntrain_ops/flash_attention_fallback.py => ops/flash_attention/interntrain_flash_attention_fallback.py} (100%) create mode 100644 deeplink_ext/ops/rms_norm/__init__.py rename deeplink_ext/{easyllm_ops/rms_norm.py => ops/rms_norm/easyllm_rms_norm.py} (71%) create mode 100644 deeplink_ext/ops/rms_norm/easyllm_rms_norm_dipu.py rename deeplink_ext/{easyllm_ops/rms_norm_fallback.py => ops/rms_norm/easyllm_rms_norm_fallback.py} (100%) rename deeplink_ext/{interntrain_ops/_mixed_rms_norm_dipu.py => ops/rms_norm/internevo_mixed_rms_norm_dipu.py} (99%) create mode 100644 deeplink_ext/ops/rms_norm/internevo_rms_norm.py rename deeplink_ext/{interntrain_ops/rms_norm_fallback.py => ops/rms_norm/internevo_rms_norm_fallback.py} (100%) create mode 100644 deeplink_ext/ops/rotary_embedding/__init__.py create mode 100644 deeplink_ext/ops/rotary_embedding/internevo_rotary_embedding.py rename deeplink_ext/{internevo_ops/_rotary_embedding_dipu.py => ops/rotary_embedding/internevo_rotary_embedding_dipu.py} (100%) rename deeplink_ext/{internevo_ops/rotary_embedding_fallback.py => ops/rotary_embedding/internevo_rotary_embedding_fallback.py} (100%) create mode 100644 deeplink_ext/ops/rotary_embedding/interntrain_rotary_embedding.py rename deeplink_ext/{interntrain_ops/_rotary_embedding_dipu.py => ops/rotary_embedding/interntrain_rotary_embedding_dipu.py} (100%) rename deeplink_ext/{interntrain_ops/rotary_embedding_fallback.py => ops/rotary_embedding/interntrain_rotary_embedding_fallback.py} (100%) diff --git a/deeplink_ext/ascend_speed/_flash_attention_dipu.py b/deeplink_ext/ascend_speed/_flash_attention_dipu.py index e5ee61d1..d6f3b41a 100644 --- a/deeplink_ext/ascend_speed/_flash_attention_dipu.py +++ b/deeplink_ext/ascend_speed/_flash_attention_dipu.py @@ -9,7 +9,6 @@ class FlashSelfAttention(torch.autograd.Function): - @staticmethod def forward( ctx, q, k, v, attention_mask, dropout_p, softmax_scale, head_num, input_layout diff --git a/deeplink_ext/ascend_speed/_rms_norm_dipu.py b/deeplink_ext/ascend_speed/_rms_norm_dipu.py index 16d3502d..7d4c237f 100644 --- a/deeplink_ext/ascend_speed/_rms_norm_dipu.py +++ b/deeplink_ext/ascend_speed/_rms_norm_dipu.py @@ -9,7 +9,6 @@ class RMSNorm(torch.autograd.Function): - @staticmethod def forward(ctx, hidden_states, weight, eps): output = torch.empty_like(hidden_states) diff --git a/deeplink_ext/ascend_speed/_scaled_masked_softmax_dipu.py b/deeplink_ext/ascend_speed/_scaled_masked_softmax_dipu.py index b20f7ee7..47f324d8 100644 --- a/deeplink_ext/ascend_speed/_scaled_masked_softmax_dipu.py +++ b/deeplink_ext/ascend_speed/_scaled_masked_softmax_dipu.py @@ -11,7 +11,6 @@ class ScaledMaskedSoftmax(torch.autograd.Function): - @staticmethod def forward(ctx, input, mask, scale, fixed_triu_mask): out = torch.empty_like(input) diff --git a/deeplink_ext/ascend_speed/_scaled_masked_softmax_npu.py b/deeplink_ext/ascend_speed/_scaled_masked_softmax_npu.py index b02e3c81..a4a1d066 100644 --- a/deeplink_ext/ascend_speed/_scaled_masked_softmax_npu.py +++ b/deeplink_ext/ascend_speed/_scaled_masked_softmax_npu.py @@ -7,7 +7,6 @@ class ScaledMaskedSoftmax(torch.autograd.Function): - @staticmethod def forward(ctx, input, mask, scale, fixed_triu_mask): out = torch_npu.npu_scaled_masked_softmax(input, mask, scale, fixed_triu_mask) diff --git a/deeplink_ext/easyllm_ops/__init__.py b/deeplink_ext/easyllm_ops/__init__.py index 439bd0be..a0fb1f22 100644 --- a/deeplink_ext/easyllm_ops/__init__.py +++ b/deeplink_ext/easyllm_ops/__init__.py @@ -3,40 +3,22 @@ _not_impl = "[deeplink_ext] {op_name} is not implemented in diopi. Falling back to the slower torch implementation." try: - from .adamw import AdamW + from deeplink_ext.ops.adamw import AdamW except Exception as e: print(_not_impl.format(op_name="adamw")) from torch.optim import AdamW -try: - from .flash_attention import ( - flash_attn_qkvpacked_func, - flash_attn_kvpacked_func, - flash_attn_func, - flash_attn_varlen_qkvpacked_func, - flash_attn_varlen_kvpacked_func, - flash_attn_varlen_func, - ) -except Exception as e: - print(_not_impl.format(op_name="flash attention")) - from .flash_attention_fallback import ( - flash_attn_qkvpacked_func_torch as flash_attn_qkvpacked_func, - flash_attn_kvpacked_func_torch as flash_attn_kvpacked_func, - flash_attn_func_torch as flash_attn_func, - flash_attn_varlen_qkvpacked_func_torch as flash_attn_varlen_qkvpacked_func, - flash_attn_varlen_kvpacked_func_torch as flash_attn_varlen_kvpacked_func, - flash_attn_varlen_func_torch as flash_attn_varlen_func, - ) - -try: - from .rms_norm import rms_norm -except: - print( - _not_impl.format(op_name="RMSNorm"), - ) - from .rms_norm_fallback import rms_norm_torch as rms_norm +from deeplink_ext.ops.flash_attention import ( + flash_attn_qkvpacked_func, + flash_attn_kvpacked_func, + flash_attn_func, + flash_attn_varlen_qkvpacked_func, + flash_attn_varlen_kvpacked_func, + flash_attn_varlen_func, +) -from .bert_padding import pad_input, unpad_input, index_first_axis +from deeplink_ext.ops.rms_norm import rms_norm +from deeplink_ext.ops.bert_padding import pad_input, unpad_input, index_first_axis __all__ = [ "AdamW", diff --git a/deeplink_ext/easyllm_ops/adamw.py b/deeplink_ext/easyllm_ops/adamw.py deleted file mode 100644 index 59afad50..00000000 --- a/deeplink_ext/easyllm_ops/adamw.py +++ /dev/null @@ -1,5 +0,0 @@ -# Copyright (c) 2024, DeepLink. - -from deeplink_ext.interntrain_ops.adamw import AdamW - -__all__ = ["AdamW"] diff --git a/deeplink_ext/easyllm_ops/flash_attention.py b/deeplink_ext/easyllm_ops/flash_attention.py deleted file mode 100644 index b9c9437f..00000000 --- a/deeplink_ext/easyllm_ops/flash_attention.py +++ /dev/null @@ -1,19 +0,0 @@ -# Copyright (c) 2024, DeepLink. - -from deeplink_ext.internevo_ops.flash_attention import ( - flash_attn_qkvpacked_func, - flash_attn_kvpacked_func, - flash_attn_func, - flash_attn_varlen_qkvpacked_func, - flash_attn_varlen_kvpacked_func, - flash_attn_varlen_func, -) - -__all__ = [ - "flash_attn_qkvpacked_func", - "flash_attn_kvpacked_func", - "flash_attn_func", - "flash_attn_varlen_qkvpacked_func", - "flash_attn_varlen_kvpacked_func", - "flash_attn_varlen_func", -] diff --git a/deeplink_ext/easyllm_ops/flash_attention_fallback.py b/deeplink_ext/easyllm_ops/flash_attention_fallback.py deleted file mode 100644 index e781ae1e..00000000 --- a/deeplink_ext/easyllm_ops/flash_attention_fallback.py +++ /dev/null @@ -1,20 +0,0 @@ -# Copyright (c) 2024, DeepLink. - -from deeplink_ext.internevo_ops.flash_attention_fallback import ( - flash_attn_qkvpacked_func_torch, - flash_attn_kvpacked_func_torch, - flash_attn_func_torch, - flash_attn_varlen_qkvpacked_func_torch, - flash_attn_varlen_kvpacked_func_torch, - flash_attn_varlen_func_torch, -) - - -__all__ = [ - "flash_attn_qkvpacked_func_torch", - "flash_attn_kvpacked_func_torch", - "flash_attn_func_torch", - "flash_attn_varlen_qkvpacked_func_torch", - "flash_attn_varlen_kvpacked_func_torch", - "flash_attn_varlen_func_torch", -] diff --git a/deeplink_ext/internevo_ops/__init__.py b/deeplink_ext/internevo_ops/__init__.py index b2c86f6e..41bbae8d 100644 --- a/deeplink_ext/internevo_ops/__init__.py +++ b/deeplink_ext/internevo_ops/__init__.py @@ -1,46 +1,23 @@ # Copyright (c) 2024, DeepLink. -_not_impl = "[deeplink_ext] {op_name} is not implemented in diopi. Falling back to the slower torch implementation." - try: - from .adamw import AdamW + from deeplink_ext.ops.adamw import AdamW except Exception as e: print(_not_impl.format(op_name="adamw")) from torch.optim import AdamW -try: - from .flash_attention import ( - flash_attn_qkvpacked_func, - flash_attn_kvpacked_func, - flash_attn_func, - flash_attn_varlen_qkvpacked_func, - flash_attn_varlen_kvpacked_func, - flash_attn_varlen_func, - ) -except Exception as e: - print(_not_impl.format(op_name="flash attention")) - from .flash_attention_fallback import ( - flash_attn_qkvpacked_func_torch as flash_attn_qkvpacked_func, - flash_attn_kvpacked_func_torch as flash_attn_kvpacked_func, - flash_attn_func_torch as flash_attn_func, - flash_attn_varlen_qkvpacked_func_torch as flash_attn_varlen_qkvpacked_func, - flash_attn_varlen_kvpacked_func_torch as flash_attn_varlen_kvpacked_func, - flash_attn_varlen_func_torch as flash_attn_varlen_func, - ) +from deeplink_ext.ops.flash_attention import ( + flash_attn_qkvpacked_func, + flash_attn_kvpacked_func, + flash_attn_func, + flash_attn_varlen_qkvpacked_func, + flash_attn_varlen_kvpacked_func, + flash_attn_varlen_func, +) -try: - from .rms_norm import MixedFusedRMSNorm -except: - print( - _not_impl.format(op_name="RMSNorm"), - ) - from .rms_norm_fallback import MixedRMSNormTorch as MixedFusedRMSNorm +from deeplink_ext.ops.rms_norm import MixedFusedRMSNorm -try: - from .rotary_embedding import ApplyRotaryEmb -except: - print(_not_impl.format(op_name="rotary embedding")) - from .rotary_embedding_fallback import ApplyRotaryEmbTorch as ApplyRotaryEmb +from deeplink_ext.ops.rotary_embedding import ApplyRotaryEmb __all__ = [ "AdamW", diff --git a/deeplink_ext/internevo_ops/_flash_attention_npu.py b/deeplink_ext/internevo_ops/_flash_attention_npu.py deleted file mode 100644 index 37110b94..00000000 --- a/deeplink_ext/internevo_ops/_flash_attention_npu.py +++ /dev/null @@ -1,337 +0,0 @@ -# Copyright (c) 2024, DeepLink. - -import torch -import torch_npu - -__all__ = [ - "flash_attn_func", - "flash_attn_varlen_func", - "flash_attn_qkvpacked_func", - "flash_attn_kvpacked_func", - "flash_attn_varlen_qkvpacked_func", - "flash_attn_varlen_kvpacked_func", -] - - -def flash_attn_func( - q, - k, - v, - dropout_p=0.0, - softmax_scale=None, - causal=False, - window_size=(-1, -1), - alibi_slopes=None, - deterministic=False, - return_attn_probs=False, -): - if softmax_scale is None: - softmax_scale = q.shape[-1] ** (-0.5) - - seqlen_q = q.shape[1] - seqlen_k = k.shape[1] - head_num = q.shape[-2] - - if seqlen_q == seqlen_k and seqlen_q < 2048 and seqlen_k < 2048: - sparse_mode = 0 - else: - sparse_mode = 2 - - seqlen_q = min(seqlen_q, 2048) - seqlen_k = min(seqlen_k, 2048) - - attention_mask = ( - torch.triu( - torch.ones([seqlen_q, seqlen_k], dtype=torch.bool, device=q.device), - diagonal=1, - ) - if causal - else None - ) - - out = torch_npu.npu_fusion_attention( - q, - k, - v, - head_num, - "BSND", - atten_mask=attention_mask, - scale=softmax_scale, - keep_prob=1 - dropout_p, - pre_tockens=seqlen_q, - next_tockens=0, - sparse_mode=sparse_mode, - )[0] - - return out - - -def flash_attn_varlen_func( - q, - k, - v, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_p=0.0, - softmax_scale=None, - causal=False, - window_size=(-1, -1), - alibi_slopes=None, - deterministic=False, - return_attn_probs=False, - block_table=None, -): - if softmax_scale is None: - softmax_scale = q.shape[-1] ** (-0.5) - head_num = q.shape[-2] - - cu_seqlens_q = cu_seqlens_q[1:].tolist() - cu_seqlens_k = cu_seqlens_k[1:].tolist() - seqlen_q = min(max_seqlen_q, 2048) - seqlen_k = min(max_seqlen_k, 2048) - - if max_seqlen_q < 2048: - sparse_mode = 0 - else: - sparse_mode = 2 - - attention_mask = ( - torch.triu( - torch.ones([seqlen_q, seqlen_k], dtype=torch.bool, device=q.device), - diagonal=1, - ) - if causal - else None - ) - - out = torch_npu.npu_fusion_attention( - q, - k, - v, - head_num, - "TND", - atten_mask=attention_mask, - scale=softmax_scale, - pre_tockens=q.shape[0], # seq_len - next_tockens=0, # 0 - keep_prob=1 - dropout_p, - sparse_mode=sparse_mode, - actual_seq_qlen=cu_seqlens_q, - actual_seq_kvlen=cu_seqlens_k, - )[0] - return out - - -def flash_attn_qkvpacked_func( - qkv, - dropout_p=0.0, - softmax_scale=None, - causal=False, - window_size=(-1, -1), - alibi_slopes=None, - deterministic=False, - return_attn_probs=False, -): - if softmax_scale is None: - softmax_scale = qkv.shape[-1] ** (-0.5) - q = qkv[:, :, 0] - k = qkv[:, :, 1] - v = qkv[:, :, 2] - - seqlen_qkv = qkv.shape[1] - head_num = q.shape[-2] - - if seqlen_qkv < 2048: - sparse_mode = 0 - else: - sparse_mode = 2 - - seqlen_qkv = min(qkv.shape[1], 2048) - - attention_mask = ( - torch.triu( - torch.ones([seqlen_qkv, seqlen_qkv], dtype=torch.bool, device=q.device), - diagonal=1, - ) - if causal - else None - ) - - out = torch_npu.npu_fusion_attention( - q, - k, - v, - head_num, - "BSND", - atten_mask=attention_mask, - scale=softmax_scale, - keep_prob=1 - dropout_p, - pre_tockens=seqlen_qkv, - next_tockens=0, - sparse_mode=sparse_mode, - )[0] - - return out - - -def flash_attn_kvpacked_func( - q, - kv, - dropout_p=0.0, - softmax_scale=None, - causal=False, - window_size=(-1, -1), - alibi_slopes=None, - deterministic=False, - return_attn_probs=False, -): - if softmax_scale is None: - softmax_scale = q.shape[-1] ** (-0.5) - k = kv[:, :, 0] - v = kv[:, :, 1] - - s0 = q.shape[1] - s1 = kv.shape[1] - head_num = q.shape[-2] - - if s0 == s1 and s0 < 2048 and s1 < 2048: - sparse_mode = 0 - else: - sparse_mode = 2 - - seqlen_q = min(s0, 2048) - seqlen_k = min(s1, 2048) - - attention_mask = ( - torch.triu( - torch.ones([seqlen_q, seqlen_k], dtype=torch.bool, device=q.device), - diagonal=1, - ) - if causal - else None - ) - - out = torch_npu.npu_fusion_attention( - q, - k, - v, - head_num, - "BSND", - atten_mask=attention_mask, - scale=softmax_scale, - keep_prob=1 - dropout_p, - pre_tockens=seqlen_k, - next_tockens=0, - sparse_mode=sparse_mode, - )[0] - - return out - - -def flash_attn_varlen_qkvpacked_func( - qkv, - cu_seqlens, - max_seqlen, - dropout_p=0.0, - softmax_scale=None, - causal=False, - window_size=(-1, -1), - alibi_slopes=None, - deterministic=False, - return_attn_probs=False, -): - if softmax_scale is None: - softmax_scale = qkv.shape[-1] ** (-0.5) - q = qkv[:, 0] - k = qkv[:, 1] - v = qkv[:, 2] - n = q.shape[1] - if max_seqlen > 2048: - sparse_mode = 2 - else: - sparse_mode = 0 - cu_seqlens_q = cu_seqlens[1:].tolist() - cu_seqlens_k = cu_seqlens[1:].tolist() - seqlen = min(max_seqlen, 2048) - attention_mask = ( - torch.triu( - torch.ones([seqlen, seqlen], dtype=torch.bool, device=q.device), - diagonal=1, - ) - if causal - else None - ) - out = torch_npu.npu_fusion_attention( - q, - k, - v, - n, - "TND", - atten_mask=attention_mask, - scale=softmax_scale, - pre_tockens=q.shape[0], # seq_len - next_tockens=0, # 0 - keep_prob=1 - dropout_p, - sparse_mode=sparse_mode, - actual_seq_qlen=cu_seqlens_q, - actual_seq_kvlen=cu_seqlens_k, - )[0] - return out - - -def flash_attn_varlen_kvpacked_func( - q, - kv, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_p=0.0, - softmax_scale=None, - causal=False, - window_size=(-1, -1), - alibi_slopes=None, - deterministic=False, - return_attn_probs=False, -): - if softmax_scale is None: - softmax_scale = q.shape[-1] ** (-0.5) - k = kv[:, 0] - v = kv[:, 1] - n = q.shape[1] - cu_seqlens_q = cu_seqlens_q[1:].tolist() - cu_seqlens_k = cu_seqlens_k[1:].tolist() - seqlen_q = min(max_seqlen_q, 2048) - seqlen_k = min(max_seqlen_k, 2048) - - if max_seqlen_q > 2048: - sparse_mode = 2 - else: - sparse_mode = 0 - - attention_mask = ( - torch.triu( - torch.ones([seqlen_q, seqlen_k], dtype=torch.bool, device=q.device), - diagonal=1, - ) - if causal - else None - ) - out = torch_npu.npu_fusion_attention( - q, - k, - v, - n, - "TND", - atten_mask=attention_mask, - scale=softmax_scale, - pre_tockens=q.shape[0], # seq_len - next_tockens=0, # 0 - keep_prob=1 - dropout_p, - sparse_mode=sparse_mode, - actual_seq_qlen=cu_seqlens_q, - actual_seq_kvlen=cu_seqlens_k, - )[0] - return out diff --git a/deeplink_ext/internevo_ops/_rotary_embedding_npu.py b/deeplink_ext/internevo_ops/_rotary_embedding_npu.py deleted file mode 100644 index 4e27d045..00000000 --- a/deeplink_ext/internevo_ops/_rotary_embedding_npu.py +++ /dev/null @@ -1,75 +0,0 @@ -# Copyright (c) 2024, DeepLink. - -import torch -import torch_npu -from einops import rearrange - -__all__ = ["ApplyRotaryEmb"] - - -# adpated from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/layers/rotary.py#L35 -class ApplyRotaryEmb(torch.autograd.Function): - """ - ApplyRotaryEmb - """ - - @staticmethod - def forward( - ctx, - x: torch.Tensor, - cos: torch.Tensor, - sin: torch.Tensor, - interleaved: bool = False, - in_place: bool = False, - ): - """ - x: (batch_size, seqlen, nheads, headdim) - cos, sin: (seqlen, rotary_dim / 2) - interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead - of 1st half and 2nd half (GPT-NeoX style). - rotary_dim must be <= headdim - Apply rotary embedding to the first rotary_dim of x. - """ - *_, seqlen, _, head_dim = x.shape - rotary_seqlen, rotary_dim = cos.shape - rotary_dim *= 2 - - assert rotary_dim <= head_dim - assert seqlen <= rotary_seqlen - assert sin.shape == (rotary_seqlen, rotary_dim // 2) - - re_cos = rearrange(cos[:seqlen], "s d -> s 1 d") - re_sin = rearrange(sin[:seqlen], "s d -> s 1 d") - - cat_cos = torch.cat([re_cos, re_cos], -1) - cat_sin = torch.cat([re_sin, re_sin], -1) - - rot = torch_npu.npu_rotary_mul(x[..., :rotary_dim], cat_cos, cat_sin) - ctx.save_for_backward(cat_cos, cat_sin) - ctx.interleaved = interleaved - ctx.in_place = in_place - if in_place: - x[..., :rotary_dim].copy_(rot) - return x - else: - out = x.detach().clone() - if rotary_dim < head_dim and not in_place: - out[..., rotary_dim:].copy_(x[..., rotary_dim:]) - return out - - @staticmethod - def backward(ctx, do): - cat_cos, cat_sin = ctx.saved_tensors - *_, seqlen, _, head_dim = do.shape - rotary_dim = cat_cos.shape[-1] - - dx_out = torch_npu.npu_rotary_mul( - do[..., :rotary_dim], cat_cos, torch.neg(cat_sin) - ) - if ctx.in_place: - do[..., :rotary_dim].copy_(dx_out) - return do, None, None, None, None - else: - dx = do.detach().clone() - dx[..., :rotary_dim].copy_(dx_out) - return dx, None, None, None, None diff --git a/deeplink_ext/internevo_ops/adamw.py b/deeplink_ext/internevo_ops/adamw.py deleted file mode 100644 index 59afad50..00000000 --- a/deeplink_ext/internevo_ops/adamw.py +++ /dev/null @@ -1,5 +0,0 @@ -# Copyright (c) 2024, DeepLink. - -from deeplink_ext.interntrain_ops.adamw import AdamW - -__all__ = ["AdamW"] diff --git a/deeplink_ext/internevo_ops/rms_norm.py b/deeplink_ext/internevo_ops/rms_norm.py deleted file mode 100644 index 77e43dee..00000000 --- a/deeplink_ext/internevo_ops/rms_norm.py +++ /dev/null @@ -1,5 +0,0 @@ -# Copyright (c) 2024, DeepLink. - -from deeplink_ext.interntrain_ops import MixedFusedRMSNorm - -__all__ = ["MixedFusedRMSNorm"] diff --git a/deeplink_ext/internevo_ops/rms_norm_fallback.py b/deeplink_ext/internevo_ops/rms_norm_fallback.py deleted file mode 100644 index db6ae97f..00000000 --- a/deeplink_ext/internevo_ops/rms_norm_fallback.py +++ /dev/null @@ -1,5 +0,0 @@ -# Copyright (c) 2024, DeepLink. - -from deeplink_ext.interntrain_ops.rms_norm_fallback import MixedRMSNormTorch - -__all__ = ["MixedRMSNormTorch"] diff --git a/deeplink_ext/internevo_ops/rotary_embedding.py b/deeplink_ext/internevo_ops/rotary_embedding.py deleted file mode 100644 index 1a2a36d9..00000000 --- a/deeplink_ext/internevo_ops/rotary_embedding.py +++ /dev/null @@ -1,14 +0,0 @@ -# Copyright (c) 2024, DeepLink. - -from deeplink_ext.utils import PlatformType, deeplink_ext_get_platform_type - -platform_type = deeplink_ext_get_platform_type() -if platform_type == PlatformType.TORCH_NPU: - # from ._rotary_embedding_npu import ApplyRotaryEmb - from .rotary_embedding_fallback import ApplyRotaryEmbTorch as ApplyRotaryEmb -elif platform_type == PlatformType.TORCH_DIPU: - from ._rotary_embedding_dipu import ApplyRotaryEmb -else: - raise ImportError - -__all__ = ["ApplyRotaryEmb"] diff --git a/deeplink_ext/interntrain_ops/__init__.py b/deeplink_ext/interntrain_ops/__init__.py index f41e35d1..60166bd6 100644 --- a/deeplink_ext/interntrain_ops/__init__.py +++ b/deeplink_ext/interntrain_ops/__init__.py @@ -3,34 +3,14 @@ _not_impl = "[deeplink_ext] {op_name} is not implemented in diopi. Falling back to the slower torch implementation." try: - from .adamw import AdamW + from deeplink_ext.ops.adamw import AdamW except Exception as e: print(_not_impl.format(op_name="adamw")) from torch.optim import AdamW -try: - from .flash_attention import FlashSelfAttention, FlashCrossAttention -except Exception as e: - print(_not_impl.format(op_name="flash attention")) - from .flash_attention_fallback import SelfAttention as FlashSelfAttention - from .flash_attention_fallback import CrossAttention as FlashCrossAttention - - -try: - from .rms_norm import MixedFusedRMSNorm -except: - print( - _not_impl.format(op_name="RMSNorm"), - ) - from .rms_norm_fallback import MixedRMSNormTorch as MixedFusedRMSNorm - - -try: - from .rotary_embedding import ApplyRotaryEmb, ApplyRotaryEmbQKV_ -except: - print(_not_impl.format(op_name="rotary embedding")) - from .rotary_embedding_fallback import ApplyRotaryEmbTorch as ApplyRotaryEmb - from .rotary_embedding_fallback import ApplyRotaryEmbQKV_Torch as ApplyRotaryEmbQKV_ +from deeplink_ext.ops.flash_attention import FlashSelfAttention, FlashCrossAttention +from deeplink_ext.ops.rms_norm import MixedFusedRMSNorm +from deeplink_ext.ops.rotary_embedding import ApplyRotaryEmb, ApplyRotaryEmbQKV_ __all__ = [ diff --git a/deeplink_ext/interntrain_ops/_flash_attention_npu.py b/deeplink_ext/interntrain_ops/_flash_attention_npu.py deleted file mode 100644 index 0db00faf..00000000 --- a/deeplink_ext/interntrain_ops/_flash_attention_npu.py +++ /dev/null @@ -1,291 +0,0 @@ -# Copyright (c) 2024, DeepLink. - -import torch -import torch_npu -import torch.nn as nn - -__all__ = ["FlashSelfAttention", "FlashCrossAttention"] - - -class FlashSelfAttention(nn.Module): - """Performs self-attention with support for both padded and unpadded sequences. - - Args: - causal (bool, optional): If True, applies causal self-attention, meaning each - position can only attend to previous positions. Default is False. - softmax_scale (float, optional): Scaling factor applied to the softmax - operation. If not provided, will be D^{-0.5}. Default is None. - dropout_p (float, optional): Dropout probability applied to the attention - scores. Default is 0.0. - """ - - def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0): - super().__init__() - self.causal = causal - self.softmax_scale = softmax_scale - self.dropout_p = attention_dropout - - def forward( - self, - qkv=None, - q=None, - k=None, - v=None, - kv=None, - causal=None, - cu_seqlens=None, - max_seqlen=None, - cu_seqlens_q=None, - cu_seqlens_k=None, - max_seqlen_q=None, - max_seqlen_k=None, - softmax_scale=None, - dropout_p=0.0, - ): - """Performs self-attention on the input sequences. - - Args: - qkv (torch.Tensor): Input tensor representing queries, keys, and values - concatenated together. (B, S, 3, H, D) for padded; (total, 3, H, D) - for unpadded. - causal (bool, optional): If provided, overrides the class-level 'causal' - argument for this forward pass. Default is None. - cu_seqlens (torch.Tensor((batch_size + 1,), dtype=torch.int32), optional): - Sequence lengths tensor for unpadded sequences. If provided, performs - attention on unpadded sequences. Default is None. - max_seqlen (int, optional): Maximum sequence length for unpadded sequences. - If provided, defines the maximum length of the sequences. Default is - None. - - Returns: - torch.Tensor: Output tensor after applying self-attention. - """ - padded = all(x is None for x in (cu_seqlens, cu_seqlens_q, cu_seqlens_k)) - if padded: - if qkv is not None: - query, key, value = qkv.unbind(dim=2) - elif kv is not None: - assert q is not None, "q should not be None, when kv is not None" - assert q.device == kv.device, "the devices of q and kv should be same" - query = q - key, value = kv.unbind(dim=2) - else: - assert ( - q is not None and k is not None and q is not None - ), "q, k, v should not be None" - assert ( - q.device == k.device and k.device == v.device - ), "the devices of q, k and v should be same" - query, key, value = q, k, v - - if softmax_scale is None: - softmax_scale = query.shape[-1] ** (-0.5) - head_num = query.shape[-2] - - seqlen_q = min(query.shape[1], 2048) - seqlen_kv = min(key.shape[1], 2048) - - if seqlen_q < 2048: - sparse_mode = 0 - else: - sparse_mode = 2 - - attention_mask = ( - torch.triu( - torch.ones( - [seqlen_q, seqlen_kv], dtype=torch.bool, device=query.device - ), - diagonal=1, - ) - if causal - else None - ) - - out = torch_npu.npu_fusion_attention( - query, - key, - value, - head_num, - "BSND", - atten_mask=attention_mask, - scale=softmax_scale, - keep_prob=1 - dropout_p, - pre_tockens=seqlen_q, - next_tockens=0, - sparse_mode=sparse_mode, - )[0] - - return out - else: - # unpadded - if qkv is not None: - query, key, value = qkv.unbind(dim=1) - elif kv is not None: - assert q is not None, "q should not be None, when kv is not None" - assert q.device == kv.device, "the devices of q and kv should be same" - query = q - key, value = kv.unbind(dim=1) - else: - assert ( - q is not None and k is not None and q is not None - ), "q, k, v should not be None" - assert ( - q.device == k.device and k.device == v.device - ), "the devices of q, k and v should be same" - query, key, value = q, k, v - - cu_seqlens = next( - (x for x in (cu_seqlens, cu_seqlens_q, cu_seqlens_k) if x is not None), - None, - ) - max_seqlen = next( - (x for x in (max_seqlen, max_seqlen_q, max_seqlen_k) if x is not None), - None, - ) - - if softmax_scale is None: - softmax_scale = query.shape[-1] ** (-0.5) - head_num = query.shape[-2] - - assert ( - cu_seqlens is not None - ), "cu_seqlens should not be None, when using varlen flash attention" - cu_seqlens = cu_seqlens[1:].tolist() - seqlen = min(max_seqlen, 2048) - attention_mask = ( - torch.triu( - torch.ones([seqlen, seqlen], dtype=torch.bool, device=query.device), - diagonal=1, - ) - if causal - else None - ) - - if max_seqlen < 2048: - sparse_mode = 0 - else: - sparse_mode = 2 - - out = torch_npu.npu_fusion_attention( - query, - key, - value, - head_num, - "TND", - atten_mask=attention_mask, - scale=softmax_scale, - pre_tockens=query.shape[0], - next_tockens=0, - keep_prob=1 - dropout_p, - sparse_mode=sparse_mode, - actual_seq_qlen=cu_seqlens, - actual_seq_kvlen=cu_seqlens, - )[0] - - return out - - -class FlashCrossAttention(nn.Module): - - def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0): - super().__init__() - self.causal = causal - self.softmax_scale = softmax_scale - self.dropout_p = attention_dropout - - def forward( - self, - q, - kv, - causal=None, - cu_seqlens=None, - max_seqlen=None, - cu_seqlens_k=None, - max_seqlen_k=None, - ): - padded = all(x is None for x in (cu_seqlens, cu_seqlens_k)) - if padded: - # padded - if self.softmax_scale is None: - self.softmax_scale = q.shape[-1] ** (-0.5) - k = kv[:, :, 0] - v = kv[:, :, 1] - - s0 = q.shape[1] - s1 = kv.shape[1] - head_num = q.shape[-2] - - if s0 == s1 and s0 < 2048 and s1 < 2048: - sparse_mode = 0 - else: - sparse_mode = 2 - - seqlen_q = min(s0, 2048) - seqlen_k = min(s1, 2048) - - attention_mask = ( - torch.triu( - torch.ones([seqlen_q, seqlen_k], dtype=torch.bool, device=q.device), - diagonal=1, - ) - if causal - else None - ) - - out = torch_npu.npu_fusion_attention( - q, - k, - v, - head_num, - "BSND", - atten_mask=attention_mask, - scale=self.softmax_scale, - keep_prob=1 - self.dropout_p, - pre_tockens=seqlen_k, - next_tockens=0, - sparse_mode=sparse_mode, - )[0] - - return out - - else: - # unpadded - if self.softmax_scale is None: - self.softmax_scale = q.shape[-1] ** (-0.5) - k = kv[:, 0] - v = kv[:, 1] - n = q.shape[1] - cu_seqlens_q = cu_seqlens[1:].tolist() - cu_seqlens_k = cu_seqlens_k[1:].tolist() - seqlen_q = min(max_seqlen, 2048) - seqlen_k = min(max_seqlen_k, 2048) - - if max_seqlen > 2048: - sparse_mode = 2 - else: - sparse_mode = 0 - - attention_mask = ( - torch.triu( - torch.ones([seqlen_q, seqlen_k], dtype=torch.bool, device=q.device), - diagonal=1, - ) - if causal - else None - ) - out = torch_npu.npu_fusion_attention( - q, - k, - v, - n, - "TND", - atten_mask=attention_mask, - scale=self.softmax_scale, - pre_tockens=q.shape[0], # seq_len - next_tockens=0, # 0 - keep_prob=1 - self.dropout_p, - sparse_mode=sparse_mode, - actual_seq_qlen=cu_seqlens_q, - actual_seq_kvlen=cu_seqlens_k, - )[0] - return out diff --git a/deeplink_ext/interntrain_ops/_mixed_rms_norm_npu.py b/deeplink_ext/interntrain_ops/_mixed_rms_norm_npu.py deleted file mode 100644 index 435be28d..00000000 --- a/deeplink_ext/interntrain_ops/_mixed_rms_norm_npu.py +++ /dev/null @@ -1,58 +0,0 @@ -# Copyright (c) 2024, DeepLink. -import numbers -import torch -import torch_npu -from torch import Tensor -from torch.nn import init -from torch.nn.parameter import Parameter -from torch_npu import npu_rms_norm - -__all__ = ["MixedFusedRMSNorm"] - - -# MixedFusedLayerNorm differs from FusedLayerNorm in that this layer norm uses parameter's dtype -# as output tensor's dtype while FusedLayerNorm uses input tensor's dtype for output tensor's dtype. -# See: `layer_norm_affine` and `layer_norm_affine_mixed_dtypes` in "csrc/layer_norm_cuda.cpp" -def manual_rms_norm( - my_input: Tensor, normalized_shape, weight: Tensor, eps, add_unit_offset=False -): - assert add_unit_offset == False - assert len(normalized_shape) == 1 - - input_dtype = my_input.dtype - weight_dtype = weight.dtype - - acc_dtype = torch.promote_types(input_dtype, weight_dtype) - out = npu_rms_norm(my_input.to(dtype=acc_dtype), weight.to(dtype=acc_dtype), eps)[0] - if out.dtype != weight_dtype: - out = out.to(dtype=weight_dtype) - return out - - -class MixedFusedRMSNorm(torch.nn.Module): - """A custom PyTorch module for RMS normalization.""" - - def __init__(self, normalized_shape, eps=1e-5, add_unit_offset=False): - super().__init__() - - if isinstance(normalized_shape, numbers.Integral): - normalized_shape = (normalized_shape,) - self.normalized_shape = torch.Size(normalized_shape) - self.eps = eps - self.weight = Parameter(torch.empty(*normalized_shape)) - self.add_unit_offset = add_unit_offset - self.reset_parameters() - - def forward(self, _input: torch.Tensor): - return manual_rms_norm( - _input, self.normalized_shape, self.weight, self.eps, self.add_unit_offset - ) - - def reset_parameters(self): - if self.add_unit_offset: - init.zeros_(self.weight) - else: - init.ones_(self.weight) - - def extra_repr(self): - return "{normalized_shape}, eps={eps}, ".format(**self.__dict__) diff --git a/deeplink_ext/interntrain_ops/_rotary_embedding_npu.py b/deeplink_ext/interntrain_ops/_rotary_embedding_npu.py deleted file mode 100644 index 0784d216..00000000 --- a/deeplink_ext/interntrain_ops/_rotary_embedding_npu.py +++ /dev/null @@ -1,171 +0,0 @@ -# Copyright (c) 2024, DeepLink. - -import torch -import torch_npu -from einops import rearrange - -__all__ = ["ApplyRotaryEmb", "ApplyRotaryEmbQKV_"] - - -class ApplyRotaryEmb(torch.autograd.Function): - """ - Apply rotary positional embedding to input tensor x. - Args: - x (Tensor): Input tensor x is of shape [seq_length, ... , dim] - cos (Tensor): Input tensor cos is of shape [seq_length, ..., dim] - sin (Tensor): Input tensor sin is of shape [seq_length, ..., dim] - - Returns: - Tensor: The input tensor after applying RoPE - """ - - @staticmethod - def forward(ctx, x, cos, sin, interleaved=False): - """ - x: (batch_size, seqlen, nheads, headdim) - cos, sin: (seqlen, rotary_dim / 2) - interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead - of 1st half and 2nd half (GPT-NeoX style). - rotary_dim must be <= headdim - Apply rotary embedding to the first rotary_dim of x. - """ - _, seqlen, _, headdim = x.shape - rotary_seqlen, rotary_dim = cos.shape - rotary_dim *= 2 - assert rotary_dim <= headdim - assert seqlen <= rotary_seqlen - assert sin.shape == (rotary_seqlen, rotary_dim // 2) - out = torch.empty_like(x) - - re_cos = rearrange(cos[:seqlen], "s d -> s 1 d") - re_sin = rearrange(sin[:seqlen], "s d -> s 1 d") - - cat_cos = torch.cat([re_cos, re_cos], -1) - cat_sin = torch.cat([re_sin, re_sin], -1) - - rot = torch_npu.npu_rotary_mul(x[..., :rotary_dim], cat_cos, cat_sin) - out[..., :rotary_dim].copy_(rot) - if rotary_dim < headdim: - out[..., rotary_dim:].copy_(x[..., rotary_dim:]) - - ctx.save_for_backward(cat_cos, cat_sin) - ctx.interleaved = interleaved - return out - - @staticmethod - def backward(ctx, do): - cat_cos, cat_sin = ctx.saved_tensors - headdim = do.shape[-1] - rotary_dim = cat_cos.shape[-1] - - dx = torch.empty_like(do) - dx_rot = torch_npu.npu_rotary_mul( - do[..., :rotary_dim], cat_cos, torch.neg(cat_sin) - ) - dx.copy_(dx_rot) - - if rotary_dim < headdim: - dx[..., rotary_dim:].copy_(do[..., rotary_dim:]) - return dx, None, None, None - - -class ApplyRotaryEmbQKV_(torch.autograd.Function): - """ - ApplyRotaryEmbQKV_ - """ - - @staticmethod - def forward(ctx, qkv, cos, sin, cos_k=None, sin_k=None, interleaved=False): - """ - qkv: (total, 3, nheads, headdim) / (batch_size, seqlen, 3, nheads, headdim) - cos, sin: (seqlen, rotary_dim / 2) - cos_k, sin_k: (seqlen, rotary_dim / 2), optional - interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of - 1st half and 2nd half (GPT-NeoX style). - rotary_dim must be <= headdim - Apply rotary embedding *inplace* to the first rotary_dim of q and k. - """ - # len(qkv.shape) == 4 means the format of qkv is (total, 3, nheads, headdim) which is packed, - # otherwise the format of qkv is (batch_size, seqlen, 3, nheads, headdim) which is unpacked. - # We handle both packed qkv and unpacked qkv scenario in this class. - three = qkv.shape[1] if len(qkv.shape) == 4 else qkv.shape[2] - assert three == 3 - seqlen = None if len(qkv.shape) == 4 else qkv.shape[1] - rotary_seqlen, rotary_dim = cos.shape - if len(qkv.shape) != 4: - assert seqlen <= rotary_seqlen - headdim = qkv.shape[-1] - rotary_dim *= 2 - assert rotary_dim <= headdim - cos_k = cos if cos_k is None else cos_k - sin_k = sin if sin_k is None else sin_k - assert ( - sin.shape == cos_k.shape == sin_k.shape == (rotary_seqlen, rotary_dim // 2) - ) - - q_ro = ( - qkv[:, 0, :, :rotary_dim] - if len(qkv.shape) == 4 - else qkv[:, :, 0, :, :rotary_dim] - ) - re_cos = ( - rearrange(cos, "s d -> s 1 d") - if len(qkv.shape) == 4 - else rearrange(cos[:seqlen], "s d -> s 1 d") - ) - re_sin = ( - rearrange(sin, "s d -> s 1 d") - if len(qkv.shape) == 4 - else rearrange(sin[:seqlen], "s d -> s 1 d") - ) - cat_cos = torch.cat([re_cos, re_cos], -1) - cat_sin = torch.cat([re_sin, re_sin], -1) - q_out = torch_npu.npu_rotary_mul(q_ro, cat_cos, cat_sin) - q_ro.copy_(q_out) - - k_ro = ( - qkv[:, 1, :, :rotary_dim] - if len(qkv.shape) == 4 - else qkv[:, :, 1, :, :rotary_dim] - ) - re_cos_k = ( - rearrange(cos_k, "s d -> s 1 d") - if len(qkv.shape) == 4 - else rearrange(cos_k[:seqlen], "s d -> s 1 d") - ) - re_sin_k = ( - rearrange(sin_k, "s d -> s 1 d") - if len(qkv.shape) == 4 - else rearrange(sin_k[:seqlen], "s d -> s 1 d") - ) - cat_cos_k = torch.cat([re_cos_k, re_cos_k], -1) - cat_sin_k = torch.cat([re_sin_k, re_sin_k], -1) - k_out = torch_npu.npu_rotary_mul(k_ro, cat_cos_k, cat_sin_k) - k_ro.copy_(k_out) - - ctx.save_for_backward(cat_cos, cat_sin, cat_cos_k, cat_sin_k) - ctx.interleaved = interleaved - return qkv - - @staticmethod - def backward(ctx, dqkv): - cat_cos, cat_sin, cat_cos_k, cat_sin_k = ctx.saved_tensors - rotary_dim = cat_cos.shape[-1] - - dq_ro = ( - dqkv[:, 0, :, :rotary_dim] - if len(dqkv.shape) == 4 - else dqkv[:, :, 0, :, :rotary_dim] - ) - dq_out = torch_npu.npu_rotary_mul(dq_ro, cat_cos, torch.neg(cat_sin)) - dq_ro.copy_(dq_out) - - dk_ro = ( - dqkv[:, 1, :, :rotary_dim] - if len(dqkv.shape) == 4 - else dqkv[:, :, 1, :, :rotary_dim] - ) - dk_out = torch_npu.npu_rotary_mul(dk_ro, cat_cos_k, torch.neg(cat_sin_k)) - dk_ro.copy_(dk_out) - - return dqkv, None, None, None, None, None diff --git a/deeplink_ext/interntrain_ops/flash_attention.py b/deeplink_ext/interntrain_ops/flash_attention.py deleted file mode 100644 index 84f83a5e..00000000 --- a/deeplink_ext/interntrain_ops/flash_attention.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright (c) 2024, DeepLink. - -from deeplink_ext.utils import PlatformType, deeplink_ext_get_platform_type - -platform_type = deeplink_ext_get_platform_type() -if platform_type == PlatformType.TORCH_NPU: - from ._flash_attention_npu import FlashSelfAttention, FlashCrossAttention -elif platform_type == PlatformType.TORCH_DIPU: - from ._flash_attention_dipu import FlashSelfAttention, FlashCrossAttention -else: - raise ImportError - -__all__ = ["FlashSelfAttention", "FlashCrossAttention"] diff --git a/deeplink_ext/interntrain_ops/rms_norm.py b/deeplink_ext/interntrain_ops/rms_norm.py deleted file mode 100644 index 301ab9e1..00000000 --- a/deeplink_ext/interntrain_ops/rms_norm.py +++ /dev/null @@ -1,17 +0,0 @@ -# Copyright (c) 2024, DeepLink. - -from deeplink_ext.utils import PlatformType, deeplink_ext_get_platform_type - -platform_type = deeplink_ext_get_platform_type() -if platform_type == PlatformType.TORCH_NPU: - # from ._mixed_rms_norm_npu import MixedFusedRMSNorm - # Due to the accuracy problem of the npu fused operator, a torch combination is used as an alternative. - from .rms_norm_fallback import MixedRMSNormTorch as MixedFusedRMSNorm -elif platform_type == PlatformType.TORCH_DIPU: - # from ._mixed_rms_norm_dipu import MixedFusedRMSNorm - # Due to the accuracy problem of the npu fused operator, a torch combination is used as an alternative. - from .rms_norm_fallback import MixedRMSNormTorch as MixedFusedRMSNorm -else: - raise ImportError - -__all__ = ["MixedFusedRMSNorm"] diff --git a/deeplink_ext/interntrain_ops/rotary_embedding.py b/deeplink_ext/interntrain_ops/rotary_embedding.py deleted file mode 100644 index 1805b678..00000000 --- a/deeplink_ext/interntrain_ops/rotary_embedding.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright (c) 2024, DeepLink. - -from deeplink_ext.utils import PlatformType, deeplink_ext_get_platform_type - -platform_type = deeplink_ext_get_platform_type() -if platform_type == PlatformType.TORCH_NPU: - # from ._rotary_embedding_npu import ApplyRotaryEmb, ApplyRotaryEmbQKV_ - from .rotary_embedding_fallback import ApplyRotaryEmbTorch as ApplyRotaryEmb - from .rotary_embedding_fallback import ApplyRotaryEmbQKV_Torch as ApplyRotaryEmbQKV_ -elif platform_type == PlatformType.TORCH_DIPU: - from ._rotary_embedding_dipu import ApplyRotaryEmb, ApplyRotaryEmbQKV_ -else: - raise ImportError - -__all__ = ["ApplyRotaryEmb", "ApplyRotaryEmbQKV_"] diff --git a/deeplink_ext/interntrain_ops/adamw.py b/deeplink_ext/ops/adamw/__init__.py similarity index 68% rename from deeplink_ext/interntrain_ops/adamw.py rename to deeplink_ext/ops/adamw/__init__.py index 110421fb..be9097a4 100644 --- a/deeplink_ext/interntrain_ops/adamw.py +++ b/deeplink_ext/ops/adamw/__init__.py @@ -7,10 +7,7 @@ from deeplink_ext.utils import PlatformType, deeplink_ext_get_platform_type platform_type = deeplink_ext_get_platform_type() -if platform_type == PlatformType.TORCH_NPU: - import torch_npu - from torch_npu.optim.npu_fused_adamw import NpuFusedAdamW as AdamW -elif platform_type == PlatformType.TORCH_DIPU: +if platform_type == PlatformType.TORCH_DIPU: # import torch_dipu # assert torch_dipu.vendor_type == 'NPU', "ascend_speed framework only support NPU accelerators." from ._adamw_dipu import AdamW diff --git a/deeplink_ext/interntrain_ops/_adamw_dipu.py b/deeplink_ext/ops/adamw/_adamw_dipu.py similarity index 99% rename from deeplink_ext/interntrain_ops/_adamw_dipu.py rename to deeplink_ext/ops/adamw/_adamw_dipu.py index 25a7b679..855c0842 100644 --- a/deeplink_ext/interntrain_ops/_adamw_dipu.py +++ b/deeplink_ext/ops/adamw/_adamw_dipu.py @@ -60,7 +60,6 @@ def fused_adamw( class AdamW(Optimizer): - def __init__( self, params, diff --git a/deeplink_ext/ops/bert_padding/__init__.py b/deeplink_ext/ops/bert_padding/__init__.py new file mode 100644 index 00000000..246dc2cc --- /dev/null +++ b/deeplink_ext/ops/bert_padding/__init__.py @@ -0,0 +1,3 @@ +from .bert_padding import pad_input, unpad_input, index_first_axis + +__all__ = ["pad_input", "unpad_input", "index_first_axis"] diff --git a/deeplink_ext/easyllm_ops/bert_padding.py b/deeplink_ext/ops/bert_padding/bert_padding.py similarity index 100% rename from deeplink_ext/easyllm_ops/bert_padding.py rename to deeplink_ext/ops/bert_padding/bert_padding.py diff --git a/deeplink_ext/ops/flash_attention/__init__.py b/deeplink_ext/ops/flash_attention/__init__.py new file mode 100644 index 00000000..486b9c37 --- /dev/null +++ b/deeplink_ext/ops/flash_attention/__init__.py @@ -0,0 +1,32 @@ +_not_impl = "[deeplink_ext] {op_name} is not implemented in diopi. Falling back to the slower torch implementation." + +try: + from .internevo_flash_attention import ( + flash_attn_qkvpacked_func, + flash_attn_kvpacked_func, + flash_attn_func, + flash_attn_varlen_qkvpacked_func, + flash_attn_varlen_kvpacked_func, + flash_attn_varlen_func, + ) +except Exception as e: + print(_not_impl.format(op_name="flash attention")) + from .internevo_flash_attention_fallback import ( + flash_attn_qkvpacked_func_torch as flash_attn_qkvpacked_func, + flash_attn_kvpacked_func_torch as flash_attn_kvpacked_func, + flash_attn_func_torch as flash_attn_func, + flash_attn_varlen_qkvpacked_func_torch as flash_attn_varlen_qkvpacked_func, + flash_attn_varlen_kvpacked_func_torch as flash_attn_varlen_kvpacked_func, + flash_attn_varlen_func_torch as flash_attn_varlen_func, + ) + +try: + from .interntrain_flash_attention import FlashSelfAttention, FlashCrossAttention +except Exception as e: + print(_not_impl.format(op_name="flash attention")) + from .interntrain_flash_attention_fallback import ( + SelfAttention as FlashSelfAttention, + ) + from .interntrain_flash_attention_fallback import ( + CrossAttention as FlashCrossAttention, + ) diff --git a/deeplink_ext/internevo_ops/flash_attention.py b/deeplink_ext/ops/flash_attention/internevo_flash_attention.py similarity index 61% rename from deeplink_ext/internevo_ops/flash_attention.py rename to deeplink_ext/ops/flash_attention/internevo_flash_attention.py index 12a2422e..2408b207 100644 --- a/deeplink_ext/internevo_ops/flash_attention.py +++ b/deeplink_ext/ops/flash_attention/internevo_flash_attention.py @@ -3,17 +3,8 @@ from deeplink_ext.utils import PlatformType, deeplink_ext_get_platform_type platform_type = deeplink_ext_get_platform_type() -if platform_type == PlatformType.TORCH_NPU: - from ._flash_attention_npu import ( - flash_attn_func, - flash_attn_kvpacked_func, - flash_attn_varlen_func, - flash_attn_qkvpacked_func, - flash_attn_varlen_kvpacked_func, - flash_attn_varlen_qkvpacked_func, - ) -elif platform_type == PlatformType.TORCH_DIPU: - from ._flash_attention_dipu import ( +if platform_type == PlatformType.TORCH_DIPU: + from .internevo_flash_attention_dipu import ( flash_attn_func, flash_attn_kvpacked_func, flash_attn_varlen_func, diff --git a/deeplink_ext/internevo_ops/_flash_attention_dipu.py b/deeplink_ext/ops/flash_attention/internevo_flash_attention_dipu.py similarity index 99% rename from deeplink_ext/internevo_ops/_flash_attention_dipu.py rename to deeplink_ext/ops/flash_attention/internevo_flash_attention_dipu.py index 03be3ba9..09dff62f 100644 --- a/deeplink_ext/internevo_ops/_flash_attention_dipu.py +++ b/deeplink_ext/ops/flash_attention/internevo_flash_attention_dipu.py @@ -22,7 +22,6 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function): - @staticmethod def forward( ctx, @@ -108,7 +107,6 @@ def backward(ctx, dout, *args): class CustomizedFlashAttnQKVPackedFunc(torch.autograd.Function): - @staticmethod def forward( ctx, @@ -254,7 +252,6 @@ def flash_attn_qkvpacked_func( class FlashAttnKVPackedFunc(torch.autograd.Function): - @staticmethod def forward( ctx, @@ -344,7 +341,6 @@ def backward(ctx, dout, *args): class CustomizedFlashAttnKVPackedFunc(torch.autograd.Function): - @staticmethod def forward( ctx, @@ -498,7 +494,6 @@ def flash_attn_kvpacked_func( class FlashAttnFunc(torch.autograd.Function): - @staticmethod def forward( ctx, @@ -592,7 +587,6 @@ def backward(ctx, dout, *args): class CustomizedFlashAttnFunc(torch.autograd.Function): - @staticmethod def forward( ctx, @@ -753,7 +747,6 @@ def flash_attn_func( class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function): - @staticmethod def forward( ctx, @@ -850,7 +843,6 @@ def backward(ctx, dout, *args): class CustomizedFlashAttnVarlenQKVPackedFunc(torch.autograd.Function): - @staticmethod def forward( ctx, @@ -1007,7 +999,6 @@ def flash_attn_varlen_qkvpacked_func( class FlashAttnVarlenKVPackedFunc(torch.autograd.Function): - @staticmethod def forward( ctx, @@ -1112,7 +1103,6 @@ def backward(ctx, dout, *args): class CustomizedFlashAttnVarlenKVPackedFunc(torch.autograd.Function): - @staticmethod def forward( ctx, @@ -1288,7 +1278,6 @@ def flash_attn_varlen_kvpacked_func( class FlashAttnVarlenFunc(torch.autograd.Function): - @staticmethod def forward( ctx, @@ -1414,7 +1403,6 @@ def backward(ctx, dout, *args): class CustomizedFlashAttnVarlenFunc(torch.autograd.Function): - @staticmethod def forward( ctx, diff --git a/deeplink_ext/internevo_ops/flash_attention_fallback.py b/deeplink_ext/ops/flash_attention/internevo_flash_attention_fallback.py similarity index 100% rename from deeplink_ext/internevo_ops/flash_attention_fallback.py rename to deeplink_ext/ops/flash_attention/internevo_flash_attention_fallback.py diff --git a/deeplink_ext/ops/flash_attention/interntrain_flash_attention.py b/deeplink_ext/ops/flash_attention/interntrain_flash_attention.py new file mode 100644 index 00000000..0ba83c65 --- /dev/null +++ b/deeplink_ext/ops/flash_attention/interntrain_flash_attention.py @@ -0,0 +1,14 @@ +# Copyright (c) 2024, DeepLink. + +from deeplink_ext.utils import PlatformType, deeplink_ext_get_platform_type + +platform_type = deeplink_ext_get_platform_type() +if platform_type == PlatformType.TORCH_DIPU: + from .interntrain_flash_attention_dipu import ( + FlashSelfAttention, + FlashCrossAttention, + ) +else: + raise ImportError + +__all__ = ["FlashSelfAttention", "FlashCrossAttention"] diff --git a/deeplink_ext/interntrain_ops/_flash_attention_dipu.py b/deeplink_ext/ops/flash_attention/interntrain_flash_attention_dipu.py similarity index 99% rename from deeplink_ext/interntrain_ops/_flash_attention_dipu.py rename to deeplink_ext/ops/flash_attention/interntrain_flash_attention_dipu.py index a4f59d2a..5b3822ae 100644 --- a/deeplink_ext/interntrain_ops/_flash_attention_dipu.py +++ b/deeplink_ext/ops/flash_attention/interntrain_flash_attention_dipu.py @@ -16,7 +16,6 @@ class CustomizedFlashAttentionQKVPackedFunc(torch.autograd.Function): - @staticmethod def forward( ctx, @@ -206,7 +205,6 @@ def backward(ctx, dout): class FlashAttentionQKVPackedFunc(torch.autograd.Function): - @staticmethod def forward( ctx, @@ -359,7 +357,6 @@ def backward(ctx, dout): class CustomizedFlashAttentionVarlenQKVPackedFunc(torch.autograd.Function): - @staticmethod def forward( ctx, @@ -560,7 +557,6 @@ def backward(ctx, dout): class FlashAttentionVarlenQKVPackedFunc(torch.autograd.Function): - @staticmethod def forward( ctx, @@ -738,7 +734,6 @@ def backward(ctx, dout): class CustomizedFlashAttentionKVPackedFunc(torch.autograd.Function): - @staticmethod def forward(ctx, q, kv, dropout_p, softmax_scale, causal): assert q.device == kv.device, "the devices of q and kv should be same" @@ -842,7 +837,6 @@ def backward(ctx, dout): class FlashAttentionKVPackedFunc(torch.autograd.Function): - @staticmethod def forward(ctx, q, kv, dropout_p, softmax_scale, causal): assert q.device == kv.device, "the devices of q and kv should be same" @@ -920,7 +914,6 @@ def backward(ctx, dout): class CustomizedFlashAttentionVarlenKVPackedFunc(torch.autograd.Function): - @staticmethod def forward( ctx, @@ -1045,7 +1038,6 @@ def backward(ctx, dout): class FlashAttentionVarlenKVPackedFunc(torch.autograd.Function): - @staticmethod def forward( ctx, @@ -1266,7 +1258,6 @@ def forward( class FlashCrossAttention(nn.Module): - def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0): super().__init__() self.causal = causal diff --git a/deeplink_ext/interntrain_ops/flash_attention_fallback.py b/deeplink_ext/ops/flash_attention/interntrain_flash_attention_fallback.py similarity index 100% rename from deeplink_ext/interntrain_ops/flash_attention_fallback.py rename to deeplink_ext/ops/flash_attention/interntrain_flash_attention_fallback.py diff --git a/deeplink_ext/ops/rms_norm/__init__.py b/deeplink_ext/ops/rms_norm/__init__.py new file mode 100644 index 00000000..f2145764 --- /dev/null +++ b/deeplink_ext/ops/rms_norm/__init__.py @@ -0,0 +1,17 @@ +_not_impl = "[deeplink_ext] {op_name} is not implemented in diopi. Falling back to the slower torch implementation." + +try: + from .easyllm_rms_norm import rms_norm +except: + print( + _not_impl.format(op_name="RMSNorm"), + ) + from .easyllm_rms_norm_fallback import rms_norm_torch as rms_norm + +try: + from .internevo_rms_norm import MixedFusedRMSNorm +except: + print( + _not_impl.format(op_name="RMSNorm"), + ) + from .internevo_rms_norm_fallback import MixedRMSNormTorch as MixedFusedRMSNorm diff --git a/deeplink_ext/easyllm_ops/rms_norm.py b/deeplink_ext/ops/rms_norm/easyllm_rms_norm.py similarity index 71% rename from deeplink_ext/easyllm_ops/rms_norm.py rename to deeplink_ext/ops/rms_norm/easyllm_rms_norm.py index 02d13c67..927f560e 100644 --- a/deeplink_ext/easyllm_ops/rms_norm.py +++ b/deeplink_ext/ops/rms_norm/easyllm_rms_norm.py @@ -1,6 +1,6 @@ # Copyright (c) 2024, DeepLink. -from deeplink_ext.ascend_speed.rms_norm import RMSNorm +from .easyllm_rms_norm_dipu import RMSNorm __all__ = ["rms_norm"] diff --git a/deeplink_ext/ops/rms_norm/easyllm_rms_norm_dipu.py b/deeplink_ext/ops/rms_norm/easyllm_rms_norm_dipu.py new file mode 100644 index 00000000..7d4c237f --- /dev/null +++ b/deeplink_ext/ops/rms_norm/easyllm_rms_norm_dipu.py @@ -0,0 +1,49 @@ +# Copyright (c) 2024, DeepLink. + +import torch +import deeplink_ext.cpp_extensions as ext + +assert hasattr(ext, "rms_norm") and hasattr(ext, "rms_norm_backward") + +__all__ = ["RMSNorm"] + + +class RMSNorm(torch.autograd.Function): + @staticmethod + def forward(ctx, hidden_states, weight, eps): + output = torch.empty_like(hidden_states) + input_dtype = hidden_states.dtype + acc_dtype = ( + torch.float32 + if input_dtype in [torch.bfloat16, torch.float16] + else input_dtype + ) + n = weight.dim() + inv_rms = torch.empty( + list(hidden_states.shape[:-n]), + dtype=acc_dtype, + device=hidden_states.device, + ) + ext.rms_norm(output, inv_rms, hidden_states, weight.shape, weight, None, eps) + ctx.save_for_backward(hidden_states, inv_rms, weight) + ctx.eps = eps + return output + + @staticmethod + def backward(ctx, grad_output): + hidden_states, inv_rms, weight = ctx.saved_tensors + grad_input = torch.empty_like(hidden_states) + grad_weight = torch.empty_like(weight) + ext.rms_norm_backward( + grad_input, + grad_weight, + None, + grad_output, + hidden_states, + weight, + None, + inv_rms, + weight.shape, + ctx.eps, + ) + return grad_input, grad_weight, None, None diff --git a/deeplink_ext/easyllm_ops/rms_norm_fallback.py b/deeplink_ext/ops/rms_norm/easyllm_rms_norm_fallback.py similarity index 100% rename from deeplink_ext/easyllm_ops/rms_norm_fallback.py rename to deeplink_ext/ops/rms_norm/easyllm_rms_norm_fallback.py diff --git a/deeplink_ext/interntrain_ops/_mixed_rms_norm_dipu.py b/deeplink_ext/ops/rms_norm/internevo_mixed_rms_norm_dipu.py similarity index 99% rename from deeplink_ext/interntrain_ops/_mixed_rms_norm_dipu.py rename to deeplink_ext/ops/rms_norm/internevo_mixed_rms_norm_dipu.py index 3782e858..74118cd6 100644 --- a/deeplink_ext/interntrain_ops/_mixed_rms_norm_dipu.py +++ b/deeplink_ext/ops/rms_norm/internevo_mixed_rms_norm_dipu.py @@ -14,7 +14,6 @@ # as output tensor's dtype while FusedLayerNorm uses input tensor's dtype for output tensor's dtype. # See: `layer_norm_affine` and `layer_norm_affine_mixed_dtypes` in "csrc/layer_norm_cuda.cpp" class _MixedFusedRMSNormFunction(torch.autograd.Function): - @staticmethod def forward(ctx, hidden_states, weight, eps, normalized_shape): # ascend currently does not support dtype of hidden_states with higher precision than weight. @@ -94,7 +93,6 @@ def backward(ctx, grad_output): class MixedFusedRMSNorm(torch.nn.Module): - def __init__(self, normalized_shape, eps=1e-5, add_unit_offset=False): # TODO: Further optimization when there are device and dtype available. # factory_kwargs = {"device": device, "dtype": dtype} diff --git a/deeplink_ext/ops/rms_norm/internevo_rms_norm.py b/deeplink_ext/ops/rms_norm/internevo_rms_norm.py new file mode 100644 index 00000000..5b828ce9 --- /dev/null +++ b/deeplink_ext/ops/rms_norm/internevo_rms_norm.py @@ -0,0 +1,13 @@ +# Copyright (c) 2024, DeepLink. + +from deeplink_ext.utils import PlatformType, deeplink_ext_get_platform_type + +platform_type = deeplink_ext_get_platform_type() +if platform_type == PlatformType.TORCH_DIPU: + # from ._mixed_rms_norm_dipu import MixedFusedRMSNorm + # Due to the accuracy problem of the npu fused operator, a torch combination is used as an alternative. + from .interntrain_rms_norm_fallback import MixedRMSNormTorch as MixedFusedRMSNorm +else: + raise ImportError + +__all__ = ["MixedFusedRMSNorm"] diff --git a/deeplink_ext/interntrain_ops/rms_norm_fallback.py b/deeplink_ext/ops/rms_norm/internevo_rms_norm_fallback.py similarity index 100% rename from deeplink_ext/interntrain_ops/rms_norm_fallback.py rename to deeplink_ext/ops/rms_norm/internevo_rms_norm_fallback.py diff --git a/deeplink_ext/ops/rotary_embedding/__init__.py b/deeplink_ext/ops/rotary_embedding/__init__.py new file mode 100644 index 00000000..b1528327 --- /dev/null +++ b/deeplink_ext/ops/rotary_embedding/__init__.py @@ -0,0 +1,20 @@ +_not_impl = "[deeplink_ext] {op_name} is not implemented in diopi. Falling back to the slower torch implementation." + +try: + from .internevo_rotary_embedding import ApplyRotaryEmb +except: + print(_not_impl.format(op_name="rotary embedding")) + from .internevo_rotary_embedding_fallback import ( + ApplyRotaryEmbTorch as ApplyRotaryEmb, + ) + +try: + from .interntrain_rotary_embedding import ApplyRotaryEmb, ApplyRotaryEmbQKV_ +except: + print(_not_impl.format(op_name="rotary embedding")) + from .interntrain_rotary_embedding_fallback import ( + ApplyRotaryEmbTorch as ApplyRotaryEmb, + ) + from .interntrain_rotary_embedding_fallback import ( + ApplyRotaryEmbQKV_Torch as ApplyRotaryEmbQKV_, + ) diff --git a/deeplink_ext/ops/rotary_embedding/internevo_rotary_embedding.py b/deeplink_ext/ops/rotary_embedding/internevo_rotary_embedding.py new file mode 100644 index 00000000..764e0206 --- /dev/null +++ b/deeplink_ext/ops/rotary_embedding/internevo_rotary_embedding.py @@ -0,0 +1,11 @@ +# Copyright (c) 2024, DeepLink. + +from deeplink_ext.utils import PlatformType, deeplink_ext_get_platform_type + +platform_type = deeplink_ext_get_platform_type() +if platform_type == PlatformType.TORCH_DIPU: + from .internevo_rotary_embedding_dipu import ApplyRotaryEmb +else: + raise ImportError + +__all__ = ["ApplyRotaryEmb"] diff --git a/deeplink_ext/internevo_ops/_rotary_embedding_dipu.py b/deeplink_ext/ops/rotary_embedding/internevo_rotary_embedding_dipu.py similarity index 100% rename from deeplink_ext/internevo_ops/_rotary_embedding_dipu.py rename to deeplink_ext/ops/rotary_embedding/internevo_rotary_embedding_dipu.py diff --git a/deeplink_ext/internevo_ops/rotary_embedding_fallback.py b/deeplink_ext/ops/rotary_embedding/internevo_rotary_embedding_fallback.py similarity index 100% rename from deeplink_ext/internevo_ops/rotary_embedding_fallback.py rename to deeplink_ext/ops/rotary_embedding/internevo_rotary_embedding_fallback.py diff --git a/deeplink_ext/ops/rotary_embedding/interntrain_rotary_embedding.py b/deeplink_ext/ops/rotary_embedding/interntrain_rotary_embedding.py new file mode 100644 index 00000000..3555e625 --- /dev/null +++ b/deeplink_ext/ops/rotary_embedding/interntrain_rotary_embedding.py @@ -0,0 +1,11 @@ +# Copyright (c) 2024, DeepLink. + +from deeplink_ext.utils import PlatformType, deeplink_ext_get_platform_type + +platform_type = deeplink_ext_get_platform_type() +if platform_type == PlatformType.TORCH_DIPU: + from ._rotary_embedding_dipu import ApplyRotaryEmb, ApplyRotaryEmbQKV_ +else: + raise ImportError + +__all__ = ["ApplyRotaryEmb", "ApplyRotaryEmbQKV_"] diff --git a/deeplink_ext/interntrain_ops/_rotary_embedding_dipu.py b/deeplink_ext/ops/rotary_embedding/interntrain_rotary_embedding_dipu.py similarity index 100% rename from deeplink_ext/interntrain_ops/_rotary_embedding_dipu.py rename to deeplink_ext/ops/rotary_embedding/interntrain_rotary_embedding_dipu.py diff --git a/deeplink_ext/interntrain_ops/rotary_embedding_fallback.py b/deeplink_ext/ops/rotary_embedding/interntrain_rotary_embedding_fallback.py similarity index 100% rename from deeplink_ext/interntrain_ops/rotary_embedding_fallback.py rename to deeplink_ext/ops/rotary_embedding/interntrain_rotary_embedding_fallback.py From a0ad15b9f2051bbda3227862bfb9e7f0f81a5ed5 Mon Sep 17 00:00:00 2001 From: HuayiL <442488254@qq.com> Date: Tue, 5 Nov 2024 18:20:59 +0800 Subject: [PATCH 2/7] update tests --- tests/easyllm/test_rms_norm_dipu.py | 4 ++-- tests/internevo/test_flash_attention.py | 4 ++-- tests/internevo/test_rotary_embedding.py | 6 ++++-- tests/internevo/test_varlen_flash_attention.py | 4 ++-- tests/interntrain/test_adamw_dipu.py | 2 +- tests/interntrain/test_flash_attention.py | 4 ++-- tests/interntrain/test_rms_norm.py | 4 ++-- tests/interntrain/test_rotary_embedding.py | 4 ++-- tests/interntrain/test_varlen_flash_attention.py | 4 ++-- 9 files changed, 19 insertions(+), 17 deletions(-) diff --git a/tests/easyllm/test_rms_norm_dipu.py b/tests/easyllm/test_rms_norm_dipu.py index 1273b155..65c76a77 100644 --- a/tests/easyllm/test_rms_norm_dipu.py +++ b/tests/easyllm/test_rms_norm_dipu.py @@ -2,8 +2,8 @@ import torch from tests.core import calculate_fwd_and_bwd, allclose -from deeplink_ext.easyllm_ops.rms_norm import rms_norm -from deeplink_ext.easyllm_ops.rms_norm_fallback import rms_norm_torch +from deeplink_ext.ops.rms_norm.easyllm_rms_norm import rms_norm +from deeplink_ext.ops.rms_norm.easyllm_rms_norm_fallback import rms_norm_torch def test_rms_norm(): diff --git a/tests/internevo/test_flash_attention.py b/tests/internevo/test_flash_attention.py index 5126551a..9e9697d6 100644 --- a/tests/internevo/test_flash_attention.py +++ b/tests/internevo/test_flash_attention.py @@ -3,12 +3,12 @@ import torch from tests.core import copy_to_cpu, allclose, calculate_fwd_and_bwd -from deeplink_ext.internevo_ops.flash_attention_fallback import ( +from deeplink_ext.ops.flash_attention.internevo_flash_attention_fallback import ( flash_attn_qkvpacked_func_torch, flash_attn_kvpacked_func_torch, flash_attn_func_torch, ) -from deeplink_ext.internevo_ops.flash_attention import ( +from deeplink_ext.ops.flash_attention.internevo_flash_attention import ( flash_attn_qkvpacked_func, flash_attn_kvpacked_func, flash_attn_func, diff --git a/tests/internevo/test_rotary_embedding.py b/tests/internevo/test_rotary_embedding.py index 981c2f00..27968e05 100644 --- a/tests/internevo/test_rotary_embedding.py +++ b/tests/internevo/test_rotary_embedding.py @@ -2,8 +2,10 @@ import torch from tests.core import call_autograd_func, allclose -from deeplink_ext.internevo_ops.rotary_embedding import ApplyRotaryEmb -from deeplink_ext.internevo_ops.rotary_embedding_fallback import ApplyRotaryEmbTorch +from deeplink_ext.ops.rotary_embedding.internevo_rotary_embedding import ApplyRotaryEmb +from deeplink_ext.ops.rotary_embedding.internevo_rotary_embedding_fallback import ( + ApplyRotaryEmbTorch, +) def test_ApplyRotaryEmb(): diff --git a/tests/internevo/test_varlen_flash_attention.py b/tests/internevo/test_varlen_flash_attention.py index 97b8d64c..a2fa0874 100644 --- a/tests/internevo/test_varlen_flash_attention.py +++ b/tests/internevo/test_varlen_flash_attention.py @@ -3,12 +3,12 @@ import torch from tests.core import allclose, calculate_fwd_and_bwd, copy_to_cpu -from deeplink_ext.internevo_ops.flash_attention_fallback import ( +from deeplink_ext.ops.flash_attention.internevo_flash_attention_fallback import ( flash_attn_varlen_qkvpacked_func_torch, flash_attn_varlen_kvpacked_func_torch, flash_attn_varlen_func_torch, ) -from deeplink_ext.internevo_ops.flash_attention import ( +from deeplink_ext.ops.flash_attention.internevo_flash_attention import ( flash_attn_varlen_qkvpacked_func, flash_attn_varlen_kvpacked_func, flash_attn_varlen_func, diff --git a/tests/interntrain/test_adamw_dipu.py b/tests/interntrain/test_adamw_dipu.py index 0fff76ef..1f6c5af2 100644 --- a/tests/interntrain/test_adamw_dipu.py +++ b/tests/interntrain/test_adamw_dipu.py @@ -3,7 +3,7 @@ import copy import torch from torch import nn -from deeplink_ext.interntrain_ops.adamw import AdamW +from deeplink_ext.ops.adamw import AdamW def test_AdamW(): diff --git a/tests/interntrain/test_flash_attention.py b/tests/interntrain/test_flash_attention.py index 8ac86337..0727f884 100644 --- a/tests/interntrain/test_flash_attention.py +++ b/tests/interntrain/test_flash_attention.py @@ -3,11 +3,11 @@ import torch from tests.core import copy_to_cpu, allclose, call_module -from deeplink_ext.interntrain_ops.flash_attention import ( +from deeplink_ext.ops.flash_attention.interntrain_flash_attention import ( FlashSelfAttention, FlashCrossAttention, ) -from deeplink_ext.interntrain_ops.flash_attention_fallback import ( +from deeplink_ext.ops.flash_attention.interntrain_flash_attention_fallback import ( SelfAttention, CrossAttention, ) diff --git a/tests/interntrain/test_rms_norm.py b/tests/interntrain/test_rms_norm.py index ac24b2ca..37da340f 100644 --- a/tests/interntrain/test_rms_norm.py +++ b/tests/interntrain/test_rms_norm.py @@ -2,8 +2,8 @@ import torch from tests.core import call_module, allclose -from deeplink_ext.interntrain_ops.rms_norm import MixedFusedRMSNorm -from deeplink_ext.interntrain_ops.rms_norm_fallback import MixedRMSNormTorch +from deeplink_ext.ops.rms_norm.internevo_rms_norm import MixedFusedRMSNorm +from deeplink_ext.ops.rms_norm.internevo_rms_norm_fallback import MixedRMSNormTorch def test_MixedFusedRMSNorm(): diff --git a/tests/interntrain/test_rotary_embedding.py b/tests/interntrain/test_rotary_embedding.py index deb9dae0..5982bd9d 100644 --- a/tests/interntrain/test_rotary_embedding.py +++ b/tests/interntrain/test_rotary_embedding.py @@ -3,11 +3,11 @@ import torch from tests.core import call_autograd_func, allclose -from deeplink_ext.interntrain_ops.rotary_embedding import ( +from deeplink_ext.ops.rotary_embedding.interntrain_rotary_embedding import ( ApplyRotaryEmb, ApplyRotaryEmbQKV_, ) -from deeplink_ext.interntrain_ops.rotary_embedding_fallback import ( +from deeplink_ext.ops.rotary_embedding.interntrain_rotary_embedding_fallback import ( ApplyRotaryEmbTorch, ApplyRotaryEmbQKV_Torch, ) diff --git a/tests/interntrain/test_varlen_flash_attention.py b/tests/interntrain/test_varlen_flash_attention.py index a66a8222..92f9c200 100644 --- a/tests/interntrain/test_varlen_flash_attention.py +++ b/tests/interntrain/test_varlen_flash_attention.py @@ -3,11 +3,11 @@ import torch from tests.core import allclose, call_module -from deeplink_ext.interntrain_ops.flash_attention import ( +from deeplink_ext.ops.flash_attention.interntrain_flash_attention import ( FlashSelfAttention, FlashCrossAttention, ) -from deeplink_ext.interntrain_ops.flash_attention_fallback import ( +from deeplink_ext.ops.flash_attention.interntrain_flash_attention_fallback import ( SelfAttention, CrossAttention, ) From 10fa4ba811aa256aea70386014e48a85234ce101 Mon Sep 17 00:00:00 2001 From: HuayiL <442488254@qq.com> Date: Wed, 6 Nov 2024 10:49:37 +0800 Subject: [PATCH 3/7] fix some bugs --- deeplink_ext/ops/rms_norm/internevo_rms_norm.py | 2 +- .../ops/rotary_embedding/interntrain_rotary_embedding.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/deeplink_ext/ops/rms_norm/internevo_rms_norm.py b/deeplink_ext/ops/rms_norm/internevo_rms_norm.py index 5b828ce9..11356346 100644 --- a/deeplink_ext/ops/rms_norm/internevo_rms_norm.py +++ b/deeplink_ext/ops/rms_norm/internevo_rms_norm.py @@ -6,7 +6,7 @@ if platform_type == PlatformType.TORCH_DIPU: # from ._mixed_rms_norm_dipu import MixedFusedRMSNorm # Due to the accuracy problem of the npu fused operator, a torch combination is used as an alternative. - from .interntrain_rms_norm_fallback import MixedRMSNormTorch as MixedFusedRMSNorm + from .internevo_rms_norm_fallback import MixedRMSNormTorch as MixedFusedRMSNorm else: raise ImportError diff --git a/deeplink_ext/ops/rotary_embedding/interntrain_rotary_embedding.py b/deeplink_ext/ops/rotary_embedding/interntrain_rotary_embedding.py index 3555e625..61f98c64 100644 --- a/deeplink_ext/ops/rotary_embedding/interntrain_rotary_embedding.py +++ b/deeplink_ext/ops/rotary_embedding/interntrain_rotary_embedding.py @@ -4,7 +4,7 @@ platform_type = deeplink_ext_get_platform_type() if platform_type == PlatformType.TORCH_DIPU: - from ._rotary_embedding_dipu import ApplyRotaryEmb, ApplyRotaryEmbQKV_ + from .interntrain_rotary_embedding_dipu import ApplyRotaryEmb, ApplyRotaryEmbQKV_ else: raise ImportError From eecdb8fef9ebb3d97165a779b8d83754701df648 Mon Sep 17 00:00:00 2001 From: HuayiL <442488254@qq.com> Date: Mon, 11 Nov 2024 11:48:21 +0800 Subject: [PATCH 4/7] add support for InterEvo flash-attn --- deeplink_ext/internevo_ops/__init__.py | 2 ++ .../flash_attention/internevo_flash_attention_dipu.py | 8 ++------ .../flash_attention/interntrain_flash_attention_dipu.py | 9 +++------ 3 files changed, 7 insertions(+), 12 deletions(-) diff --git a/deeplink_ext/internevo_ops/__init__.py b/deeplink_ext/internevo_ops/__init__.py index 41bbae8d..aa78e6f5 100644 --- a/deeplink_ext/internevo_ops/__init__.py +++ b/deeplink_ext/internevo_ops/__init__.py @@ -13,6 +13,8 @@ flash_attn_varlen_qkvpacked_func, flash_attn_varlen_kvpacked_func, flash_attn_varlen_func, + FlashCrossAttention, + FlashSelfAttention, ) from deeplink_ext.ops.rms_norm import MixedFusedRMSNorm diff --git a/deeplink_ext/ops/flash_attention/internevo_flash_attention_dipu.py b/deeplink_ext/ops/flash_attention/internevo_flash_attention_dipu.py index 09dff62f..4833b474 100644 --- a/deeplink_ext/ops/flash_attention/internevo_flash_attention_dipu.py +++ b/deeplink_ext/ops/flash_attention/internevo_flash_attention_dipu.py @@ -4,12 +4,8 @@ import torch_dipu import deeplink_ext.cpp_extensions as ext -if torch_dipu.dipu.vendor_type == "NPU": - assert hasattr(ext, "custom_fa_fwd") and hasattr(ext, "custom_fa_bwd") - assert hasattr(ext, "custom_fa_varlen_fwd") and hasattr(ext, "custom_fa_varlen_bwd") -else: - assert hasattr(ext, "fa_fwd") and hasattr(ext, "fa_bwd") - assert hasattr(ext, "fa_varlen_fwd") and hasattr(ext, "fa_varlen_bwd") +assert hasattr(ext, "fa_fwd") and hasattr(ext, "fa_bwd") +assert hasattr(ext, "fa_varlen_fwd") and hasattr(ext, "fa_varlen_bwd") __all__ = [ "flash_attn_qkvpacked_func", diff --git a/deeplink_ext/ops/flash_attention/interntrain_flash_attention_dipu.py b/deeplink_ext/ops/flash_attention/interntrain_flash_attention_dipu.py index 5b3822ae..4aa56d05 100644 --- a/deeplink_ext/ops/flash_attention/interntrain_flash_attention_dipu.py +++ b/deeplink_ext/ops/flash_attention/interntrain_flash_attention_dipu.py @@ -5,12 +5,9 @@ import torch.nn as nn import deeplink_ext.cpp_extensions as ext -if torch_dipu.dipu.vendor_type == "NPU": - assert hasattr(ext, "custom_fa_fwd") and hasattr(ext, "custom_fa_bwd") - assert hasattr(ext, "custom_fa_varlen_fwd") and hasattr(ext, "custom_fa_varlen_bwd") -else: - assert hasattr(ext, "fa_fwd") and hasattr(ext, "fa_bwd") - assert hasattr(ext, "fa_varlen_fwd") and hasattr(ext, "fa_varlen_bwd") + +assert hasattr(ext, "fa_fwd") and hasattr(ext, "fa_bwd") +assert hasattr(ext, "fa_varlen_fwd") and hasattr(ext, "fa_varlen_bwd") __all__ = ["FlashSelfAttention", "FlashCrossAttention"] From 2ef9df0e303b084b828d0dcf1de2b7434bc0633b Mon Sep 17 00:00:00 2001 From: HuayiL <442488254@qq.com> Date: Thu, 14 Nov 2024 11:57:04 +0800 Subject: [PATCH 5/7] refactor the import ways --- deeplink_ext/__init__.py | 3 - deeplink_ext/ops/flash_attention/__init__.py | 51 +++---- .../ops/flash_attention/flash_attn_utils.py | 80 ++++++++++ deeplink_ext/ops/rms_norm/__init__.py | 10 +- deeplink_ext/ops/rms_norm/rms_norm_utils.py | 22 +++ deeplink_ext/ops/rotary_embedding/__init__.py | 14 +- ...tary_embedding.py => _rotary_embedding.py} | 2 +- ...ding_dipu.py => _rotary_embedding_dipu.py} | 0 ...lback.py => _rotary_embedding_fallback.py} | 0 .../internevo_rotary_embedding.py | 11 -- .../internevo_rotary_embedding_dipu.py | 90 ----------- .../internevo_rotary_embedding_fallback.py | 141 ------------------ 12 files changed, 130 insertions(+), 294 deletions(-) create mode 100644 deeplink_ext/ops/flash_attention/flash_attn_utils.py create mode 100644 deeplink_ext/ops/rms_norm/rms_norm_utils.py rename deeplink_ext/ops/rotary_embedding/{interntrain_rotary_embedding.py => _rotary_embedding.py} (76%) rename deeplink_ext/ops/rotary_embedding/{interntrain_rotary_embedding_dipu.py => _rotary_embedding_dipu.py} (100%) rename deeplink_ext/ops/rotary_embedding/{interntrain_rotary_embedding_fallback.py => _rotary_embedding_fallback.py} (100%) delete mode 100644 deeplink_ext/ops/rotary_embedding/internevo_rotary_embedding.py delete mode 100644 deeplink_ext/ops/rotary_embedding/internevo_rotary_embedding_dipu.py delete mode 100644 deeplink_ext/ops/rotary_embedding/internevo_rotary_embedding_fallback.py diff --git a/deeplink_ext/__init__.py b/deeplink_ext/__init__.py index 787302ac..293335bf 100644 --- a/deeplink_ext/__init__.py +++ b/deeplink_ext/__init__.py @@ -9,9 +9,6 @@ def _init(): platform_type = deeplink_ext_get_platform_type() if platform_type == PlatformType.TORCH_DIPU: import torch_dipu - elif platform_type == PlatformType.TORCH_NPU: - warnings.warn("DeepLinkExt using torch_npu ...", ImportWarning) - import torch_npu else: raise ImportError diff --git a/deeplink_ext/ops/flash_attention/__init__.py b/deeplink_ext/ops/flash_attention/__init__.py index 486b9c37..9b6026f6 100644 --- a/deeplink_ext/ops/flash_attention/__init__.py +++ b/deeplink_ext/ops/flash_attention/__init__.py @@ -1,32 +1,25 @@ _not_impl = "[deeplink_ext] {op_name} is not implemented in diopi. Falling back to the slower torch implementation." -try: - from .internevo_flash_attention import ( - flash_attn_qkvpacked_func, - flash_attn_kvpacked_func, - flash_attn_func, - flash_attn_varlen_qkvpacked_func, - flash_attn_varlen_kvpacked_func, - flash_attn_varlen_func, - ) -except Exception as e: - print(_not_impl.format(op_name="flash attention")) - from .internevo_flash_attention_fallback import ( - flash_attn_qkvpacked_func_torch as flash_attn_qkvpacked_func, - flash_attn_kvpacked_func_torch as flash_attn_kvpacked_func, - flash_attn_func_torch as flash_attn_func, - flash_attn_varlen_qkvpacked_func_torch as flash_attn_varlen_qkvpacked_func, - flash_attn_varlen_kvpacked_func_torch as flash_attn_varlen_kvpacked_func, - flash_attn_varlen_func_torch as flash_attn_varlen_func, - ) +from .flash_attn_utils import import_flash_attn_modules, import_flash_attn_funcs -try: - from .interntrain_flash_attention import FlashSelfAttention, FlashCrossAttention -except Exception as e: - print(_not_impl.format(op_name="flash attention")) - from .interntrain_flash_attention_fallback import ( - SelfAttention as FlashSelfAttention, - ) - from .interntrain_flash_attention_fallback import ( - CrossAttention as FlashCrossAttention, - ) +FlashSelfAttention, FlashCrossAttention = import_flash_attn_modules() +( + flash_attn_qkvpacked_func, + flash_attn_kvpacked_func, + flash_attn_func, + flash_attn_varlen_qkvpacked_func, + flash_attn_varlen_kvpacked_func, + flash_attn_varlen_func, +) = import_flash_attn_funcs() + +from .flash_attn_utils import patch_mha, patch_flash_attn_funcs + +patch_mha(FlashSelfAttention, FlashCrossAttention) +patch_flash_attn_funcs( + flash_attn_qkvpacked_func, + flash_attn_kvpacked_func, + flash_attn_func, + flash_attn_varlen_qkvpacked_func, + flash_attn_varlen_kvpacked_func, + flash_attn_varlen_func, +) diff --git a/deeplink_ext/ops/flash_attention/flash_attn_utils.py b/deeplink_ext/ops/flash_attention/flash_attn_utils.py new file mode 100644 index 00000000..121c5d0c --- /dev/null +++ b/deeplink_ext/ops/flash_attention/flash_attn_utils.py @@ -0,0 +1,80 @@ +_not_impl = "[deeplink_ext] {op_name} is not implemented in diopi. Falling back to the slower torch implementation." + + +def patch_mha(CustomFlashSelfAttention, CustomFlashCrossAttention): + FlashSelfAttention, FlashCrossAttention = ( + CustomFlashSelfAttention, + CustomFlashCrossAttention, + ) + try: + import flash_attn.modules.mha as mha + except Exception as e: + print("Unable to import flash_attn, skip mocking flash_attn") + return + + mha.FlashSelfAttention = FlashSelfAttention + mha.FlashCrossAttention = FlashCrossAttention + + +def import_flash_attn_modules(): + try: + from .interntrain_flash_attention import FlashSelfAttention + from .interntrain_flash_attention import FlashCrossAttention + except Exception as e: + print(_not_impl.format(op_name="flash attention")) + from .interntrain_flash_attention import SelfAttention as FlashSelfAttention + from .interntrain_flash_attention import CrossAttention as FlashCrossAttention + + return (FlashSelfAttention, FlashCrossAttention) + + +def patch_flash_attn_funcs( + flash_attn_qkvpacked_func, + flash_attn_kvpacked_func, + flash_attn_func, + flash_attn_varlen_qkvpacked_func, + flash_attn_varlen_kvpacked_func, + flash_attn_varlen_func, +): + try: + import flash_attn + except Exception as e: + print("Unable to import flash_attn, skip mocking flash_attn") + return + + flash_attn.flash_attn_qkvpacked_func = flash_attn_qkvpacked_func + flash_attn.flash_attn_kvpacked_func = flash_attn_kvpacked_func + flash_attn.flash_attn_func = flash_attn_func + flash_attn.flash_attn_varlen_qkvpacked_func = flash_attn_varlen_qkvpacked_func + flash_attn.flash_attn_varlen_kvpacked_func = flash_attn_varlen_kvpacked_func + flash_attn.flash_attn_varlen_func = flash_attn_varlen_func + + +def import_flash_attn_funcs(): + try: + from .internevo_flash_attention import ( + flash_attn_qkvpacked_func, + flash_attn_kvpacked_func, + flash_attn_func, + flash_attn_varlen_qkvpacked_func, + flash_attn_varlen_kvpacked_func, + flash_attn_varlen_func, + ) + except Exception as e: + print(_not_impl.format(op_name="flash attention")) + from .internevo_flash_attention_fallback import ( + flash_attn_qkvpacked_func_torch as flash_attn_qkvpacked_func, + flash_attn_kvpacked_func_torch as flash_attn_kvpacked_func, + flash_attn_func_torch as flash_attn_func, + flash_attn_varlen_qkvpacked_func_torch as flash_attn_varlen_qkvpacked_func, + flash_attn_varlen_kvpacked_func_torch as flash_attn_varlen_kvpacked_func, + flash_attn_varlen_func_torch as flash_attn_varlen_func, + ) + return ( + flash_attn_qkvpacked_func, + flash_attn_kvpacked_func, + flash_attn_func, + flash_attn_varlen_qkvpacked_func, + flash_attn_varlen_kvpacked_func, + flash_attn_varlen_func, + ) diff --git a/deeplink_ext/ops/rms_norm/__init__.py b/deeplink_ext/ops/rms_norm/__init__.py index f2145764..e6745aa0 100644 --- a/deeplink_ext/ops/rms_norm/__init__.py +++ b/deeplink_ext/ops/rms_norm/__init__.py @@ -8,10 +8,6 @@ ) from .easyllm_rms_norm_fallback import rms_norm_torch as rms_norm -try: - from .internevo_rms_norm import MixedFusedRMSNorm -except: - print( - _not_impl.format(op_name="RMSNorm"), - ) - from .internevo_rms_norm_fallback import MixedRMSNormTorch as MixedFusedRMSNorm +from .rms_norm_utils import import_RMSNorm, patch_RMSNorm + +MixedFusedRMSNorm = import_RMSNorm() diff --git a/deeplink_ext/ops/rms_norm/rms_norm_utils.py b/deeplink_ext/ops/rms_norm/rms_norm_utils.py new file mode 100644 index 00000000..f2b533d6 --- /dev/null +++ b/deeplink_ext/ops/rms_norm/rms_norm_utils.py @@ -0,0 +1,22 @@ +_not_impl = "[deeplink_ext] {op_name} is not implemented in diopi. Falling back to the slower torch implementation." + + +def patch_RMSNorm(MixedFusedRMSNorm): + try: + import apex.normalization.fused_layer_norm as fused_layer_norm + except Exception as e: + print("Unable to import fused_layer_norm, skip mocking fused_layer_norm") + return + + fused_layer_norm.MixedFusedRMSNorm = MixedFusedRMSNorm + + +def import_RMSNorm(): + try: + from .internevo_rms_norm import MixedFusedRMSNorm + except: + print( + _not_impl.format(op_name="RMSNorm"), + ) + from .internevo_rms_norm_fallback import MixedRMSNormTorch as MixedFusedRMSNorm + return MixedFusedRMSNorm diff --git a/deeplink_ext/ops/rotary_embedding/__init__.py b/deeplink_ext/ops/rotary_embedding/__init__.py index b1528327..2c7d08b1 100644 --- a/deeplink_ext/ops/rotary_embedding/__init__.py +++ b/deeplink_ext/ops/rotary_embedding/__init__.py @@ -1,20 +1,10 @@ _not_impl = "[deeplink_ext] {op_name} is not implemented in diopi. Falling back to the slower torch implementation." try: - from .internevo_rotary_embedding import ApplyRotaryEmb + from ._rotary_embedding import ApplyRotaryEmb, ApplyRotaryEmbQKV_ except: print(_not_impl.format(op_name="rotary embedding")) - from .internevo_rotary_embedding_fallback import ( + from ._rotary_embedding_fallback import ( ApplyRotaryEmbTorch as ApplyRotaryEmb, - ) - -try: - from .interntrain_rotary_embedding import ApplyRotaryEmb, ApplyRotaryEmbQKV_ -except: - print(_not_impl.format(op_name="rotary embedding")) - from .interntrain_rotary_embedding_fallback import ( - ApplyRotaryEmbTorch as ApplyRotaryEmb, - ) - from .interntrain_rotary_embedding_fallback import ( ApplyRotaryEmbQKV_Torch as ApplyRotaryEmbQKV_, ) diff --git a/deeplink_ext/ops/rotary_embedding/interntrain_rotary_embedding.py b/deeplink_ext/ops/rotary_embedding/_rotary_embedding.py similarity index 76% rename from deeplink_ext/ops/rotary_embedding/interntrain_rotary_embedding.py rename to deeplink_ext/ops/rotary_embedding/_rotary_embedding.py index 61f98c64..3555e625 100644 --- a/deeplink_ext/ops/rotary_embedding/interntrain_rotary_embedding.py +++ b/deeplink_ext/ops/rotary_embedding/_rotary_embedding.py @@ -4,7 +4,7 @@ platform_type = deeplink_ext_get_platform_type() if platform_type == PlatformType.TORCH_DIPU: - from .interntrain_rotary_embedding_dipu import ApplyRotaryEmb, ApplyRotaryEmbQKV_ + from ._rotary_embedding_dipu import ApplyRotaryEmb, ApplyRotaryEmbQKV_ else: raise ImportError diff --git a/deeplink_ext/ops/rotary_embedding/interntrain_rotary_embedding_dipu.py b/deeplink_ext/ops/rotary_embedding/_rotary_embedding_dipu.py similarity index 100% rename from deeplink_ext/ops/rotary_embedding/interntrain_rotary_embedding_dipu.py rename to deeplink_ext/ops/rotary_embedding/_rotary_embedding_dipu.py diff --git a/deeplink_ext/ops/rotary_embedding/interntrain_rotary_embedding_fallback.py b/deeplink_ext/ops/rotary_embedding/_rotary_embedding_fallback.py similarity index 100% rename from deeplink_ext/ops/rotary_embedding/interntrain_rotary_embedding_fallback.py rename to deeplink_ext/ops/rotary_embedding/_rotary_embedding_fallback.py diff --git a/deeplink_ext/ops/rotary_embedding/internevo_rotary_embedding.py b/deeplink_ext/ops/rotary_embedding/internevo_rotary_embedding.py deleted file mode 100644 index 764e0206..00000000 --- a/deeplink_ext/ops/rotary_embedding/internevo_rotary_embedding.py +++ /dev/null @@ -1,11 +0,0 @@ -# Copyright (c) 2024, DeepLink. - -from deeplink_ext.utils import PlatformType, deeplink_ext_get_platform_type - -platform_type = deeplink_ext_get_platform_type() -if platform_type == PlatformType.TORCH_DIPU: - from .internevo_rotary_embedding_dipu import ApplyRotaryEmb -else: - raise ImportError - -__all__ = ["ApplyRotaryEmb"] diff --git a/deeplink_ext/ops/rotary_embedding/internevo_rotary_embedding_dipu.py b/deeplink_ext/ops/rotary_embedding/internevo_rotary_embedding_dipu.py deleted file mode 100644 index 7d5a4eb4..00000000 --- a/deeplink_ext/ops/rotary_embedding/internevo_rotary_embedding_dipu.py +++ /dev/null @@ -1,90 +0,0 @@ -# Copyright (c) 2024, DeepLink. - -import torch -from einops import rearrange -import deeplink_ext.cpp_extensions as ext - -assert hasattr(ext, "apply_rotary") - -__all__ = ["ApplyRotaryEmb"] - - -# adpated from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/layers/rotary.py#L35 -class ApplyRotaryEmb(torch.autograd.Function): - """ - ApplyRotaryEmb - """ - - @staticmethod - def forward( - ctx, - x: torch.Tensor, - cos: torch.Tensor, - sin: torch.Tensor, - interleaved: bool = False, - in_place: bool = False, - ): - """ - x: (batch_size, seqlen, nheads, headdim) - cos, sin: (seqlen, rotary_dim / 2) - interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead - of 1st half and 2nd half (GPT-NeoX style). - rotary_dim must be <= headdim - Apply rotary embedding to the first rotary_dim of x. - """ - *_, seqlen, _, head_dim = x.shape - rotary_seqlen, rotary_dim = cos.shape - rotary_dim *= 2 - - assert rotary_dim <= head_dim - assert seqlen <= rotary_seqlen - assert sin.shape == (rotary_seqlen, rotary_dim // 2) - - if in_place: - out = x - else: - out = torch.empty_like(x) - - ext.apply_rotary( - out[..., :rotary_dim], - x[..., :rotary_dim], - rearrange(cos[:seqlen], "s d -> s 1 d"), - rearrange(sin[:seqlen], "s d -> s 1 d"), - False, - interleaved, - ) - - if rotary_dim < head_dim and not in_place: - out[..., rotary_dim:].copy_(x[..., rotary_dim:]) - - ctx.save_for_backward(cos, sin) - ctx.interleaved = interleaved - ctx.in_place = in_place - - return out - - @staticmethod - def backward(ctx, do): - cos, sin = ctx.saved_tensors - *_, seqlen, _, head_dim = do.shape - rotary_dim = cos.shape[-1] - rotary_dim *= 2 - - if ctx.in_place: - dx = do - else: - dx = torch.empty_like(do) - - ext.apply_rotary( - dx[..., :rotary_dim], - do[..., :rotary_dim], - rearrange(cos[:seqlen], "s d -> s 1 d"), - rearrange(sin[:seqlen], "s d -> s 1 d"), - True, - ctx.interleaved, - ) - - if rotary_dim < head_dim and not ctx.in_place: - dx[..., rotary_dim:].copy_(do[..., rotary_dim:]) - - return dx, None, None, None, None diff --git a/deeplink_ext/ops/rotary_embedding/internevo_rotary_embedding_fallback.py b/deeplink_ext/ops/rotary_embedding/internevo_rotary_embedding_fallback.py deleted file mode 100644 index c8956025..00000000 --- a/deeplink_ext/ops/rotary_embedding/internevo_rotary_embedding_fallback.py +++ /dev/null @@ -1,141 +0,0 @@ -# Copyright (c) 2024, DeepLink. - -import torch -from einops import rearrange - -__all__ = ["ApplyRotaryEmbTorch"] - - -def _torch_apply_rotary_func( - x1: torch.Tensor, - x2: torch.Tensor, - cos: torch.Tensor, - sin: torch.Tensor, - out1: torch.Tensor, - out2: torch.Tensor, - conj: bool = False, -): - assert ( - x1.device == x2.device == cos.device == sin.device - ), "All inputs must be on the same device" - assert ( - x1.dtype == x2.dtype == cos.dtype == sin.dtype - ), "All inputs must have the same dtype" - assert x1.size() == x2.size(), "Input x1 and x2 must have the same sizes" - assert cos.size() == sin.size(), "Input cos and sin must have the same sizes" - - x1, x2, cos, sin = x1.float(), x2.float(), cos.float(), sin.float() - - if conj: - out1.copy_(x1 * cos + x2 * sin) - out2.copy_(-x1 * sin + x2 * cos) - else: - out1.copy_(x1 * cos - x2 * sin) - out2.copy_(x1 * sin + x2 * cos) - - -# adpated from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/layers/rotary.py#L35 -class ApplyRotaryEmbTorch(torch.autograd.Function): - """ - ApplyRotaryEmbTorch - """ - - @staticmethod - def forward( - ctx, - x: torch.Tensor, - cos: torch.Tensor, - sin: torch.Tensor, - interleaved: bool = False, - in_place: bool = False, - ): - """ - x: (batch_size, seqlen, nheads, headdim) - cos, sin: (seqlen, rotary_dim / 2) - interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead - of 1st half and 2nd half (GPT-NeoX style). - rotary_dim must be <= headdim - Apply rotary embedding to the first rotary_dim of x. - """ - *_, seqlen, _, head_dim = x.shape - rotary_seqlen, rotary_dim = cos.shape - rotary_dim *= 2 - - assert rotary_dim <= head_dim - assert seqlen <= rotary_seqlen - assert sin.shape == (rotary_seqlen, rotary_dim // 2) - - x_ro = x[..., :rotary_dim] - x1, x2 = ( - (x_ro[..., ::2], x_ro[..., 1::2]) if interleaved else x_ro.chunk(2, dim=-1) - ) - - if in_place: - out, o1, o2 = x, x1, x2 - else: - out = torch.empty_like(x) - out_ro = out[..., :rotary_dim] - o1, o2 = ( - (out_ro[..., ::2], out_ro[..., 1::2]) - if interleaved - else out_ro.chunk(2, dim=-1) - ) - - _torch_apply_rotary_func( - x1, - x2, - rearrange(cos[:seqlen], "s d -> s 1 d"), - rearrange(sin[:seqlen], "s d -> s 1 d"), - o1, - o2, - False, - ) - - if rotary_dim < head_dim and not in_place: - out[..., rotary_dim:].copy_(x[..., rotary_dim:]) - - ctx.save_for_backward(cos, sin) - ctx.interleaved = interleaved - ctx.in_place = in_place - - return out - - @staticmethod - def backward(ctx, do): - cos, sin = ctx.saved_tensors - *_, seqlen, _, head_dim = do.shape - rotary_dim = cos.shape[-1] - rotary_dim *= 2 - - do_ro = do[..., :rotary_dim] - do1, do2 = ( - (do_ro[..., ::2], do_ro[..., 1::2]) - if ctx.interleaved - else do_ro.chunk(2, dim=-1) - ) - - if ctx.in_place: - dx, dx1, dx2 = do, do1, do2 - else: - dx = torch.empty_like(do) - dx_ro = dx[..., :rotary_dim] - dx1, dx2 = ( - (dx_ro[..., ::2], dx_ro[..., 1::2]) - if ctx.interleaved - else dx_ro.chunk(2, dim=-1) - ) - - _torch_apply_rotary_func( - do1, - do2, - rearrange(cos[:seqlen], "s d -> s 1 d"), - rearrange(sin[:seqlen], "s d -> s 1 d"), - dx1, - dx2, - True, - ) - - if rotary_dim < head_dim and not ctx.in_place: - dx[..., rotary_dim:].copy_(do[..., rotary_dim:]) - - return dx, None, None, None, None From 8e62d9946b753c191c29e2287559e0bc86f91e2e Mon Sep 17 00:00:00 2001 From: HuayiL <442488254@qq.com> Date: Tue, 19 Nov 2024 16:58:38 +0800 Subject: [PATCH 6/7] refactor the way of import deeplink_ext Now apply_rotary is same with rotary_emb.apply_rotary and import deeplink_ext will do mock now. --- csrc/extensions.cpp | 9 +- deeplink_ext/__init__.py | 3 - deeplink_ext/internevo_ops/__init__.py | 4 +- deeplink_ext/interntrain_ops/__init__.py | 3 +- deeplink_ext/ops/flash_attention/__init__.py | 51 +++---- .../ops/flash_attention/flash_attn_utils.py | 80 ++++++++++ .../internevo_flash_attention.py | 2 +- .../interntrain_flash_attention.py | 2 +- deeplink_ext/ops/rms_norm/__init__.py | 10 +- .../ops/rms_norm/internevo_rms_norm.py | 2 +- deeplink_ext/ops/rms_norm/rms_norm_utils.py | 22 +++ deeplink_ext/ops/rotary_embedding/__init__.py | 22 +-- .../ops/rotary_embedding/_rotary_embedding.py | 11 ++ ...ding_dipu.py => _rotary_embedding_dipu.py} | 79 ++++++---- ...lback.py => _rotary_embedding_fallback.py} | 2 +- .../internevo_rotary_embedding.py | 11 -- .../internevo_rotary_embedding_dipu.py | 90 ----------- .../internevo_rotary_embedding_fallback.py | 141 ------------------ .../interntrain_rotary_embedding.py | 11 -- .../ops/rotary_embedding/rotary_emb_utils.py | 32 ++++ 20 files changed, 238 insertions(+), 349 deletions(-) create mode 100644 deeplink_ext/ops/flash_attention/flash_attn_utils.py create mode 100644 deeplink_ext/ops/rms_norm/rms_norm_utils.py create mode 100644 deeplink_ext/ops/rotary_embedding/_rotary_embedding.py rename deeplink_ext/ops/rotary_embedding/{interntrain_rotary_embedding_dipu.py => _rotary_embedding_dipu.py} (80%) rename deeplink_ext/ops/rotary_embedding/{interntrain_rotary_embedding_fallback.py => _rotary_embedding_fallback.py} (98%) delete mode 100644 deeplink_ext/ops/rotary_embedding/internevo_rotary_embedding.py delete mode 100644 deeplink_ext/ops/rotary_embedding/internevo_rotary_embedding_dipu.py delete mode 100644 deeplink_ext/ops/rotary_embedding/internevo_rotary_embedding_fallback.py delete mode 100644 deeplink_ext/ops/rotary_embedding/interntrain_rotary_embedding.py create mode 100644 deeplink_ext/ops/rotary_embedding/rotary_emb_utils.py diff --git a/csrc/extensions.cpp b/csrc/extensions.cpp index bb46f710..4a5edfd3 100644 --- a/csrc/extensions.cpp +++ b/csrc/extensions.cpp @@ -60,10 +60,11 @@ void extRmsNormBackward(at::Tensor& grad_input, at::Tensor& grad_weight, eps); } -void extApplyRotary(at::Tensor& output, const at::Tensor& input, +void extApplyRotary(const at::Tensor& input1, const at::Tensor& input2, const at::Tensor& cos, const at::Tensor& sin, - const bool conj, const bool interleaved) { - callDiopi(diopiRotaryEmbedding, output, input, cos, sin, conj, interleaved); + at::Tensor& output1, at::Tensor& output2, + const bool conj) { + callDiopi(diopiApplyRotary, output1, output2, input1, input2, cos, sin, conj, false); } auto extMultiHeadAttention(at::Tensor& q, at::Tensor& k, at::Tensor& v, @@ -443,7 +444,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("rms_norm_backward", &extRmsNormBackward, "deeplink ext_rms_norm_backward"); } - if (&diopiRotaryEmbedding != nullptr) { + if (&diopiApplyRotary != nullptr) { m.def("apply_rotary", &extApplyRotary, "deeplink ext_apply_rotary"); } if (&diopiMultiHeadAttention != nullptr) { diff --git a/deeplink_ext/__init__.py b/deeplink_ext/__init__.py index 787302ac..293335bf 100644 --- a/deeplink_ext/__init__.py +++ b/deeplink_ext/__init__.py @@ -9,9 +9,6 @@ def _init(): platform_type = deeplink_ext_get_platform_type() if platform_type == PlatformType.TORCH_DIPU: import torch_dipu - elif platform_type == PlatformType.TORCH_NPU: - warnings.warn("DeepLinkExt using torch_npu ...", ImportWarning) - import torch_npu else: raise ImportError diff --git a/deeplink_ext/internevo_ops/__init__.py b/deeplink_ext/internevo_ops/__init__.py index aa78e6f5..4f6049e2 100644 --- a/deeplink_ext/internevo_ops/__init__.py +++ b/deeplink_ext/internevo_ops/__init__.py @@ -19,7 +19,7 @@ from deeplink_ext.ops.rms_norm import MixedFusedRMSNorm -from deeplink_ext.ops.rotary_embedding import ApplyRotaryEmb +from deeplink_ext.ops.rotary_embedding import ApplyRotaryEmb, ApplyRotaryEmbQKV_, apply_rotary __all__ = [ "AdamW", @@ -31,4 +31,6 @@ "flash_attn_varlen_func", "MixedFusedRMSNorm", "ApplyRotaryEmb", + "ApplyRotaryEmbQKV_", + "apply_rotary", ] diff --git a/deeplink_ext/interntrain_ops/__init__.py b/deeplink_ext/interntrain_ops/__init__.py index 60166bd6..3acaf6f5 100644 --- a/deeplink_ext/interntrain_ops/__init__.py +++ b/deeplink_ext/interntrain_ops/__init__.py @@ -10,7 +10,7 @@ from deeplink_ext.ops.flash_attention import FlashSelfAttention, FlashCrossAttention from deeplink_ext.ops.rms_norm import MixedFusedRMSNorm -from deeplink_ext.ops.rotary_embedding import ApplyRotaryEmb, ApplyRotaryEmbQKV_ +from deeplink_ext.ops.rotary_embedding import ApplyRotaryEmb, ApplyRotaryEmbQKV_, apply_rotary __all__ = [ @@ -20,4 +20,5 @@ "MixedFusedRMSNorm", "ApplyRotaryEmb", "ApplyRotaryEmbQKV_", + "apply_rotary", ] diff --git a/deeplink_ext/ops/flash_attention/__init__.py b/deeplink_ext/ops/flash_attention/__init__.py index 486b9c37..9b6026f6 100644 --- a/deeplink_ext/ops/flash_attention/__init__.py +++ b/deeplink_ext/ops/flash_attention/__init__.py @@ -1,32 +1,25 @@ _not_impl = "[deeplink_ext] {op_name} is not implemented in diopi. Falling back to the slower torch implementation." -try: - from .internevo_flash_attention import ( - flash_attn_qkvpacked_func, - flash_attn_kvpacked_func, - flash_attn_func, - flash_attn_varlen_qkvpacked_func, - flash_attn_varlen_kvpacked_func, - flash_attn_varlen_func, - ) -except Exception as e: - print(_not_impl.format(op_name="flash attention")) - from .internevo_flash_attention_fallback import ( - flash_attn_qkvpacked_func_torch as flash_attn_qkvpacked_func, - flash_attn_kvpacked_func_torch as flash_attn_kvpacked_func, - flash_attn_func_torch as flash_attn_func, - flash_attn_varlen_qkvpacked_func_torch as flash_attn_varlen_qkvpacked_func, - flash_attn_varlen_kvpacked_func_torch as flash_attn_varlen_kvpacked_func, - flash_attn_varlen_func_torch as flash_attn_varlen_func, - ) +from .flash_attn_utils import import_flash_attn_modules, import_flash_attn_funcs -try: - from .interntrain_flash_attention import FlashSelfAttention, FlashCrossAttention -except Exception as e: - print(_not_impl.format(op_name="flash attention")) - from .interntrain_flash_attention_fallback import ( - SelfAttention as FlashSelfAttention, - ) - from .interntrain_flash_attention_fallback import ( - CrossAttention as FlashCrossAttention, - ) +FlashSelfAttention, FlashCrossAttention = import_flash_attn_modules() +( + flash_attn_qkvpacked_func, + flash_attn_kvpacked_func, + flash_attn_func, + flash_attn_varlen_qkvpacked_func, + flash_attn_varlen_kvpacked_func, + flash_attn_varlen_func, +) = import_flash_attn_funcs() + +from .flash_attn_utils import patch_mha, patch_flash_attn_funcs + +patch_mha(FlashSelfAttention, FlashCrossAttention) +patch_flash_attn_funcs( + flash_attn_qkvpacked_func, + flash_attn_kvpacked_func, + flash_attn_func, + flash_attn_varlen_qkvpacked_func, + flash_attn_varlen_kvpacked_func, + flash_attn_varlen_func, +) diff --git a/deeplink_ext/ops/flash_attention/flash_attn_utils.py b/deeplink_ext/ops/flash_attention/flash_attn_utils.py new file mode 100644 index 00000000..121c5d0c --- /dev/null +++ b/deeplink_ext/ops/flash_attention/flash_attn_utils.py @@ -0,0 +1,80 @@ +_not_impl = "[deeplink_ext] {op_name} is not implemented in diopi. Falling back to the slower torch implementation." + + +def patch_mha(CustomFlashSelfAttention, CustomFlashCrossAttention): + FlashSelfAttention, FlashCrossAttention = ( + CustomFlashSelfAttention, + CustomFlashCrossAttention, + ) + try: + import flash_attn.modules.mha as mha + except Exception as e: + print("Unable to import flash_attn, skip mocking flash_attn") + return + + mha.FlashSelfAttention = FlashSelfAttention + mha.FlashCrossAttention = FlashCrossAttention + + +def import_flash_attn_modules(): + try: + from .interntrain_flash_attention import FlashSelfAttention + from .interntrain_flash_attention import FlashCrossAttention + except Exception as e: + print(_not_impl.format(op_name="flash attention")) + from .interntrain_flash_attention import SelfAttention as FlashSelfAttention + from .interntrain_flash_attention import CrossAttention as FlashCrossAttention + + return (FlashSelfAttention, FlashCrossAttention) + + +def patch_flash_attn_funcs( + flash_attn_qkvpacked_func, + flash_attn_kvpacked_func, + flash_attn_func, + flash_attn_varlen_qkvpacked_func, + flash_attn_varlen_kvpacked_func, + flash_attn_varlen_func, +): + try: + import flash_attn + except Exception as e: + print("Unable to import flash_attn, skip mocking flash_attn") + return + + flash_attn.flash_attn_qkvpacked_func = flash_attn_qkvpacked_func + flash_attn.flash_attn_kvpacked_func = flash_attn_kvpacked_func + flash_attn.flash_attn_func = flash_attn_func + flash_attn.flash_attn_varlen_qkvpacked_func = flash_attn_varlen_qkvpacked_func + flash_attn.flash_attn_varlen_kvpacked_func = flash_attn_varlen_kvpacked_func + flash_attn.flash_attn_varlen_func = flash_attn_varlen_func + + +def import_flash_attn_funcs(): + try: + from .internevo_flash_attention import ( + flash_attn_qkvpacked_func, + flash_attn_kvpacked_func, + flash_attn_func, + flash_attn_varlen_qkvpacked_func, + flash_attn_varlen_kvpacked_func, + flash_attn_varlen_func, + ) + except Exception as e: + print(_not_impl.format(op_name="flash attention")) + from .internevo_flash_attention_fallback import ( + flash_attn_qkvpacked_func_torch as flash_attn_qkvpacked_func, + flash_attn_kvpacked_func_torch as flash_attn_kvpacked_func, + flash_attn_func_torch as flash_attn_func, + flash_attn_varlen_qkvpacked_func_torch as flash_attn_varlen_qkvpacked_func, + flash_attn_varlen_kvpacked_func_torch as flash_attn_varlen_kvpacked_func, + flash_attn_varlen_func_torch as flash_attn_varlen_func, + ) + return ( + flash_attn_qkvpacked_func, + flash_attn_kvpacked_func, + flash_attn_func, + flash_attn_varlen_qkvpacked_func, + flash_attn_varlen_kvpacked_func, + flash_attn_varlen_func, + ) diff --git a/deeplink_ext/ops/flash_attention/internevo_flash_attention.py b/deeplink_ext/ops/flash_attention/internevo_flash_attention.py index 2408b207..b5d951ac 100644 --- a/deeplink_ext/ops/flash_attention/internevo_flash_attention.py +++ b/deeplink_ext/ops/flash_attention/internevo_flash_attention.py @@ -3,7 +3,7 @@ from deeplink_ext.utils import PlatformType, deeplink_ext_get_platform_type platform_type = deeplink_ext_get_platform_type() -if platform_type == PlatformType.TORCH_DIPU: +if platform_type == PlatformType.TORCH_DIPU or platform_type == PlatformType.TORCH_CUDA: from .internevo_flash_attention_dipu import ( flash_attn_func, flash_attn_kvpacked_func, diff --git a/deeplink_ext/ops/flash_attention/interntrain_flash_attention.py b/deeplink_ext/ops/flash_attention/interntrain_flash_attention.py index 0ba83c65..facad039 100644 --- a/deeplink_ext/ops/flash_attention/interntrain_flash_attention.py +++ b/deeplink_ext/ops/flash_attention/interntrain_flash_attention.py @@ -3,7 +3,7 @@ from deeplink_ext.utils import PlatformType, deeplink_ext_get_platform_type platform_type = deeplink_ext_get_platform_type() -if platform_type == PlatformType.TORCH_DIPU: +if platform_type == PlatformType.TORCH_DIPU or platform_type == PlatformType.TORCH_CUDA: from .interntrain_flash_attention_dipu import ( FlashSelfAttention, FlashCrossAttention, diff --git a/deeplink_ext/ops/rms_norm/__init__.py b/deeplink_ext/ops/rms_norm/__init__.py index f2145764..e6745aa0 100644 --- a/deeplink_ext/ops/rms_norm/__init__.py +++ b/deeplink_ext/ops/rms_norm/__init__.py @@ -8,10 +8,6 @@ ) from .easyllm_rms_norm_fallback import rms_norm_torch as rms_norm -try: - from .internevo_rms_norm import MixedFusedRMSNorm -except: - print( - _not_impl.format(op_name="RMSNorm"), - ) - from .internevo_rms_norm_fallback import MixedRMSNormTorch as MixedFusedRMSNorm +from .rms_norm_utils import import_RMSNorm, patch_RMSNorm + +MixedFusedRMSNorm = import_RMSNorm() diff --git a/deeplink_ext/ops/rms_norm/internevo_rms_norm.py b/deeplink_ext/ops/rms_norm/internevo_rms_norm.py index 11356346..2a3c2a45 100644 --- a/deeplink_ext/ops/rms_norm/internevo_rms_norm.py +++ b/deeplink_ext/ops/rms_norm/internevo_rms_norm.py @@ -3,7 +3,7 @@ from deeplink_ext.utils import PlatformType, deeplink_ext_get_platform_type platform_type = deeplink_ext_get_platform_type() -if platform_type == PlatformType.TORCH_DIPU: +if platform_type == PlatformType.TORCH_DIPU or platform_type == PlatformType.TORCH_CUDA: # from ._mixed_rms_norm_dipu import MixedFusedRMSNorm # Due to the accuracy problem of the npu fused operator, a torch combination is used as an alternative. from .internevo_rms_norm_fallback import MixedRMSNormTorch as MixedFusedRMSNorm diff --git a/deeplink_ext/ops/rms_norm/rms_norm_utils.py b/deeplink_ext/ops/rms_norm/rms_norm_utils.py new file mode 100644 index 00000000..f2b533d6 --- /dev/null +++ b/deeplink_ext/ops/rms_norm/rms_norm_utils.py @@ -0,0 +1,22 @@ +_not_impl = "[deeplink_ext] {op_name} is not implemented in diopi. Falling back to the slower torch implementation." + + +def patch_RMSNorm(MixedFusedRMSNorm): + try: + import apex.normalization.fused_layer_norm as fused_layer_norm + except Exception as e: + print("Unable to import fused_layer_norm, skip mocking fused_layer_norm") + return + + fused_layer_norm.MixedFusedRMSNorm = MixedFusedRMSNorm + + +def import_RMSNorm(): + try: + from .internevo_rms_norm import MixedFusedRMSNorm + except: + print( + _not_impl.format(op_name="RMSNorm"), + ) + from .internevo_rms_norm_fallback import MixedRMSNormTorch as MixedFusedRMSNorm + return MixedFusedRMSNorm diff --git a/deeplink_ext/ops/rotary_embedding/__init__.py b/deeplink_ext/ops/rotary_embedding/__init__.py index b1528327..3761959d 100644 --- a/deeplink_ext/ops/rotary_embedding/__init__.py +++ b/deeplink_ext/ops/rotary_embedding/__init__.py @@ -1,20 +1,6 @@ _not_impl = "[deeplink_ext] {op_name} is not implemented in diopi. Falling back to the slower torch implementation." -try: - from .internevo_rotary_embedding import ApplyRotaryEmb -except: - print(_not_impl.format(op_name="rotary embedding")) - from .internevo_rotary_embedding_fallback import ( - ApplyRotaryEmbTorch as ApplyRotaryEmb, - ) - -try: - from .interntrain_rotary_embedding import ApplyRotaryEmb, ApplyRotaryEmbQKV_ -except: - print(_not_impl.format(op_name="rotary embedding")) - from .interntrain_rotary_embedding_fallback import ( - ApplyRotaryEmbTorch as ApplyRotaryEmb, - ) - from .interntrain_rotary_embedding_fallback import ( - ApplyRotaryEmbQKV_Torch as ApplyRotaryEmbQKV_, - ) +from .rotary_emb_utils import import_rotary_emb_funcs, patch_rotary_emb_funcs +ApplyRotaryEmb, ApplyRotaryEmbQKV_, apply_rotary = import_rotary_emb_funcs() +print(dir(apply_rotary)) +patch_rotary_emb_funcs(apply_rotary) diff --git a/deeplink_ext/ops/rotary_embedding/_rotary_embedding.py b/deeplink_ext/ops/rotary_embedding/_rotary_embedding.py new file mode 100644 index 00000000..c5157ead --- /dev/null +++ b/deeplink_ext/ops/rotary_embedding/_rotary_embedding.py @@ -0,0 +1,11 @@ +# Copyright (c) 2024, DeepLink. + +from deeplink_ext.utils import PlatformType, deeplink_ext_get_platform_type + +platform_type = deeplink_ext_get_platform_type() +if platform_type == PlatformType.TORCH_DIPU || platform_type == PlatformType.TORCH_CUDA: + from ._rotary_embedding_dipu import ApplyRotaryEmb, ApplyRotaryEmbQKV_, apply_rotary +else: + raise ImportError + +__all__ = ["ApplyRotaryEmb", "ApplyRotaryEmbQKV_","apply_rotary"] diff --git a/deeplink_ext/ops/rotary_embedding/interntrain_rotary_embedding_dipu.py b/deeplink_ext/ops/rotary_embedding/_rotary_embedding_dipu.py similarity index 80% rename from deeplink_ext/ops/rotary_embedding/interntrain_rotary_embedding_dipu.py rename to deeplink_ext/ops/rotary_embedding/_rotary_embedding_dipu.py index ec1f55e0..e73e50f7 100644 --- a/deeplink_ext/ops/rotary_embedding/interntrain_rotary_embedding_dipu.py +++ b/deeplink_ext/ops/rotary_embedding/_rotary_embedding_dipu.py @@ -6,8 +6,9 @@ import deeplink_ext.cpp_extensions as ext assert hasattr(ext, "apply_rotary") +from deeplink_ext.cpp_extensions import apply_rotary -__all__ = ["ApplyRotaryEmb", "ApplyRotaryEmbQKV_"] +__all__ = ["ApplyRotaryEmb", "ApplyRotaryEmbQKV_","apply_rotary"] class ApplyRotaryEmb(torch.autograd.Function): @@ -31,16 +32,21 @@ def forward(ctx, x, cos, sin, interleaved=False): assert rotary_dim <= headdim assert seqlen <= rotary_seqlen assert sin.shape == (rotary_seqlen, rotary_dim // 2) + x_ro = x[..., :rotary_dim] + x1, x2 = x_ro.chunk(2, dim=-1) out = torch.empty_like(x) + out_ro = out[..., :rotary_dim] + o1, o2 = out_ro.chunk(2, dim=-1) re_cos = rearrange(cos[:seqlen], "s d -> s 1 d") re_sin = rearrange(sin[:seqlen], "s d -> s 1 d") - ext.apply_rotary( - out[..., :rotary_dim], - x[..., :rotary_dim], + apply_rotary( + x1, + x2, re_cos, re_sin, + o1, + o2, False, - interleaved, ) if rotary_dim < headdim: out[..., rotary_dim:].copy_(x[..., rotary_dim:]) @@ -54,14 +60,20 @@ def backward(ctx, do): headdim = do.shape[-1] rotary_dim = re_cos.shape[-1] rotary_dim *= 2 + do_ro = do[..., :rotary_dim] + do1, do2 = do_ro.chunk(2, dim=-1) dx = torch.empty_like(do) - ext.apply_rotary( - dx[..., :rotary_dim], - do[..., :rotary_dim], + dx_ro = dx[..., :rotary_dim] + dx1, dx2 = dx_ro.chunk(2, dim=-1) + + apply_rotary( + do1, + do2, re_cos, re_sin, + dx1, + dx2, True, - ctx.interleaved, ) if rotary_dim < headdim: dx[..., rotary_dim:].copy_(do[..., rotary_dim:]) @@ -107,6 +119,7 @@ def forward(ctx, qkv, cos, sin, cos_k=None, sin_k=None, interleaved=False): if len(qkv.shape) == 4 else qkv[:, :, 0, :, :rotary_dim] ) + q1, q2 = q_ro.chunk(2, dim=-1) re_cos = ( rearrange(cos, "s d -> s 1 d") if len(qkv.shape) == 4 @@ -117,13 +130,14 @@ def forward(ctx, qkv, cos, sin, cos_k=None, sin_k=None, interleaved=False): if len(qkv.shape) == 4 else rearrange(sin[:seqlen], "s d -> s 1 d") ) - ext.apply_rotary( - q_ro, - q_ro, + apply_rotary( + q1, + q2, re_cos, re_sin, - False, - interleaved, + q1, + q2, + False ) k_ro = ( @@ -131,6 +145,7 @@ def forward(ctx, qkv, cos, sin, cos_k=None, sin_k=None, interleaved=False): if len(qkv.shape) == 4 else qkv[:, :, 1, :, :rotary_dim] ) + k1, k2 = k_ro.chunk(2, dim=-1) re_cos_k = ( rearrange(cos_k, "s d -> s 1 d") if len(qkv.shape) == 4 @@ -141,13 +156,14 @@ def forward(ctx, qkv, cos, sin, cos_k=None, sin_k=None, interleaved=False): if len(qkv.shape) == 4 else rearrange(sin_k[:seqlen], "s d -> s 1 d") ) - ext.apply_rotary( - k_ro, - k_ro, + apply_rotary( + k1, + k2, re_cos_k, re_sin_k, - False, - interleaved, + k1, + k2, + False ) ctx.save_for_backward(re_cos, re_sin, re_cos_k, re_sin_k) @@ -165,13 +181,16 @@ def backward(ctx, dqkv): if len(dqkv.shape) == 4 else dqkv[:, :, 0, :, :rotary_dim] ) - ext.apply_rotary( - dq_ro, - dq_ro, + dq1, dq2 = dq_ro.chunk(2, dim=-1) + _torch_apply_rotary_func(dq1, dq2, re_cos, re_sin, dq1, dq2, True) + apply_rotary( + dq1, + dq2, re_cos, re_sin, - True, - ctx.interleaved, + dq1, + dq2, + True ) dk_ro = ( @@ -179,12 +198,14 @@ def backward(ctx, dqkv): if len(dqkv.shape) == 4 else dqkv[:, :, 1, :, :rotary_dim] ) - ext.apply_rotary( - dk_ro, - dk_ro, + dk1, dk2 = dk_ro.chunk(2, dim=-1) + apply_rotary( + dk1, + dk2, re_cos_k, re_sin_k, - True, - ctx.interleaved, + dk1, + dk2, + True ) return dqkv, None, None, None, None, None diff --git a/deeplink_ext/ops/rotary_embedding/interntrain_rotary_embedding_fallback.py b/deeplink_ext/ops/rotary_embedding/_rotary_embedding_fallback.py similarity index 98% rename from deeplink_ext/ops/rotary_embedding/interntrain_rotary_embedding_fallback.py rename to deeplink_ext/ops/rotary_embedding/_rotary_embedding_fallback.py index 4d31c0c0..84cd5e30 100644 --- a/deeplink_ext/ops/rotary_embedding/interntrain_rotary_embedding_fallback.py +++ b/deeplink_ext/ops/rotary_embedding/_rotary_embedding_fallback.py @@ -5,7 +5,7 @@ import torch from einops import rearrange -__all__ = ["ApplyRotaryEmbTorch", "ApplyRotaryEmbQKV_Torch"] +__all__ = ["ApplyRotaryEmbTorch", "ApplyRotaryEmbQKV_Torch","_torch_apply_rotary_func"] def _torch_apply_rotary_func( diff --git a/deeplink_ext/ops/rotary_embedding/internevo_rotary_embedding.py b/deeplink_ext/ops/rotary_embedding/internevo_rotary_embedding.py deleted file mode 100644 index 764e0206..00000000 --- a/deeplink_ext/ops/rotary_embedding/internevo_rotary_embedding.py +++ /dev/null @@ -1,11 +0,0 @@ -# Copyright (c) 2024, DeepLink. - -from deeplink_ext.utils import PlatformType, deeplink_ext_get_platform_type - -platform_type = deeplink_ext_get_platform_type() -if platform_type == PlatformType.TORCH_DIPU: - from .internevo_rotary_embedding_dipu import ApplyRotaryEmb -else: - raise ImportError - -__all__ = ["ApplyRotaryEmb"] diff --git a/deeplink_ext/ops/rotary_embedding/internevo_rotary_embedding_dipu.py b/deeplink_ext/ops/rotary_embedding/internevo_rotary_embedding_dipu.py deleted file mode 100644 index 7d5a4eb4..00000000 --- a/deeplink_ext/ops/rotary_embedding/internevo_rotary_embedding_dipu.py +++ /dev/null @@ -1,90 +0,0 @@ -# Copyright (c) 2024, DeepLink. - -import torch -from einops import rearrange -import deeplink_ext.cpp_extensions as ext - -assert hasattr(ext, "apply_rotary") - -__all__ = ["ApplyRotaryEmb"] - - -# adpated from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/layers/rotary.py#L35 -class ApplyRotaryEmb(torch.autograd.Function): - """ - ApplyRotaryEmb - """ - - @staticmethod - def forward( - ctx, - x: torch.Tensor, - cos: torch.Tensor, - sin: torch.Tensor, - interleaved: bool = False, - in_place: bool = False, - ): - """ - x: (batch_size, seqlen, nheads, headdim) - cos, sin: (seqlen, rotary_dim / 2) - interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead - of 1st half and 2nd half (GPT-NeoX style). - rotary_dim must be <= headdim - Apply rotary embedding to the first rotary_dim of x. - """ - *_, seqlen, _, head_dim = x.shape - rotary_seqlen, rotary_dim = cos.shape - rotary_dim *= 2 - - assert rotary_dim <= head_dim - assert seqlen <= rotary_seqlen - assert sin.shape == (rotary_seqlen, rotary_dim // 2) - - if in_place: - out = x - else: - out = torch.empty_like(x) - - ext.apply_rotary( - out[..., :rotary_dim], - x[..., :rotary_dim], - rearrange(cos[:seqlen], "s d -> s 1 d"), - rearrange(sin[:seqlen], "s d -> s 1 d"), - False, - interleaved, - ) - - if rotary_dim < head_dim and not in_place: - out[..., rotary_dim:].copy_(x[..., rotary_dim:]) - - ctx.save_for_backward(cos, sin) - ctx.interleaved = interleaved - ctx.in_place = in_place - - return out - - @staticmethod - def backward(ctx, do): - cos, sin = ctx.saved_tensors - *_, seqlen, _, head_dim = do.shape - rotary_dim = cos.shape[-1] - rotary_dim *= 2 - - if ctx.in_place: - dx = do - else: - dx = torch.empty_like(do) - - ext.apply_rotary( - dx[..., :rotary_dim], - do[..., :rotary_dim], - rearrange(cos[:seqlen], "s d -> s 1 d"), - rearrange(sin[:seqlen], "s d -> s 1 d"), - True, - ctx.interleaved, - ) - - if rotary_dim < head_dim and not ctx.in_place: - dx[..., rotary_dim:].copy_(do[..., rotary_dim:]) - - return dx, None, None, None, None diff --git a/deeplink_ext/ops/rotary_embedding/internevo_rotary_embedding_fallback.py b/deeplink_ext/ops/rotary_embedding/internevo_rotary_embedding_fallback.py deleted file mode 100644 index c8956025..00000000 --- a/deeplink_ext/ops/rotary_embedding/internevo_rotary_embedding_fallback.py +++ /dev/null @@ -1,141 +0,0 @@ -# Copyright (c) 2024, DeepLink. - -import torch -from einops import rearrange - -__all__ = ["ApplyRotaryEmbTorch"] - - -def _torch_apply_rotary_func( - x1: torch.Tensor, - x2: torch.Tensor, - cos: torch.Tensor, - sin: torch.Tensor, - out1: torch.Tensor, - out2: torch.Tensor, - conj: bool = False, -): - assert ( - x1.device == x2.device == cos.device == sin.device - ), "All inputs must be on the same device" - assert ( - x1.dtype == x2.dtype == cos.dtype == sin.dtype - ), "All inputs must have the same dtype" - assert x1.size() == x2.size(), "Input x1 and x2 must have the same sizes" - assert cos.size() == sin.size(), "Input cos and sin must have the same sizes" - - x1, x2, cos, sin = x1.float(), x2.float(), cos.float(), sin.float() - - if conj: - out1.copy_(x1 * cos + x2 * sin) - out2.copy_(-x1 * sin + x2 * cos) - else: - out1.copy_(x1 * cos - x2 * sin) - out2.copy_(x1 * sin + x2 * cos) - - -# adpated from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/layers/rotary.py#L35 -class ApplyRotaryEmbTorch(torch.autograd.Function): - """ - ApplyRotaryEmbTorch - """ - - @staticmethod - def forward( - ctx, - x: torch.Tensor, - cos: torch.Tensor, - sin: torch.Tensor, - interleaved: bool = False, - in_place: bool = False, - ): - """ - x: (batch_size, seqlen, nheads, headdim) - cos, sin: (seqlen, rotary_dim / 2) - interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead - of 1st half and 2nd half (GPT-NeoX style). - rotary_dim must be <= headdim - Apply rotary embedding to the first rotary_dim of x. - """ - *_, seqlen, _, head_dim = x.shape - rotary_seqlen, rotary_dim = cos.shape - rotary_dim *= 2 - - assert rotary_dim <= head_dim - assert seqlen <= rotary_seqlen - assert sin.shape == (rotary_seqlen, rotary_dim // 2) - - x_ro = x[..., :rotary_dim] - x1, x2 = ( - (x_ro[..., ::2], x_ro[..., 1::2]) if interleaved else x_ro.chunk(2, dim=-1) - ) - - if in_place: - out, o1, o2 = x, x1, x2 - else: - out = torch.empty_like(x) - out_ro = out[..., :rotary_dim] - o1, o2 = ( - (out_ro[..., ::2], out_ro[..., 1::2]) - if interleaved - else out_ro.chunk(2, dim=-1) - ) - - _torch_apply_rotary_func( - x1, - x2, - rearrange(cos[:seqlen], "s d -> s 1 d"), - rearrange(sin[:seqlen], "s d -> s 1 d"), - o1, - o2, - False, - ) - - if rotary_dim < head_dim and not in_place: - out[..., rotary_dim:].copy_(x[..., rotary_dim:]) - - ctx.save_for_backward(cos, sin) - ctx.interleaved = interleaved - ctx.in_place = in_place - - return out - - @staticmethod - def backward(ctx, do): - cos, sin = ctx.saved_tensors - *_, seqlen, _, head_dim = do.shape - rotary_dim = cos.shape[-1] - rotary_dim *= 2 - - do_ro = do[..., :rotary_dim] - do1, do2 = ( - (do_ro[..., ::2], do_ro[..., 1::2]) - if ctx.interleaved - else do_ro.chunk(2, dim=-1) - ) - - if ctx.in_place: - dx, dx1, dx2 = do, do1, do2 - else: - dx = torch.empty_like(do) - dx_ro = dx[..., :rotary_dim] - dx1, dx2 = ( - (dx_ro[..., ::2], dx_ro[..., 1::2]) - if ctx.interleaved - else dx_ro.chunk(2, dim=-1) - ) - - _torch_apply_rotary_func( - do1, - do2, - rearrange(cos[:seqlen], "s d -> s 1 d"), - rearrange(sin[:seqlen], "s d -> s 1 d"), - dx1, - dx2, - True, - ) - - if rotary_dim < head_dim and not ctx.in_place: - dx[..., rotary_dim:].copy_(do[..., rotary_dim:]) - - return dx, None, None, None, None diff --git a/deeplink_ext/ops/rotary_embedding/interntrain_rotary_embedding.py b/deeplink_ext/ops/rotary_embedding/interntrain_rotary_embedding.py deleted file mode 100644 index 61f98c64..00000000 --- a/deeplink_ext/ops/rotary_embedding/interntrain_rotary_embedding.py +++ /dev/null @@ -1,11 +0,0 @@ -# Copyright (c) 2024, DeepLink. - -from deeplink_ext.utils import PlatformType, deeplink_ext_get_platform_type - -platform_type = deeplink_ext_get_platform_type() -if platform_type == PlatformType.TORCH_DIPU: - from .interntrain_rotary_embedding_dipu import ApplyRotaryEmb, ApplyRotaryEmbQKV_ -else: - raise ImportError - -__all__ = ["ApplyRotaryEmb", "ApplyRotaryEmbQKV_"] diff --git a/deeplink_ext/ops/rotary_embedding/rotary_emb_utils.py b/deeplink_ext/ops/rotary_embedding/rotary_emb_utils.py new file mode 100644 index 00000000..3fce626d --- /dev/null +++ b/deeplink_ext/ops/rotary_embedding/rotary_emb_utils.py @@ -0,0 +1,32 @@ +_not_impl = "[deeplink_ext] {op_name} is not implemented in diopi. Falling back to the slower torch implementation." + +def patch_rotary_emb_funcs( + apply_rotary +): + try: + import rotary_emb + except Exception as e: + print("Unable to import rotary_emb, skip mocking flash_attn") + return + + rotary_emb.apply_rotary = apply_rotary + +def import_rotary_emb_funcs(): + try: + from ._rotary_embedding_dipu import ( + ApplyRotaryEmb, + ApplyRotaryEmbQKV_, + apply_rotary, + ) + except Exception as e: + print(_not_impl.format(op_name="flash attention")) + from ._rotary_embedding_fallback import ( + ApplyRotaryEmbTorch as ApplyRotaryEmb, + ApplyRotaryEmbQKV_Torch as ApplyRotaryEmbQKV_, + _torch_apply_rotary_func as apply_rotary, + ) + return ( + ApplyRotaryEmb, + ApplyRotaryEmbQKV_, + apply_rotary, + ) From 998040b4e2a2222cf892eee37efa83b8675a01ec Mon Sep 17 00:00:00 2001 From: HuayiL <442488254@qq.com> Date: Tue, 19 Nov 2024 18:00:33 +0800 Subject: [PATCH 7/7] fix a bug which miss to patch rmsnorm --- deeplink_ext/ops/rms_norm/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/deeplink_ext/ops/rms_norm/__init__.py b/deeplink_ext/ops/rms_norm/__init__.py index e6745aa0..b1db2ad8 100644 --- a/deeplink_ext/ops/rms_norm/__init__.py +++ b/deeplink_ext/ops/rms_norm/__init__.py @@ -11,3 +11,4 @@ from .rms_norm_utils import import_RMSNorm, patch_RMSNorm MixedFusedRMSNorm = import_RMSNorm() +patch_RMSNorm(MixedFusedRMSNorm) \ No newline at end of file