From e3890dd87a4e29b23ba4775fde0cc82b71d88661 Mon Sep 17 00:00:00 2001 From: kelvin-zou Date: Fri, 31 Jan 2025 13:57:58 -0800 Subject: [PATCH] fix pylint --- axlearn/common/flash_attention/gpu_attention.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/axlearn/common/flash_attention/gpu_attention.py b/axlearn/common/flash_attention/gpu_attention.py index dd9fdca6a..187e8f471 100644 --- a/axlearn/common/flash_attention/gpu_attention.py +++ b/axlearn/common/flash_attention/gpu_attention.py @@ -76,12 +76,7 @@ def _segment_mask( def _build_mask( - mask_fn: MaskFn, - *, - q_seq_len: int, - kv_seq_len: int, - block_q: int, - block_k: int + mask_fn: MaskFn, *, q_seq_len: int, kv_seq_len: int, block_q: int, block_k: int ) -> np.ndarray: """build the iteration map where True means the block is not empty. @@ -134,7 +129,7 @@ def _query_iterator_indices(block_mask_map: np.ndarray) -> Tuple[Tensor, Tensor] return jnp.asarray(index_offset), jnp.asarray(index_offset_size) -def _key_value_iterator_indices(block_mask_map: np.ndarray)->Tuple[Tensor, Tensor]: +def _key_value_iterator_indices(block_mask_map: np.ndarray) -> Tuple[Tensor, Tensor]: """build the iteration begin/end indices for the key/value dimension. Returns: