Skip to content

Commit

Permalink
add format CI
Browse files Browse the repository at this point in the history
  • Loading branch information
lljbash committed Jan 24, 2024
1 parent 9c265cd commit 23993cd
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 4 deletions.
32 changes: 32 additions & 0 deletions .github/workflows/format.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
name: format

on:
workflow_dispatch:
pull_request:
push:
branches:
- main

jobs:
clang-format:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: cpp-linter/cpp-linter-action@v2
id: cpp-lint
with:
style: file
tidy-checks: '-*' # disable clang tidy at this stage
version: 17
- name: Fail test
if: steps.cpp-lint.outputs.checks-failed > 0
run: echo "Some files failed the linting checks!" && exit 1

python-black:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: psf/black@stable
with: # see: https://black.readthedocs.io/en/stable/getting_started.html
src: "dipu"
version: "~= 23.11.0"
1 change: 1 addition & 0 deletions deeplink_ext/internlm_ops/mha/mha_varlen_kvpacked_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch
import deeplink_ext.cpp_extensions as ext


class DeepLinkMultiHeadAttentionVarLenKVPackedFunc(torch.autograd.Function):
@staticmethod
def forward(
Expand Down
1 change: 1 addition & 0 deletions deeplink_ext/internlm_ops/mha/mha_varlen_qkvpacked_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch
import deeplink_ext.cpp_extensions as ext


class DeepLinkMultiHeadAttentionVarLenQKVPackedFunc(torch.autograd.Function):
@staticmethod
def forward(
Expand Down
4 changes: 3 additions & 1 deletion deeplink_ext/patch_internlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,9 @@ def _patch_ops():

import internlm.model.norm

internlm.model.norm.RMSNormTorch = ext.rms_norm.DeepLinkRMSNormWithNormalizedShape
internlm.model.norm.RMSNormTorch = (
ext.rms_norm.DeepLinkRMSNormWithNormalizedShape
)

_find_or_mock_module("rotary_emb")
_find_or_mock_module("fused_dense_lib")
Expand Down
4 changes: 1 addition & 3 deletions tests/test_mha_internlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,7 @@ def _run_cross_attention(
output_gold, dq_gold, dkv_gold = _run_cross_attention(
ext.fallback.CrossAttention, q, kv
)
output_ext, dq_ext, dkv_ext = _run_cross_attention(
ext.DeepLinkCrossAttention, q, kv
)
output_ext, dq_ext, dkv_ext = _run_cross_attention(ext.DeepLinkCrossAttention, q, kv)
assert torch.allclose(output_gold, output_ext, atol=1e-3)
print("CrossAttention forward test pass")
assert torch.allclose(dq_gold, dq_ext, atol=2e-3)
Expand Down

0 comments on commit 23993cd

Please sign in to comment.