Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(rotary,internlm): add DeepLinkApplyRotaryEmb (fallback unimpleme… #44

Merged
merged 1 commit into from
Feb 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions deeplink_ext/internlm_ops/rotary/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
# Copyright (c) 2024, DeepLink.

try:
from .deeplink import DeepLinkApplyRotaryEmbQKV_
from .deeplink import DeepLinkApplyRotaryEmbQKV_, DeepLinkApplyRotaryEmb
except:
print(
"[deeplink_ext] rotary is not implemented in diopi. Falling back to the slower implementation.\n",
end="",
)
from .fallback import ApplyRotaryEmbQKV_ as DeepLinkApplyRotaryEmbQKV_
from .fallback import (
ApplyRotaryEmbQKV_ as DeepLinkApplyRotaryEmbQKV_,
ApplyRotaryEmb as DeepLinkApplyRotaryEmb,
)
from . import fallback
60 changes: 60 additions & 0 deletions deeplink_ext/internlm_ops/rotary/deeplink.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,3 +69,63 @@ def backward(ctx, dqkv):
interleaved,
)
return dqkv, None, None, None, None, None


class DeepLinkApplyRotaryEmb(torch.autograd.Function):
@staticmethod
def forward(ctx, x, cos, sin, interleaved=False, inplace=False):
"""
x: (batch_size, seqlen, nheads, headdim)
cos, sin: (seqlen, rotary_dim / 2)
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
of 1st half and 2nd half (GPT-NeoX style).
rotary_dim must be <= headdim
Apply rotary embedding to the first rotary_dim of x.
"""
batch, seqlen, nheads, headdim = x.shape
rotary_seqlen, rotary_dim = cos.shape
rotary_dim *= 2
assert rotary_dim <= headdim
assert seqlen <= rotary_seqlen
assert sin.shape == (rotary_seqlen, rotary_dim // 2)
x_ro = x[..., :rotary_dim]
out = torch.empty_like(x) if not inplace else x
out_ro = out[..., :rotary_dim]

ext.apply_rotary(
out_ro,
x_ro,
rearrange(cos[:seqlen], "s d -> s 1 d"),
rearrange(sin[:seqlen], "s d -> s 1 d"),
False,
interleaved,
)

if not inplace and rotary_dim < headdim:
out[..., rotary_dim:].copy_(x[..., rotary_dim:])
ctx.save_for_backward(cos, sin)
ctx.interleaved = interleaved
ctx.inplace = inplace
return out if not inplace else x

@staticmethod
def backward(ctx, do):
cos, sin = ctx.saved_tensors
_, seqlen, _, headdim = do.shape
rotary_dim = cos.shape[-1]
rotary_dim *= 2
inplace = ctx.inplace
do_ro = do[..., :rotary_dim]
dx = torch.empty_like(do) if not inplace else do
dx_ro = dx[..., :rotary_dim]
ext.apply_rotary(
dx_ro,
do_ro,
rearrange(cos[:seqlen], "s d -> s 1 d"),
rearrange(sin[:seqlen], "s d -> s 1 d"),
True,
ctx.interleaved,
)
if not inplace and rotary_dim < headdim:
dx[..., rotary_dim:].copy_(do[..., rotary_dim:])
return dx, None, None, None, None
2 changes: 1 addition & 1 deletion deeplink_ext/internlm_ops/rotary/fallback/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Copyright (c) 2024, DeepLink.

from .fallback import ApplyRotaryEmbQKV_
from .fallback import ApplyRotaryEmbQKV_, ApplyRotaryEmb
6 changes: 6 additions & 0 deletions deeplink_ext/internlm_ops/rotary/fallback/fallback.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,3 +113,9 @@ def backward(ctx, dqkv):
)
dqkv[:, :, 1, :, :rotary_dim] = torch.cat((dk1, dk2), dim=-1)
return dqkv, None, None, None, None, None


class ApplyRotaryEmb:
@staticmethod
def apply(*args, **kwargs):
raise NotImplementedError("fallback.ApplyRotaryEmb is not implemented yet")
7 changes: 1 addition & 6 deletions deeplink_ext/patch_internlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,6 @@ def CrossEntropyLossProxy(reduction, **_):
def _patch_ops():
import internlm.model.embedding # type: ignore

def NotImplementedLegacyRotaryEmb(*args, **kwargs):
raise NotImplementedError(
"we assume that legacy_apply_rotary_embed is not used in internlm"
)

class NonLegacyRotaryEmbQKV_(torch.autograd.Function):
"""the first 2 dims of qkv has been squeezed"""

Expand All @@ -74,7 +69,7 @@ def backward(ctx, dqkv: torch.Tensor, *args, **kwargs):

internlm.model.embedding.apply_rotary_emb_qkv_ = NonLegacyRotaryEmbQKV_.apply
internlm.model.embedding.legacy_apply_rotary_embed = (
NotImplementedLegacyRotaryEmb
ext.rotary.DeepLinkApplyRotaryEmb.apply
)
internlm.model.embedding.legacy_apply_rotary_embed_qkv = (
ext.rotary.DeepLinkApplyRotaryEmbQKV_.apply
Expand Down