Skip to content

Commit

Permalink
feat: add scaled masked softmax op for ascend speed (#64)
Browse files Browse the repository at this point in the history
Add scaled masked softmax op for ascend speed.
  • Loading branch information
POI-WX authored Apr 2, 2024
1 parent 220857e commit ab149fc
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 4 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/static.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ on:

env:
DEEPLINK_PATH: /mnt/cache/share/deeplinkci/github/${{ github.repository }}
ENV_SOURCE: /mnt/cache/share/platform/env/dipu_latest
ENV_SOURCE: /mnt/cache/share/platform/env/dipu_latest_ci
PROXY_SOURCE: /mnt/cache/share/platform/env/proxy
CLANGD_EXEC: /mnt/cache/share/platform/dep/clang-17/bin/clangd

Expand Down
23 changes: 23 additions & 0 deletions csrc/extensions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,21 @@ auto extMultiHeadAttentionVarLenBackward(
std::move(grad_v));
}

void extScaledMaskedSoftmax(at::Tensor& out, const at::Tensor& input,
const at::Tensor& mask, double scale,
bool fixed_triu_mask) {
callDiopi(diopiScaledMaskedSoftmax, out, input, mask, scale, fixed_triu_mask);
}

void extScaledMaskedSoftmaxBackward(at::Tensor& grad_input,
const at::Tensor& grad_output,
const at::Tensor& out,
const at::Tensor& mask, double scale,
bool fixed_triu_mask) {
callDiopi(diopiScaledMaskedSoftmaxBackward, grad_input, grad_output, out,
mask, scale, fixed_triu_mask);
}

void extDestIndexCopyKV(const at::Tensor& k, const at::Tensor& dest_loc,
at::Tensor& out) {
callDiopi(diopiDestIndexCopyKV, out, k, dest_loc);
Expand Down Expand Up @@ -274,6 +289,14 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
if (&diopiApplyPenalty != nullptr) {
m.def("apply_penalty", &extApplyPenalty, "deeplink ext_apply_penalty");
}
if (&diopiScaledMaskedSoftmax != nullptr) {
m.def("scaled_masked_softmax_fwd", &extScaledMaskedSoftmax,
"deeplink ext_scaled_masked_softmax_fwd");
}
if (&diopiScaledMaskedSoftmaxBackward != nullptr) {
m.def("scaled_masked_softmax_bwd", &extScaledMaskedSoftmaxBackward,
"deeplink ext_scaled_masked_softmax_bwd");
}
}

} // namespace dipu::dipu_ext
3 changes: 2 additions & 1 deletion deeplink_ext/ascend_speed/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .rotary_embedding import apply_rotary, RotaryEmbedding
from .adamw import adamw
from .scaled_masked_softmax import ScaledMaskedSoftmax

__all__ = ["apply_rotary", "RotaryEmbedding", "adamw"]
__all__ = ["apply_rotary", "RotaryEmbedding", "adamw", "ScaledMaskedSoftmax"]
5 changes: 3 additions & 2 deletions deeplink_ext/ascend_speed/adamw.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,9 @@ def adamw(
See :class:`~torch.optim.AdamW` for details.
"""

assert maximize == False, "ascend diopiAdamW only support False 'maximize'."
assert amsgrad == False, "ascend diopiAdamW only support False 'amsgrad'."
assert (
maximize == False
), "The maximize parameter is not supported by diopiAdamW yet"

for i, param in enumerate(params):
if norm_coeff_scale is not None:
Expand Down
27 changes: 27 additions & 0 deletions deeplink_ext/ascend_speed/scaled_masked_softmax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import torch
import deeplink_ext.cpp_extensions as ext


assert hasattr(ext, "scaled_masked_softmax_fwd") and hasattr(
ext, "scaled_masked_softmax_bwd"
)


class ScaledMaskedSoftmax(torch.autograd.Function):
@staticmethod
def forward(ctx, input, mask, scale, fixed_triu_mask):
out = torch.empty_like(input)
ext.scaled_masked_softmax_fwd(out, input, mask, scale, fixed_triu_mask)
ctx.save_for_backward(out, mask)
ctx.scale = scale
ctx.fixed_triu_mask = fixed_triu_mask
return out

@staticmethod
def backward(ctx, grad_output):
out, mask = ctx.saved_tensors
grad_input = torch.empty_like(grad_output)
ext.scaled_masked_softmax_bwd(
grad_input, grad_output, out, mask, ctx.scale, ctx.fixed_triu_mask
)
return grad_input, None, None, None

0 comments on commit ab149fc

Please sign in to comment.