From 2b5e0326c7acbf78cb9838acbe23468812a13d3b Mon Sep 17 00:00:00 2001 From: jingguo-st Date: Wed, 4 Sep 2024 02:57:44 +0000 Subject: [PATCH] fix py format --- deeplink_ext/ascend_speed/_rotary_embedding_npu.py | 6 +++++- tests/conftest.py | 3 ++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/deeplink_ext/ascend_speed/_rotary_embedding_npu.py b/deeplink_ext/ascend_speed/_rotary_embedding_npu.py index 646f7e9..2b84b3e 100644 --- a/deeplink_ext/ascend_speed/_rotary_embedding_npu.py +++ b/deeplink_ext/ascend_speed/_rotary_embedding_npu.py @@ -27,4 +27,8 @@ def forward(ctx, x, cos, sin): @staticmethod def backward(ctx, grad_output): out, cos, sin = ctx.saved_tensors - return torch_npu.npu_rotary_mul_backward(grad_output, out, cos, sin)[0], None, None + return ( + torch_npu.npu_rotary_mul_backward(grad_output, out, cos, sin)[0], + None, + None, + ) diff --git a/tests/conftest.py b/tests/conftest.py index cd2aae5..98906e9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,7 +3,8 @@ from deeplink_ext.utils import PlatformType, deeplink_ext_get_platform_type -@pytest.fixture(scope='session', autouse=True) + +@pytest.fixture(scope="session", autouse=True) def import_module(): platform = deeplink_ext_get_platform_type() if platform == PlatformType.TORCH_NPU: