-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: add scaled masked softmax op for ascend speed (#64)
Add scaled masked softmax op for ascend speed.
- Loading branch information
Showing
5 changed files
with
56 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
from .rotary_embedding import apply_rotary, RotaryEmbedding | ||
from .adamw import adamw | ||
from .scaled_masked_softmax import ScaledMaskedSoftmax | ||
|
||
__all__ = ["apply_rotary", "RotaryEmbedding", "adamw"] | ||
__all__ = ["apply_rotary", "RotaryEmbedding", "adamw", "ScaledMaskedSoftmax"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
import torch | ||
import deeplink_ext.cpp_extensions as ext | ||
|
||
|
||
assert hasattr(ext, "scaled_masked_softmax_fwd") and hasattr( | ||
ext, "scaled_masked_softmax_bwd" | ||
) | ||
|
||
|
||
class ScaledMaskedSoftmax(torch.autograd.Function): | ||
@staticmethod | ||
def forward(ctx, input, mask, scale, fixed_triu_mask): | ||
out = torch.empty_like(input) | ||
ext.scaled_masked_softmax_fwd(out, input, mask, scale, fixed_triu_mask) | ||
ctx.save_for_backward(out, mask) | ||
ctx.scale = scale | ||
ctx.fixed_triu_mask = fixed_triu_mask | ||
return out | ||
|
||
@staticmethod | ||
def backward(ctx, grad_output): | ||
out, mask = ctx.saved_tensors | ||
grad_input = torch.empty_like(grad_output) | ||
ext.scaled_masked_softmax_bwd( | ||
grad_input, grad_output, out, mask, ctx.scale, ctx.fixed_triu_mask | ||
) | ||
return grad_input, None, None, None |