Skip to content

Commit

Permalink
ascend speed rotary
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangzefeng92 committed Mar 29, 2024
1 parent 50a29d2 commit 29b9d8e
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 28 deletions.
Empty file.
31 changes: 31 additions & 0 deletions deeplink_ext/ascend_speed/rotary/deeplink.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import torch
from typing import Optional, Union
import deeplink_ext.cpp_extensions as ext


def apply_rotary_for_ascend_speed(
x: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
seqlen_offsets: Union[int, torch.Tensor] = 0,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[int] = None,
interleaved=False,
inplace=False,
conjugate=False,
) -> torch.Tensor:
output = torch.empty_like(x)
ext.apply_rotary(output, x, cos, sin, conjugate, interleaved)
return output


class RotaryEmbedding_AscendSpeed(torch.autograd.Function):
@staticmethod
def forward(ctx, t, cos, sin):
ctx.save_for_backward(cos, sin)
return apply_rotary_for_ascend_speed(t, cos, sin)

@staticmethod
def backward(ctx, t):
cos, sin = ctx.saved_tensors
return apply_rotary_for_ascend_speed(t, cos, sin, conjugate=True), None, None
28 changes: 0 additions & 28 deletions deeplink_ext/internlm_ops/rotary/deeplink.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,31 +63,3 @@ def apply_rotary(
interleaved,
)
return output


def apply_rotary_for_ascend_speed(
x: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
seqlen_offsets: Union[int, torch.Tensor] = 0,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[int] = None,
interleaved=False,
inplace=False,
conjugate=False,
) -> torch.Tensor:
output = torch.empty_like(x)
ext.apply_rotary(output, x, cos, sin, conjugate, interleaved)
return output


class RotaryEmbedding_AscendSpeed(torch.autograd.Function):
@staticmethod
def forward(ctx, t, cos, sin):
ctx.save_for_backward(cos, sin)
return apply_rotary_for_ascend_speed(t, cos, sin)

@staticmethod
def backward(ctx, t):
cos, sin = ctx.saved_tensors
return apply_rotary_for_ascend_speed(t, cos, sin, conjugate=True), None, None

0 comments on commit 29b9d8e

Please sign in to comment.