diff --git a/tests/brevitas/nn/test_sdpa.py b/tests/brevitas/nn/test_sdpa.py index 7e0de670d..9d735a997 100644 --- a/tests/brevitas/nn/test_sdpa.py +++ b/tests/brevitas/nn/test_sdpa.py @@ -28,7 +28,7 @@ class TestScaledDotProductAttention: @pytest.mark.parametrize("scale", [None, 0.3]) @pytest.mark.parametrize("enable_gqa", [False, True]) @pytest.mark.parametrize("rand_attn_mask", [False, True]) - # Sanity check, since `ScaledDotProductAttention` just called `F.scaled_dot_product_attention` in its forward function + # Sanity check, since `ScaledDotProductAttention` just calls `F.scaled_dot_product_attention` in its forward function def test_sdpa_fwd(self, dropout_p, is_causal, scale, enable_gqa, rand_attn_mask): extra_kwargs = { "dropout_p": dropout_p, @@ -56,3 +56,45 @@ def test_sdpa_fwd(self, dropout_p, is_causal, scale, enable_gqa, rand_attn_mask) out = m(q, k, v, attn_mask, **extra_kwargs) assert torch.isclose(out, ref_out, atol=ATOL).all() assert torch.isclose(out, ref_out, atol=ATOL).all() + + @requires_pt_ge('2.0') + @pytest.mark.parametrize("dropout_p", [0.0, 0.5]) + @pytest.mark.parametrize("is_causal", [True, False]) + @pytest.mark.parametrize("scale", [None, 0.3]) + @pytest.mark.parametrize("enable_gqa", [False, True]) + @pytest.mark.parametrize("rand_attn_mask", [False, True]) + def test_sdpa_quant_disabled_fwd(self, dropout_p, is_causal, scale, enable_gqa, rand_attn_mask): + extra_kwargs = { + "dropout_p": dropout_p, + "is_causal": is_causal, + "scale": scale, + "enable_gqa": enable_gqa,} + if torch_version < version.parse('2.5.0'): + del extra_kwargs["enable_gqa"] + + kv_length = PAST_SEQUENCE_LENGTH + SEQUENCE_LENGTH + m = ScaledDotProductAttention() + qm = QuantScaledDotProductAttention( + softmax_input_quant=None, + attn_output_weights_quant=None, + q_scaled_quant=None, + k_transposed_quant=None, + v_quant=None, + attn_output_quant=None, + ) + q = torch.randn(BATCH_SIZE, HEAD_DIM, SEQUENCE_LENGTH, EMBED_DIM) + k = torch.randn(BATCH_SIZE, HEAD_DIM, kv_length, EMBED_DIM) + v = torch.randn(BATCH_SIZE, HEAD_DIM, kv_length, EMBED_DIM) + if rand_attn_mask and not is_causal: + attn_mask = torch.randint( + low=0, high=2, size=(BATCH_SIZE, 1, SEQUENCE_LENGTH, kv_length), dtype=torch.bool) + else: + attn_mask = None + if dropout_p > 0.0: + torch.manual_seed(DROPOUT_SEED) + ref_out = m(q, k, v, attn_mask, **extra_kwargs) + if dropout_p > 0.0: + torch.manual_seed(DROPOUT_SEED) + out = qm(q, k, v, attn_mask, **extra_kwargs) + assert torch.isclose(out, ref_out, atol=ATOL).all() + assert torch.isclose(out, ref_out, atol=ATOL).all()