Skip to content

Commit

Permalink
reimpl rotary embedding for npu
Browse files Browse the repository at this point in the history
  • Loading branch information
jingguo-st committed Sep 10, 2024
1 parent ad193f8 commit cf21627
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 298 deletions.
117 changes: 20 additions & 97 deletions deeplink_ext/internevo_ops/_rotary_embedding_npu.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,110 +2,33 @@

import torch
import torch_npu
from einops import rearrange

__all__ = ["ApplyRotaryEmb"]
__all__ = ["RotaryEmbedding"]


def _unsqueeze_to_4d(x: torch.Tensor):
while x.dim() < 4:
x = x.unsqueeze(0)
return x


def _apply_rotary(x: torch.Tensor, cos, sin, confj, interleaved):
assert interleaved == False, "interleaved not support by torch_npu"

x_view = _unsqueeze_to_4d(x)
cos_view = _unsqueeze_to_4d(cos)
sin_view = _unsqueeze_to_4d(sin)

cos_cat = torch.cat([cos_view, cos_view], -1)
sin_cat = torch.cat([sin_view, sin_view], -1)

if confj:
sin_cat.neg_()

x_view_chunks = x_view.chunk(2, -1)
x_view_new = torch.cat([-x_view_chunks[1], x_view_chunks[0]], -1)

cos_x = torch.mul(cos_cat, x_view)
sin_x = torch.mul(sin_cat, x_view_new)
out = cos_x + sin_x

return out


# adpated from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/layers/rotary.py#L35
class ApplyRotaryEmb(torch.autograd.Function):
class RotaryEmbedding(torch.autograd.Function):
"""
ApplyRotaryEmb
Apply rotary positional embedding to input tensor x.
Args:
x (Tensor): Input tensor x is of shape [seq_length, ... , dim]
cos (Tensor): Input tensor cos is of shape [seq_length, ..., dim]
sin (Tensor): Input tensor sin is of shape [seq_length, ..., dim]
Returns:
Tensor: The input tensor after applying RoPE
"""

@staticmethod
def forward(
ctx,
x: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
interleaved: bool = False,
in_place: bool = False,
):
"""
x: (batch_size, seqlen, nheads, headdim)
cos, sin: (seqlen, rotary_dim / 2)
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
of 1st half and 2nd half (GPT-NeoX style).
rotary_dim must be <= headdim
Apply rotary embedding to the first rotary_dim of x.
"""
*_, seqlen, _, head_dim = x.shape
rotary_seqlen, rotary_dim = cos.shape
rotary_dim *= 2

