diff --git a/requirements/requirements-llm.txt b/requirements/requirements-llm.txt index 9bc21d251..bb8c823dc 100644 --- a/requirements/requirements-llm.txt +++ b/requirements/requirements-llm.txt @@ -1,3 +1,3 @@ # optimum-amd[brevitas] @ git+https://github.com/huggingface/optimum-amd.git@main tqdm -transformers[sentencepiece]==4.45.2 +transformers[sentencepiece] diff --git a/src/brevitas_examples/llm/llm_quant/mha_layers.py b/src/brevitas_examples/llm/llm_quant/mha_layers.py index cf694d4eb..67eeb8738 100644 --- a/src/brevitas_examples/llm/llm_quant/mha_layers.py +++ b/src/brevitas_examples/llm/llm_quant/mha_layers.py @@ -152,6 +152,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, + position_ids: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if key_value_states is None: key_value_states = hidden_states @@ -164,14 +165,17 @@ def forward( query_seq_length, batch_size = hidden_states.shape[:2] key_value_seq_length = key_value_states.shape[0] num_heads = self.num_heads - attention_mask = attention_mask_handler( - attention_mask, batch_size, num_heads, query_seq_length, key_value_seq_length) + attention_mask = ( + attention_mask_handler( + attention_mask, batch_size, num_heads, query_seq_length, key_value_seq_length) + if attention_mask is not None else None) attn_output, attn_output_weights = self.mha( hidden_states, key_value_states, key_value_states, attn_mask=attention_mask, need_weights=output_attentions, - average_attn_weights=False) + average_attn_weights=False, + ) past_key_value = None return attn_output, attn_output_weights, past_key_value diff --git a/src/brevitas_examples/llm/llm_quant/prepare_for_quantize.py b/src/brevitas_examples/llm/llm_quant/prepare_for_quantize.py index d22b2eff1..2a71546e4 100644 --- a/src/brevitas_examples/llm/llm_quant/prepare_for_quantize.py +++ b/src/brevitas_examples/llm/llm_quant/prepare_for_quantize.py @@ -2,11 +2,16 @@ import torch from transformers.models.opt.modeling_opt import OPTAttention +from transformers.models.opt.modeling_opt import OPTSdpaAttention from brevitas.graph import ModuleToModuleByClass from brevitas_examples.llm.llm_quant.mha_layers import QuantizableOPTAttention -QUANTIZABLE_MHA_MAP = {OPTAttention: (QuantizableOPTAttention, {'batch_first': True})} +QUANTIZABLE_MHA_MAP = { + OPTAttention: (QuantizableOPTAttention, { + 'batch_first': True}), + OPTSdpaAttention: (QuantizableOPTAttention, { + 'batch_first': True}),} def replace_mha_with_quantizable_layers(model, dtype): diff --git a/tests/brevitas_examples/test_llm.py b/tests/brevitas_examples/test_llm.py index 576af04b1..79b8536d4 100644 --- a/tests/brevitas_examples/test_llm.py +++ b/tests/brevitas_examples/test_llm.py @@ -14,6 +14,7 @@ import pytest import pytest_cases import torch +import transformers from brevitas import config from brevitas import torch_version @@ -40,6 +41,10 @@ def allexact(x, y): return np.allclose(x, y, rtol=0.0, atol=0.0, equal_nan=False) +def transformers_version_ge(required_version: str): + return version.parse(required_version) >= version.parse(transformers.__version__) + + # Check that all args in args are used def validate_args(args): a = vars(args) @@ -126,6 +131,7 @@ def default_run_args(request): args.weight_quant_granularity = "per_channel" # "per_tensor", "per_channel", "per_group". args.input_bit_width = 8 args.act_calibration = True + args.no_float16 = True return args @@ -203,18 +209,18 @@ def test_small_models_toggle_run_args_pt_ge_2_4( "llama", "mistral",], params=[ - { - "model": "hf-internal-testing/tiny-random-MistralForCausalLM", - "act_equalization": "layerwise", - "gptq": True, - "float_ppl": 31274.05078125, - "quant_ppl": 33139.23046875}, { "model": "hf-internal-testing/tiny-random-LlamaForCausalLM", "act_equalization": "fx", "bias_corr": True, - "float_ppl": 33239.5, - "quant_ppl": 33283.75390625},]) + "float_ppl": 33312.0 if transformers_version_ge('4.46.0') else 33239.5, + "quant_ppl": 33056.0 if transformers_version_ge('4.46.0') else 33283.75390625}, + { + "model": "hf-internal-testing/tiny-random-MistralForCausalLM", + "act_equalization": "layerwise", + "gptq": True, + "float_ppl": 31056.0 if transformers_version_ge('4.46.0') else 31274.05078125, + "quant_ppl": 33056.0 if transformers_version_ge('4.46.0') else 33139.23046875},]) def acc_args_and_acc(default_run_args, request): args = default_run_args run_dict = request.param