Skip to content

Commit

Permalink
update tests again
Browse files Browse the repository at this point in the history
  • Loading branch information
epwalsh committed Oct 29, 2024
1 parent f0e980f commit 0d6597b
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ jobs:
constraints:
cluster:
- ai2/allennlp-cirrascale
- ai2/allennlp-elanding-a100-40g
# - ai2/allennlp-elanding-a100-40g
- ai2/pluto-cirrascale
- ai2/jupiter-cirrascale-2
- ai2/saturn-cirrascale
Expand Down
6 changes: 3 additions & 3 deletions src/test/nn/attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
)
@pytest.mark.parametrize(
"n_kv_heads",
[pytest.param(None, id="MHA"), pytest.param(1, id="MQA"), pytest.param(4, id="GQA")],
[pytest.param(None, id="MHA"), pytest.param(1, id="MQA"), pytest.param(2, id="GQA")],
)
@pytest.mark.parametrize(
"use_flash",
Expand Down Expand Up @@ -50,12 +50,12 @@ def test_attention(

torch.random.manual_seed(0)

d_model = 128
d_model = 512
seq_len = 32

attention = Attention(
d_model=d_model,
n_heads=8,
n_heads=4,
n_kv_heads=n_kv_heads,
use_flash=use_flash,
init_device=device.type,
Expand Down

0 comments on commit 0d6597b

Please sign in to comment.