-
Notifications
You must be signed in to change notification settings - Fork 503
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adapt Splash Attention from TorchPrime #8911
base: master
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!
from torch_xla.experimental.custom_kernel import requires_jax | ||
|
||
|
||
@dataclasses.dataclass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can add eq=True and hash=True to the dataclasses decorator call. This way any instance of this config will be hashable and can be used as an argument to call_jax, avoiding the lru cache. (I had a comment with more details later)
) | ||
|
||
mesh = config.maybe_convert_and_get_jax_mesh() | ||
# input q,k,v shape: [batch, #head, seq_len, kv] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should the last dim be called "head or "head_dim"?
query.shape[2] == decoder_segment_ids.q.shape[1] | ||
), "Sharding along sequence dimension not allowed in tpu kernel attention" | ||
block_sizes = splash_attention_kernel.BlockSizes( | ||
block_q=min(global_block_q, query.shape[2]), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we factor out a seq_len variable
("data", "fsdp"), | ||
None, | ||
) | ||
AttentionType_LOCAL_SLIDING: bool = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: field names should be simple snake case
|
||
def maybe_reduce_kv_grad(self, hidden_state_grad): | ||
# For GQA, the kv grad shape is [BATCH_SIZE, NUM_Q_HEADS, SEQ_LEN, | ||
# HEAD_DIM]. We need to convert it back to [BATCH_SIZE, NUM_Q_HEADS, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Convert back to NUM_KV_HEADS?
self.k_grad = self.maybe_reduce_kv_grad(k_grad) | ||
self.v_grad = self.maybe_reduce_kv_grad(v_grad) | ||
|
||
def maybe_expend_kv(self, hidden_state): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: expand, or repeat
@@ -41,6 +41,7 @@ run_xla_hlo_debug python3 "$TEST_CDIR/scan/test_scan_debug.py" | |||
python3 "$TEST_CDIR/test_pallas.py" -v | |||
python3 "$TEST_CDIR/test_pallas_spmd.py" | |||
XLA_DISABLE_FUNCTIONALIZATION=1 python3 "$TEST_CDIR/test_pallas_spmd.py" | |||
python3 "$TEST_CDIR/test_splash_attention_jax.py" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Jax is just an implementation detail. We could simply call this file test_splash_attention.py
@@ -0,0 +1,397 @@ | |||
import dataclasses |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I think we could call this splash_attention.py to be more specific. If there are generic Jax utilities, those could be put in a custom_kernels_from_jax.py or similar
k: torch.Tensor, | ||
v: torch.Tensor, | ||
config: str, | ||
decoder_segment_ids: torch.Tensor | None = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we at least document the segment IDs thing and the soft cap in triple quote style doc comments?
Maybe also document that splash attention vs flash attention, e.g. splash attention can be faster if the attention mask is sparse by skipping blocks, etc etc
Adapt the PR AI-Hypercomputer/torchprime#145 from TorchPrime into PTXLA. Also simplified the code to use jit hashing from #8878.