Skip to content

Commit

Permalink
feat(internlm): implement patch for non-legacy rotary_emb (#38)
Browse files Browse the repository at this point in the history
  • Loading branch information
lljbash authored Jan 30, 2024
1 parent 9a80a19 commit b7be55d
Show file tree
Hide file tree
Showing 6 changed files with 34 additions and 191 deletions.
5 changes: 3 additions & 2 deletions csrc/extensions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ auto extRmsNormBackward(const at::Tensor& input, const at::Tensor& grad_output,

void extApplyRotary(at::Tensor output, const at::Tensor& input,
const at::Tensor& cos, const at::Tensor& sin,
const bool conj, const bool interleaved = false) {
const bool conj, const bool interleaved) {
callDiopi(diopiRotaryEmbedding, output, input, cos, sin, conj, interleaved);
}

Expand Down Expand Up @@ -239,7 +239,8 @@ void extRotaryEmb(at::Tensor& q, const at::Tensor& cos, const at::Tensor& sin) {
auto dim = q.size(-1);
auto cos_view = cos.view({seq_len, 1, dim / 2});
auto sin_view = sin.view({seq_len, 1, dim / 2});
callDiopi(diopiRotaryEmbedding, q, q, cos_view, sin_view, false, false);
callDiopi(diopiRotaryEmbedding, q, q, cos_view, sin_view, /*conj=*/false,
/*interleaved=*/false);
}

// 判断是否有对应的 diopi 实现:
Expand Down
7 changes: 2 additions & 5 deletions deeplink_ext/internlm_ops/rotary/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
# Copyright (c) 2024, DeepLink.

try:
from .deeplink import DeepLinkApplyRotaryEmb, DeepLinkApplyRotaryEmbQKV_
from .deeplink import DeepLinkApplyRotaryEmbQKV_
except:
print(
"[deeplink_ext] rotary is not implemented in diopi. Falling back to the slower implementation.\n",
end="",
)
from .fallback import (
ApplyRotaryEmb as DeepLinkApplyRotaryEmb,
ApplyRotaryEmbQKV_ as DeepLinkApplyRotaryEmbQKV_,
)
from .fallback import ApplyRotaryEmbQKV_ as DeepLinkApplyRotaryEmbQKV_
from . import fallback
107 changes: 4 additions & 103 deletions deeplink_ext/internlm_ops/rotary/deeplink.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def forward(ctx, qkv, cos, sin, cos_k=None, sin_k=None, interleaved=False):
rearrange(cos[:seqlen], "s d -> s 1 d"),
rearrange(sin[:seqlen], "s d -> s 1 d"),
False,
False,
interleaved,
)
k_ro = qkv[:, :, 1, :, :rotary_dim]
ext.apply_rotary(
Expand All @@ -37,7 +37,7 @@ def forward(ctx, qkv, cos, sin, cos_k=None, sin_k=None, interleaved=False):
rearrange(cos[:seqlen], "s d -> s 1 d"),
rearrange(sin[:seqlen], "s d -> s 1 d"),
False,
False,
interleaved,
)
ctx.save_for_backward(cos, sin, cos_k, sin_k)
ctx.interleaved = interleaved
Expand All @@ -57,7 +57,7 @@ def backward(ctx, dqkv):
rearrange(cos[:seqlen], "s d -> s 1 d"),
rearrange(sin[:seqlen], "s d -> s 1 d"),
True,
False,
interleaved,
)
dk_ro = dqkv[:, :, 1, :, :rotary_dim]
ext.apply_rotary(
Expand All @@ -66,105 +66,6 @@ def backward(ctx, dqkv):
rearrange(cos[:seqlen], "s d -> s 1 d"),
rearrange(sin[:seqlen], "s d -> s 1 d"),
True,
False,
interleaved,
)
return dqkv, None, None, None, None, None


