Skip to content

Commit

Permalink
fix pylint
Browse files Browse the repository at this point in the history
  • Loading branch information
kelvin-zou committed Jan 31, 2025
1 parent 14f01e5 commit e3890dd
Showing 1 changed file with 2 additions and 7 deletions.
9 changes: 2 additions & 7 deletions axlearn/common/flash_attention/gpu_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,7 @@ def _segment_mask(


def _build_mask(
mask_fn: MaskFn,
*,
q_seq_len: int,
kv_seq_len: int,
block_q: int,
block_k: int
mask_fn: MaskFn, *, q_seq_len: int, kv_seq_len: int, block_q: int, block_k: int
) -> np.ndarray:
"""build the iteration map where True means the block is not empty.
Expand Down Expand Up @@ -134,7 +129,7 @@ def _query_iterator_indices(block_mask_map: np.ndarray) -> Tuple[Tensor, Tensor]
return jnp.asarray(index_offset), jnp.asarray(index_offset_size)


def _key_value_iterator_indices(block_mask_map: np.ndarray)->Tuple[Tensor, Tensor]:
def _key_value_iterator_indices(block_mask_map: np.ndarray) -> Tuple[Tensor, Tensor]:
"""build the iteration begin/end indices for the key/value dimension.
Returns:
Expand Down

0 comments on commit e3890dd

Please sign in to comment.