From 29b9d8e2b2f16548e830f16234040d5d495c098e Mon Sep 17 00:00:00 2001 From: zhangzefeng92 Date: Fri, 29 Mar 2024 17:55:58 +0800 Subject: [PATCH] ascend speed rotary --- deeplink_ext/ascend_speed/__init__.py | 0 deeplink_ext/ascend_speed/rotary/deeplink.py | 31 ++++++++++++++++++++ deeplink_ext/internlm_ops/rotary/deeplink.py | 28 ------------------ 3 files changed, 31 insertions(+), 28 deletions(-) create mode 100644 deeplink_ext/ascend_speed/__init__.py create mode 100644 deeplink_ext/ascend_speed/rotary/deeplink.py diff --git a/deeplink_ext/ascend_speed/__init__.py b/deeplink_ext/ascend_speed/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/deeplink_ext/ascend_speed/rotary/deeplink.py b/deeplink_ext/ascend_speed/rotary/deeplink.py new file mode 100644 index 00000000..711562ec --- /dev/null +++ b/deeplink_ext/ascend_speed/rotary/deeplink.py @@ -0,0 +1,31 @@ +import torch +from typing import Optional, Union +import deeplink_ext.cpp_extensions as ext + + +def apply_rotary_for_ascend_speed( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + seqlen_offsets: Union[int, torch.Tensor] = 0, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + interleaved=False, + inplace=False, + conjugate=False, +) -> torch.Tensor: + output = torch.empty_like(x) + ext.apply_rotary(output, x, cos, sin, conjugate, interleaved) + return output + + +class RotaryEmbedding_AscendSpeed(torch.autograd.Function): + @staticmethod + def forward(ctx, t, cos, sin): + ctx.save_for_backward(cos, sin) + return apply_rotary_for_ascend_speed(t, cos, sin) + + @staticmethod + def backward(ctx, t): + cos, sin = ctx.saved_tensors + return apply_rotary_for_ascend_speed(t, cos, sin, conjugate=True), None, None diff --git a/deeplink_ext/internlm_ops/rotary/deeplink.py b/deeplink_ext/internlm_ops/rotary/deeplink.py index d1e593d9..670a47b9 100644 --- a/deeplink_ext/internlm_ops/rotary/deeplink.py +++ b/deeplink_ext/internlm_ops/rotary/deeplink.py @@ -63,31 +63,3 @@ def apply_rotary( interleaved, ) return output - - -def apply_rotary_for_ascend_speed( - x: torch.Tensor, - cos: torch.Tensor, - sin: torch.Tensor, - seqlen_offsets: Union[int, torch.Tensor] = 0, - cu_seqlens: Optional[torch.Tensor] = None, - max_seqlen: Optional[int] = None, - interleaved=False, - inplace=False, - conjugate=False, -) -> torch.Tensor: - output = torch.empty_like(x) - ext.apply_rotary(output, x, cos, sin, conjugate, interleaved) - return output - - -class RotaryEmbedding_AscendSpeed(torch.autograd.Function): - @staticmethod - def forward(ctx, t, cos, sin): - ctx.save_for_backward(cos, sin) - return apply_rotary_for_ascend_speed(t, cos, sin) - - @staticmethod - def backward(ctx, t): - cos, sin = ctx.saved_tensors - return apply_rotary_for_ascend_speed(t, cos, sin, conjugate=True), None, None