Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adapt Splash Attention from TorchPrime #8911

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open

Adapt Splash Attention from TorchPrime #8911

wants to merge 7 commits into from

Conversation

zpcore
Copy link
Collaborator

@zpcore zpcore commented Mar 31, 2025

Adapt the PR AI-Hypercomputer/torchprime#145 from TorchPrime into PTXLA. Also simplified the code to use jit hashing from #8878.

@zpcore zpcore marked this pull request as ready for review April 5, 2025 21:00
@zpcore zpcore requested a review from tengyifei April 5, 2025 21:00
@zpcore zpcore enabled auto-merge (squash) April 5, 2025 23:28
Copy link
Collaborator

@tengyifei tengyifei left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

from torch_xla.experimental.custom_kernel import requires_jax


@dataclasses.dataclass
Copy link
Collaborator

@tengyifei tengyifei Apr 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can add eq=True and hash=True to the dataclasses decorator call. This way any instance of this config will be hashable and can be used as an argument to call_jax, avoiding the lru cache. (I had a comment with more details later)

)

mesh = config.maybe_convert_and_get_jax_mesh()
# input q,k,v shape: [batch, #head, seq_len, kv]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should the last dim be called "head or "head_dim"?

query.shape[2] == decoder_segment_ids.q.shape[1]
), "Sharding along sequence dimension not allowed in tpu kernel attention"
block_sizes = splash_attention_kernel.BlockSizes(
block_q=min(global_block_q, query.shape[2]),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we factor out a seq_len variable

("data", "fsdp"),
None,
)
AttentionType_LOCAL_SLIDING: bool = False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: field names should be simple snake case


def maybe_reduce_kv_grad(self, hidden_state_grad):
# For GQA, the kv grad shape is [BATCH_SIZE, NUM_Q_HEADS, SEQ_LEN,
# HEAD_DIM]. We need to convert it back to [BATCH_SIZE, NUM_Q_HEADS,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Convert back to NUM_KV_HEADS?

self.k_grad = self.maybe_reduce_kv_grad(k_grad)
self.v_grad = self.maybe_reduce_kv_grad(v_grad)

def maybe_expend_kv(self, hidden_state):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: expand, or repeat

@@ -41,6 +41,7 @@ run_xla_hlo_debug python3 "$TEST_CDIR/scan/test_scan_debug.py"
python3 "$TEST_CDIR/test_pallas.py" -v
python3 "$TEST_CDIR/test_pallas_spmd.py"
XLA_DISABLE_FUNCTIONALIZATION=1 python3 "$TEST_CDIR/test_pallas_spmd.py"
python3 "$TEST_CDIR/test_splash_attention_jax.py"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Jax is just an implementation detail. We could simply call this file test_splash_attention.py

@@ -0,0 +1,397 @@
import dataclasses
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I think we could call this splash_attention.py to be more specific. If there are generic Jax utilities, those could be put in a custom_kernels_from_jax.py or similar

k: torch.Tensor,
v: torch.Tensor,
config: str,
decoder_segment_ids: torch.Tensor | None = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we at least document the segment IDs thing and the soft cap in triple quote style doc comments?

Maybe also document that splash attention vs flash attention, e.g. splash attention can be faster if the attention mask is sparse by skipping blocks, etc etc

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants