Skip to content

Commit

Permalink
optimize
Browse files Browse the repository at this point in the history
  • Loading branch information
POI-WX committed Jan 9, 2025
1 parent ee0baeb commit 5565fbc
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions deeplink_ext/internevo_ops/_rotary_embedding_npu.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# Copyright (c) 2024, DeepLink.

import torch
import torch_npu
from einops import repeat
from mindspeed.ops.npu_rotary_position_embedding import npu_rotary_position_embedding

Expand Down Expand Up @@ -45,9 +44,11 @@ def forward(
else:
cos = repeat(cos[:seqlen], "... d -> 1 ... 1 (2 d)")
sin = repeat(sin[:seqlen], "... d -> 1 ... 1 (2 d)")

ctx.save_for_backward(cos, sin)
ctx.interleaved = interleaved
ctx.in_place = in_place

if interleaved:
x_ro = x[..., :rotary_dim]
out_ro = npu_rotary_position_embedding(x_ro, cos, sin, 1)
Expand All @@ -62,7 +63,7 @@ def forward(
return out_ro
else:
x_ro = x[..., :rotary_dim]
out_ro = torch_npu.npu_rotary_mul(x_ro, cos, sin)
out_ro = npu_rotary_position_embedding(x_ro, cos, sin, 0)
if in_place:
x[..., :rotary_dim].copy_(out_ro)
return x
Expand All @@ -78,6 +79,7 @@ def backward(ctx, grad_out):
cos, sin = ctx.saved_tensors
rotary_dim = cos.shape[-1]
head_dim = grad_out.shape[-1]

if ctx.interleaved:
grad_out_ro = grad_out[..., :rotary_dim]
grad_input_ro = npu_rotary_position_embedding(
Expand All @@ -94,7 +96,9 @@ def backward(ctx, grad_out):
return grad_input_ro, None, None, None, None
else:
grad_out_ro = grad_out[..., :rotary_dim]
grad_input_ro = torch_npu.npu_rotary_mul(grad_out_ro, cos, torch.neg(sin))
grad_input_ro = npu_rotary_position_embedding(
grad_out_ro, cos, torch.neg(sin), 0
)
if ctx.in_place:
grad_out[..., :rotary_dim].copy_(grad_input_ro)
return grad_out, None, None, None, None
Expand Down

0 comments on commit 5565fbc

Please sign in to comment.