From 221c82277d1aef2fdfacb0895f7a72984f143f32 Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Mon, 18 Nov 2024 15:58:44 +0000 Subject: [PATCH] Fix (nn/sdpa): Updated argument to match qsdpa --- src/brevitas/nn/quant_sdpa.py | 29 ++++++++++++++++------------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/src/brevitas/nn/quant_sdpa.py b/src/brevitas/nn/quant_sdpa.py index 9927ca41a..a7a61d56d 100644 --- a/src/brevitas/nn/quant_sdpa.py +++ b/src/brevitas/nn/quant_sdpa.py @@ -108,11 +108,11 @@ class QuantScaledDotProductAttention(Module): def __init__( self, - query_quant=Int8ActPerTensorFloat, - key_quant=Int8ActPerTensorFloat, - value_quant=Int8ActPerTensorFloat, - softmax_input_quant=Int8ActPerTensorFloat, - softmax_output_quant=Uint8ActPerTensorFloat, + softmax_input_quant=None, + attn_output_weights_quant=Uint8ActPerTensorFloat, + q_scaled_quant=Int8ActPerTensorFloat, + k_transposed_quant=Int8ActPerTensorFloat, + v_quant=Int8ActPerTensorFloat, attn_output_quant=None, **kwargs) -> None: super(QuantScaledDotProductAttention, self).__init__() @@ -120,13 +120,14 @@ def __init__( def filter_kwargs(prefix): return {k[len(prefix):]: v for k, v in kwargs.items() if k.startswith(prefix)} - self.query_quant = QuantIdentity(act_quant=query_quant, **filter_kwargs('query_')) - self.key_quant = QuantIdentity(act_quant=key_quant, **filter_kwargs('key_')) - self.value_quant = QuantIdentity(act_quant=value_quant, **filter_kwargs('value_')) + self.q_scaled_quant = QuantIdentity(act_quant=q_scaled_quant, **filter_kwargs('q_scaled_')) + self.k_transposed_quant = QuantIdentity( + act_quant=k_transposed_quant, **filter_kwargs('k_transposed_')) + self.v_quant = QuantIdentity(act_quant=v_quant, **filter_kwargs('v_')) self.softmax_input_quant = QuantIdentity( act_quant=softmax_input_quant, **filter_kwargs('softmax_input_')) - self.softmax_output_quant = QuantIdentity( - act_quant=softmax_output_quant, **filter_kwargs('softmax_output_')) + self.attn_output_weights_quant = QuantIdentity( + act_quant=attn_output_weights_quant, **filter_kwargs('attn_output_weights_')) self.attn_output_quant = QuantIdentity( act_quant=attn_output_quant, **filter_kwargs('attn_output_')) @@ -187,12 +188,14 @@ def forward( attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf")) else: attn_bias += attn_mask - attn_weight = query @ key.transpose(-2, -1) * scale_factor + q_scaled = self.q_scaled_quant(query * scale_factor) + k_transpose = self.k_transpose_quant(key.transpose(-2, -1)) + attn_weight = q_scaled @ k_transpose attn_weight += attn_bias attn_weight = self.softmax_input_quant(attn_weight) attn_weight = torch.softmax(attn_weight, dim=-1) attn_weight = torch.dropout(attn_weight, dropout_p, train=True) - attn_weight = self.softmax_output_quant(attn_weight) - attn_output = attn_weight @ value + attn_weight = self.attn_output_weights_quant(attn_weight) + attn_output = attn_weight @ self.v_quant(value) attn_output = self.attn_output_quant(attn_output) return attn_output