assert rotary_dim <= head_dim
assert seqlen <= rotary_seqlen
assert sin.shape == (rotary_seqlen, rotary_dim // 2)

out = _apply_rotary(
x[..., :rotary_dim],
rearrange(cos[:seqlen], "s d -> s 1 d"),
rearrange(sin[:seqlen], "s d -> s 1 d"),
False,
interleaved,
)

ctx.save_for_backward(cos, sin)
ctx.interleaved = interleaved
ctx.in_place = in_place

if in_place:
x[..., :rotary_dim].copy_(out[..., :rotary_dim])
return x
else:
if rotary_dim < head_dim:
out[..., rotary_dim:].copy_(x[..., rotary_dim:])
return out
def forward(ctx, x, cos, sin):
out = torch_npu.npu_rotary_mul(x, cos, sin)
ctx.save_for_backward(out, cos, sin)
return out

@staticmethod
def backward(ctx, do):
cos, sin = ctx.saved_tensors
*_, seqlen, _, head_dim = do.shape
rotary_dim = cos.shape[-1]
rotary_dim *= 2

out = _apply_rotary(
do[..., :rotary_dim],
rearrange(cos[:seqlen], "s d -> s 1 d"),
rearrange(sin[:seqlen], "s d -> s 1 d"),
True,
ctx.interleaved,
def backward(ctx, grad_output):
out, cos, sin = ctx.saved_tensors
return (
torch_npu.npu_rotary_mul_backward(grad_output, out, cos, sin)[0],
None,
None,
)

if ctx.in_place:
do[..., :rotary_dim].copy_(out[..., :rotary_dim])
return do, None, None, None, None
else:
if rotary_dim < head_dim:
out[..., rotary_dim:].copy(do[..., rotary_dim:])
return out, None, None, None, None
220 changes: 19 additions & 201 deletions deeplink_ext/interntrain_ops/_rotary_embedding_npu.py
Original file line number Diff line number Diff line change
@@ -1,216 +1,34 @@
# Copyright (c) 2024, DeepLink.
# Copyright (c) 2024, InternEvo.

import torch
import torch_npu
from einops import rearrange

__all__ = ["ApplyRotaryEmb", "ApplyRotaryEmbQKV_"]
__all__ = ["RotaryEmbedding"]


def _unsqueeze_to_4d(x: torch.Tensor):
while x.dim() < 4:
x = x.unsqueeze(0)
return x


def _apply_rotary(x: torch.Tensor, cos, sin, confj, interleaved):
assert interleaved == False, "interleaved not support by torch_npu"

x_view = _unsqueeze_to_4d(x)
cos_view = _unsqueeze_to_4d(cos)
sin_view = _unsqueeze_to_4d(sin)

cos_cat = torch.cat([cos_view, cos_view], -1)
sin_cat = torch.cat([sin_view, sin_view], -1)

if confj:
sin_cat.neg_()

x_view_chunks = x_view.chunk(2, -1)
x_view_new = torch.cat([-x_view_chunks[1], x_view_chunks[0]], -1)

cos_x = torch.mul(cos_cat, x_view)
sin_x = torch.mul(sin_cat, x_view_new)
out = cos_x + sin_x

return out


class ApplyRotaryEmb(torch.autograd.Function):
class RotaryEmbedding(torch.autograd.Function):
"""
ApplyRotaryEmb
Apply rotary positional embedding to input tensor x.
Args:
x (Tensor): Input tensor x is of shape [seq_length, ... , dim]
cos (Tensor): Input tensor cos is of shape [seq_length, ..., dim]
sin (Tensor): Input tensor sin is of shape [seq_length, ..., dim]
Returns:
Tensor: The input tensor after applying RoPE
"""

@staticmethod
def forward(ctx, x, cos, sin, interleaved=False):
"""
x: (batch_size, seqlen, nheads, headdim)
cos, sin: (seqlen, rotary_dim / 2)
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
of 1st half and 2nd half (GPT-NeoX style).
rotary_dim must be <= headdim
Apply rotary embedding to the first rotary_dim of x.
"""
_, seqlen, _, headdim = x.shape
rotary_seqlen, rotary_dim = cos.shape
rotary_dim *= 2
assert rotary_dim <= headdim
assert seqlen <= rotary_seqlen
assert sin.shape == (rotary_seqlen, rotary_dim // 2)
out = torch.empty_like(x)
re_cos = rearrange(cos[:seqlen], "s d -> s 1 d")
re_sin = rearrange(sin[:seqlen], "s d -> s 1 d")
out = _apply_rotary(
x[..., :rotary_dim],
re_cos,
re_sin,
False,
interleaved,
)
if rotary_dim < headdim:
out[..., rotary_dim:].copy_(x[..., rotary_dim:])
ctx.save_for_backward(re_cos, re_sin)
ctx.interleaved = interleaved
def forward(ctx, x, cos, sin):
out = torch_npu.npu_rotary_mul(x, cos, sin)
ctx.save_for_backward(out, cos, sin)
return out

@staticmethod
def backward(ctx, do):
re_cos, re_sin = ctx.saved_tensors
headdim = do.shape[-1]
rotary_dim = re_cos.shape[-1]
rotary_dim *= 2
dx = _apply_rotary(
do[..., :rotary_dim],
re_cos,
re_sin,
True,
ctx.interleaved,
)
if rotary_dim < headdim:
dx[..., rotary_dim:].copy_(do[..., rotary_dim:])
return dx, None, None, None


class ApplyRotaryEmbQKV_(torch.autograd.Function):
"""
ApplyRotaryEmbQKV_
"""

@staticmethod
def forward(ctx, qkv, cos, sin, cos_k=None, sin_k=None, interleaved=False):
"""
qkv: (total, 3, nheads, headdim) / (batch_size, seqlen, 3, nheads, headdim)
cos, sin: (seqlen, rotary_dim / 2)
cos_k, sin_k: (seqlen, rotary_dim / 2), optional
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of
1st half and 2nd half (GPT-NeoX style).
rotary_dim must be <= headdim
Apply rotary embedding *inplace* to the first rotary_dim of q and k.
"""
# len(qkv.shape) == 4 means the format of qkv is (total, 3, nheads, headdim) which is packed,
# otherwise the format of qkv is (batch_size, seqlen, 3, nheads, headdim) which is unpacked.
# We handle both packed qkv and unpacked qkv scenario in this class.
three = qkv.shape[1] if len(qkv.shape) == 4 else qkv.shape[2]
assert three == 3
seqlen = None if len(qkv.shape) == 4 else qkv.shape[1]
rotary_seqlen, rotary_dim = cos.shape
if len(qkv.shape) != 4:
assert seqlen <= rotary_seqlen
headdim = qkv.shape[-1]
rotary_dim *= 2
assert rotary_dim <= headdim
cos_k = cos if cos_k is None else cos_k
sin_k = sin if sin_k is None else sin_k
assert (
sin.shape == cos_k.shape == sin_k.shape == (rotary_seqlen, rotary_dim // 2)
)

q_ro = (
qkv[:, 0, :, :rotary_dim]
if len(qkv.shape) == 4
else qkv[:, :, 0, :, :rotary_dim]
)
re_cos = (
rearrange(cos, "s d -> s 1 d")
if len(qkv.shape) == 4
else rearrange(cos[:seqlen], "s d -> s 1 d")
)
re_sin = (
rearrange(sin, "s d -> s 1 d")
if len(qkv.shape) == 4
else rearrange(sin[:seqlen], "s d -> s 1 d")
)

# qro
out = _apply_rotary(
q_ro,
re_cos,
re_sin,
False,
interleaved,
)
q_ro.copy_(out)

k_ro = (
qkv[:, 1, :, :rotary_dim]
if len(qkv.shape) == 4
else qkv[:, :, 1, :, :rotary_dim]
)
re_cos_k = (
rearrange(cos_k, "s d -> s 1 d")
if len(qkv.shape) == 4
else rearrange(cos_k[:seqlen], "s d -> s 1 d")
)
re_sin_k = (
rearrange(sin_k, "s d -> s 1 d")
if len(qkv.shape) == 4
else rearrange(sin_k[:seqlen], "s d -> s 1 d")
)
out = _apply_rotary(
k_ro,
re_cos_k,
re_sin_k,
False,
interleaved,
)
k_ro.copy_(out)

ctx.save_for_backward(re_cos, re_sin, re_cos_k, re_sin_k)
ctx.interleaved = interleaved
return qkv

@staticmethod
def backward(ctx, dqkv):
re_cos, re_sin, re_cos_k, re_sin_k = ctx.saved_tensors
rotary_dim = re_cos.shape[-1]
rotary_dim *= 2

dq_ro = (
dqkv[:, 0, :, :rotary_dim]
if len(dqkv.shape) == 4
else dqkv[:, :, 0, :, :rotary_dim]
)
out = _apply_rotary(
dq_ro,
re_cos,
re_sin,
True,
ctx.interleaved,
)
dq_ro.copy_(out)

dk_ro = (
dqkv[:, 1, :, :rotary_dim]
if len(dqkv.shape) == 4
else dqkv[:, :, 1, :, :rotary_dim]
)
out = _apply_rotary(
dk_ro,
re_cos_k,
re_sin_k,
True,
ctx.interleaved,
def backward(ctx, grad_output):
out, cos, sin = ctx.saved_tensors
return (
torch_npu.npu_rotary_mul_backward(grad_output, out, cos, sin)[0],
None,
None,
)
dk_ro.copy_(out)
return dqkv, None, None, None, None, None

0 comments on commit cf21627

Please sign in to comment.