class DeepLinkApplyRotaryEmb(torch.autograd.Function):
@staticmethod
def forward(ctx, x, cos, sin, interleaved=False, inplace=False):
"""
x: (batch_size, seqlen, nheads, headdim)
cos, sin: (seqlen, rotary_dim / 2)
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
of 1st half and 2nd half (GPT-NeoX style).
rotary_dim must be <= headdim
Apply rotary embedding to the first rotary_dim of x.
"""
batch, seqlen, nheads, headdim = x.shape
rotary_seqlen, rotary_dim = cos.shape
rotary_dim *= 2
assert rotary_dim <= headdim
assert seqlen <= rotary_seqlen
assert sin.shape == (rotary_seqlen, rotary_dim // 2)
x_ro = x[..., :rotary_dim]
x1, x2 = (
x_ro.chunk(2, dim=-1)
if not interleaved
else (x_ro[..., ::2], x_ro[..., 1::2])
)
out = torch.empty_like(x) if not inplace else x
out_ro = out[..., :rotary_dim]

if inplace:
ext.apply_rotary(
out_ro,
x_ro,
rearrange(cos[:seqlen], "s d -> s 1 d"),
rearrange(sin[:seqlen], "s d -> s 1 d"),
False,
False,
)
else:
ext.apply_rotary(
out_ro,
x_ro,
rearrange(cos[:seqlen], "s d -> s 1 d"),
rearrange(sin[:seqlen], "s d -> s 1 d"),
False,
False,
)

if not inplace and rotary_dim < headdim:
out[..., rotary_dim:].copy_(x[..., rotary_dim:])
ctx.save_for_backward(cos, sin)
ctx.interleaved = interleaved
ctx.inplace = inplace
return out if not inplace else x

@staticmethod
def backward(ctx, do):
cos, sin = ctx.saved_tensors
_, seqlen, _, headdim = do.shape
rotary_dim = cos.shape[-1]
rotary_dim *= 2
inplace = ctx.inplace
do_ro = do[..., :rotary_dim]
do1, do2 = (
do_ro.chunk(2, dim=-1)
if not ctx.interleaved
else (do_ro[..., ::2], do_ro[..., 1::2])
)
dx = torch.empty_like(do) if not inplace else do
if inplace:
dx1, dx2 = do1, do2
else:
dx_ro = dx[..., :rotary_dim]
dx1, dx2 = (
dx_ro.chunk(2, dim=-1)
if not ctx.interleaved
else (dx_ro[..., ::2], dx_ro[..., 1::2])
)
if inplace:
ext.apply_rotary(
do_ro,
do_ro,
rearrange(cos[:seqlen], "s d -> s 1 d"),
rearrange(sin[:seqlen], "s d -> s 1 d"),
True,
False,
)
else:
ext.apply_rotary(
dx_ro,
do_ro,
rearrange(cos[:seqlen], "s d -> s 1 d"),
rearrange(sin[:seqlen], "s d -> s 1 d"),
True,
False,
)

if not inplace and rotary_dim < headdim:
dx[..., rotary_dim:].copy_(do[..., rotary_dim:])
return dx, None, None, None, None
2 changes: 1 addition & 1 deletion deeplink_ext/internlm_ops/rotary/fallback/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Copyright (c) 2024, DeepLink.

from .fallback import ApplyRotaryEmb, ApplyRotaryEmbQKV_
from .fallback import ApplyRotaryEmbQKV_
75 changes: 0 additions & 75 deletions deeplink_ext/internlm_ops/rotary/fallback/fallback.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,78 +114,3 @@ def backward(ctx, dqkv):
)
dqkv[:, :, 1, :, :rotary_dim] = torch.cat((dk1, dk2), dim=-1)
return dqkv, None, None, None, None, None


