Skip to content

Commit

Permalink
run clang format
Browse files Browse the repository at this point in the history
  • Loading branch information
Wrench-Git committed Nov 5, 2024
1 parent 0b6718c commit 8beeed6
Show file tree
Hide file tree
Showing 13 changed files with 20 additions and 35 deletions.
1 change: 0 additions & 1 deletion deeplink_ext/ascend_speed/_flash_attention_dipu.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@


class FlashSelfAttention(torch.autograd.Function):

@staticmethod
def forward(
ctx, q, k, v, attention_mask, dropout_p, softmax_scale, head_num, input_layout
Expand Down
1 change: 0 additions & 1 deletion deeplink_ext/ascend_speed/_rms_norm_dipu.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@


class RMSNorm(torch.autograd.Function):

@staticmethod
def forward(ctx, hidden_states, weight, eps):
output = torch.empty_like(hidden_states)
Expand Down
1 change: 0 additions & 1 deletion deeplink_ext/ascend_speed/_scaled_masked_softmax_dipu.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@


class ScaledMaskedSoftmax(torch.autograd.Function):

@staticmethod
def forward(ctx, input, mask, scale, fixed_triu_mask):
out = torch.empty_like(input)
Expand Down
1 change: 0 additions & 1 deletion deeplink_ext/ascend_speed/_scaled_masked_softmax_npu.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@


class ScaledMaskedSoftmax(torch.autograd.Function):

@staticmethod
def forward(ctx, input, mask, scale, fixed_triu_mask):
out = torch_npu.npu_scaled_masked_softmax(input, mask, scale, fixed_triu_mask)
Expand Down
1 change: 0 additions & 1 deletion deeplink_ext/ops/adamw/_adamw_dipu.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ def fused_adamw(


class AdamW(Optimizer):

def __init__(
self,
params,
Expand Down
1 change: 1 addition & 0 deletions deeplink_ext/ops/bert_padding/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .bert_padding import pad_input, unpad_input, index_first_axis

__all__ = ["pad_input", "unpad_input", "index_first_axis"]
8 changes: 6 additions & 2 deletions deeplink_ext/ops/flash_attention/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,5 +24,9 @@
from .interntrain_flash_attention import FlashSelfAttention, FlashCrossAttention
except Exception as e:
print(_not_impl.format(op_name="flash attention"))
from .interntrain_flash_attention_fallback import SelfAttention as FlashSelfAttention
from .interntrain_flash_attention_fallback import CrossAttention as FlashCrossAttention
from .interntrain_flash_attention_fallback import (
SelfAttention as FlashSelfAttention,
)
from .interntrain_flash_attention_fallback import (
CrossAttention as FlashCrossAttention,
)
12 changes: 0 additions & 12 deletions deeplink_ext/ops/flash_attention/internevo_flash_attention_dipu.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@


class FlashAttnQKVPackedFunc(torch.autograd.Function):

@staticmethod
def forward(
ctx,
Expand Down Expand Up @@ -108,7 +107,6 @@ def backward(ctx, dout, *args):


class CustomizedFlashAttnQKVPackedFunc(torch.autograd.Function):

@staticmethod
def forward(
ctx,
Expand Down Expand Up @@ -254,7 +252,6 @@ def flash_attn_qkvpacked_func(


class FlashAttnKVPackedFunc(torch.autograd.Function):

@staticmethod
def forward(
ctx,
Expand Down Expand Up @@ -344,7 +341,6 @@ def backward(ctx, dout, *args):


class CustomizedFlashAttnKVPackedFunc(torch.autograd.Function):

@staticmethod
def forward(
ctx,
Expand Down Expand Up @@ -498,7 +494,6 @@ def flash_attn_kvpacked_func(


class FlashAttnFunc(torch.autograd.Function):

@staticmethod
def forward(
ctx,
Expand Down Expand Up @@ -592,7 +587,6 @@ def backward(ctx, dout, *args):


class CustomizedFlashAttnFunc(torch.autograd.Function):

@staticmethod
def forward(
ctx,
Expand Down Expand Up @@ -753,7 +747,6 @@ def flash_attn_func(


class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):

@staticmethod
def forward(
ctx,
Expand Down Expand Up @@ -850,7 +843,6 @@ def backward(ctx, dout, *args):


class CustomizedFlashAttnVarlenQKVPackedFunc(torch.autograd.Function):

@staticmethod
def forward(
ctx,
Expand Down Expand Up @@ -1007,7 +999,6 @@ def flash_attn_varlen_qkvpacked_func(


class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):

@staticmethod
def forward(
ctx,
Expand Down Expand Up @@ -1112,7 +1103,6 @@ def backward(ctx, dout, *args):


class CustomizedFlashAttnVarlenKVPackedFunc(torch.autograd.Function):

@staticmethod
def forward(
ctx,
Expand Down Expand Up @@ -1288,7 +1278,6 @@ def flash_attn_varlen_kvpacked_func(


class FlashAttnVarlenFunc(torch.autograd.Function):

@staticmethod
def forward(
ctx,
Expand Down Expand Up @@ -1414,7 +1403,6 @@ def backward(ctx, dout, *args):


class CustomizedFlashAttnVarlenFunc(torch.autograd.Function):

@staticmethod
def forward(
ctx,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@

platform_type = deeplink_ext_get_platform_type()
if platform_type == PlatformType.TORCH_DIPU:
from .interntrain_flash_attention_dipu import FlashSelfAttention, FlashCrossAttention
from .interntrain_flash_attention_dipu import (
FlashSelfAttention,
FlashCrossAttention,
)
else:
raise ImportError

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@


class CustomizedFlashAttentionQKVPackedFunc(torch.autograd.Function):

@staticmethod
def forward(
ctx,
Expand Down Expand Up @@ -206,7 +205,6 @@ def backward(ctx, dout):


class FlashAttentionQKVPackedFunc(torch.autograd.Function):

@staticmethod
def forward(
ctx,
Expand Down Expand Up @@ -359,7 +357,6 @@ def backward(ctx, dout):


class CustomizedFlashAttentionVarlenQKVPackedFunc(torch.autograd.Function):

@staticmethod
def forward(
ctx,
Expand Down Expand Up @@ -560,7 +557,6 @@ def backward(ctx, dout):


class FlashAttentionVarlenQKVPackedFunc(torch.autograd.Function):

@staticmethod
def forward(
ctx,
Expand Down Expand Up @@ -738,7 +734,6 @@ def backward(ctx, dout):


class CustomizedFlashAttentionKVPackedFunc(torch.autograd.Function):

@staticmethod
def forward(ctx, q, kv, dropout_p, softmax_scale, causal):
assert q.device == kv.device, "the devices of q and kv should be same"
Expand Down Expand Up @@ -842,7 +837,6 @@ def backward(ctx, dout):


class FlashAttentionKVPackedFunc(torch.autograd.Function):

@staticmethod
def forward(ctx, q, kv, dropout_p, softmax_scale, causal):
assert q.device == kv.device, "the devices of q and kv should be same"
Expand Down Expand Up @@ -920,7 +914,6 @@ def backward(ctx, dout):


class CustomizedFlashAttentionVarlenKVPackedFunc(torch.autograd.Function):

@staticmethod
def forward(
ctx,
Expand Down Expand Up @@ -1045,7 +1038,6 @@ def backward(ctx, dout):


class FlashAttentionVarlenKVPackedFunc(torch.autograd.Function):

@staticmethod
def forward(
ctx,
Expand Down Expand Up @@ -1266,7 +1258,6 @@ def forward(


class FlashCrossAttention(nn.Module):

def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
super().__init__()
self.causal = causal
Expand Down
1 change: 0 additions & 1 deletion deeplink_ext/ops/rms_norm/easyllm_rms_norm_dipu.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@


class RMSNorm(torch.autograd.Function):

@staticmethod
def forward(ctx, hidden_states, weight, eps):
output = torch.empty_like(hidden_states)
Expand Down
2 changes: 0 additions & 2 deletions deeplink_ext/ops/rms_norm/internevo_mixed_rms_norm_dipu.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
# as output tensor's dtype while FusedLayerNorm uses input tensor's dtype for output tensor's dtype.
# See: `layer_norm_affine` and `layer_norm_affine_mixed_dtypes` in "csrc/layer_norm_cuda.cpp"
class _MixedFusedRMSNormFunction(torch.autograd.Function):

@staticmethod
def forward(ctx, hidden_states, weight, eps, normalized_shape):
# ascend currently does not support dtype of hidden_states with higher precision than weight.
Expand Down Expand Up @@ -94,7 +93,6 @@ def backward(ctx, grad_output):


class MixedFusedRMSNorm(torch.nn.Module):

def __init__(self, normalized_shape, eps=1e-5, add_unit_offset=False):
# TODO: Further optimization when there are device and dtype available.
# factory_kwargs = {"device": device, "dtype": dtype}
Expand Down
12 changes: 9 additions & 3 deletions deeplink_ext/ops/rotary_embedding/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,17 @@
from .internevo_rotary_embedding import ApplyRotaryEmb
except:
print(_not_impl.format(op_name="rotary embedding"))
from .internevo_rotary_embedding_fallback import ApplyRotaryEmbTorch as ApplyRotaryEmb
from .internevo_rotary_embedding_fallback import (
ApplyRotaryEmbTorch as ApplyRotaryEmb,
)

try:
from .interntrain_rotary_embedding import ApplyRotaryEmb, ApplyRotaryEmbQKV_
except:
print(_not_impl.format(op_name="rotary embedding"))
from .interntrain_rotary_embedding_fallback import ApplyRotaryEmbTorch as ApplyRotaryEmb
from .interntrain_rotary_embedding_fallback import ApplyRotaryEmbQKV_Torch as ApplyRotaryEmbQKV_
from .interntrain_rotary_embedding_fallback import (
ApplyRotaryEmbTorch as ApplyRotaryEmb,
)
from .interntrain_rotary_embedding_fallback import (
ApplyRotaryEmbQKV_Torch as ApplyRotaryEmbQKV_,
)

0 comments on commit 8beeed6

Please sign in to comment.