class ApplyRotaryEmb(torch.autograd.Function):
@staticmethod
def forward(ctx, x, cos, sin, interleaved=False):
"""
x: (batch_size, seqlen, nheads, headdim)
cos, sin: (seqlen, rotary_dim / 2)
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
of 1st half and 2nd half (GPT-NeoX style).
rotary_dim must be <= headdim
Apply rotary embedding to the first rotary_dim of x.
"""
batch, seqlen, nheads, headdim = x.shape
rotary_seqlen, rotary_dim = cos.shape
rotary_dim *= 2
assert rotary_dim <= headdim
assert seqlen <= rotary_seqlen
assert sin.shape == (rotary_seqlen, rotary_dim // 2)
x_ro = x[..., :rotary_dim]
x1, x2 = (
x_ro.chunk(2, dim=-1)
if not interleaved
else (x_ro[..., ::2], x_ro[..., 1::2])
)
out = torch.empty_like(x)
out_ro = out[..., :rotary_dim]

ext.apply_rotary(
out_ro,
x_ro,
rearrange(cos[:seqlen], "s d -> s 1 d"),
rearrange(sin[:seqlen], "s d -> s 1 d"),
False,
False,
)

if rotary_dim < headdim:
out[..., rotary_dim:].copy_(x[..., rotary_dim:])
ctx.save_for_backward(cos, sin)
ctx.interleaved = interleaved
return out

@staticmethod
def backward(ctx, do):
cos, sin = ctx.saved_tensors
_, seqlen, _, headdim = do.shape
rotary_dim = cos.shape[-1]
rotary_dim *= 2
do_ro = do[..., :rotary_dim]
do1, do2 = (
do_ro.chunk(2, dim=-1)
if not ctx.interleaved
else (do_ro[..., ::2], do_ro[..., 1::2])
)
dx = torch.empty_like(do)

dx_ro = dx[..., :rotary_dim]
dx1, dx2 = (
dx_ro.chunk(2, dim=-1)
if not ctx.interleaved
else (dx_ro[..., ::2], dx_ro[..., 1::2])
)
ext.apply_rotary(
dx_ro,
do_ro,
rearrange(cos[:seqlen], "s d -> s 1 d"),
rearrange(sin[:seqlen], "s d -> s 1 d"),
True,
False,
)

if rotary_dim < headdim:
dx[..., rotary_dim:].copy_(do[..., rotary_dim:])
return dx, None, None, None, None
29 changes: 24 additions & 5 deletions deeplink_ext/patch_internlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ def _patch_internlm():
import os
import sys
import unittest.mock as mock
import torch
import deeplink_ext.internlm_ops as ext

def _find_or_mock_module(module_name):
Expand Down Expand Up @@ -47,15 +48,33 @@ def CrossEntropyLossProxy(reduction, **_):
def _patch_ops():
import internlm.model.embedding # type: ignore

# TODO(lljbash,gongqiwei): implement a module aligned with rotary_emb
def NotImplementedRotaryEnb(*args, **kwargs):
def NotImplementedLegacyRotaryEmb(*args, **kwargs):
raise NotImplementedError(
"the patch for apply_rotary_emb_qkv_ (requires rotary_emb) has not been implemented in deeplink_ext yet"
"we assume that legacy_apply_rotary_embed is not used in internlm"
)

internlm.model.embedding.apply_rotary_emb_qkv_ = NotImplementedRotaryEnb
class NonLegacyRotaryEmbQKV_(torch.autograd.Function):
"""the first 2 dims of qkv has been squeezed"""

@staticmethod
def forward(ctx, qkv: torch.Tensor, *args, **kwargs):
unsqueezed_qkv = qkv.view([1] + list(qkv.shape))
out: torch.Tensor = ext.rotary.DeepLinkApplyRotaryEmbQKV_.forward(
ctx, unsqueezed_qkv, *args, **kwargs
)
return out.view(out.shape[1:])

@staticmethod
def backward(ctx, dqkv: torch.Tensor, *args, **kwargs):
unqueezed_dqkv = dqkv.view([1] + list(dqkv.shape))
out: tuple = ext.rotary.DeepLinkApplyRotaryEmbQKV_.backward(
ctx, unqueezed_dqkv, *args, **kwargs
)
return (out[0].view(out[0].shape[1:]),) + out[1:]

internlm.model.embedding.apply_rotary_emb_qkv_ = NonLegacyRotaryEmbQKV_.apply
internlm.model.embedding.legacy_apply_rotary_embed = (
ext.rotary.DeepLinkApplyRotaryEmb.apply
NotImplementedLegacyRotaryEmb
)
internlm.model.embedding.legacy_apply_rotary_embed_qkv = (
ext.rotary.DeepLinkApplyRotaryEmbQKV_.apply
Expand Down

0 comments on commit b7be55d

Please sign in to comment.