Skip to content

Commit

Permalink
Tests (llm): update LLM tests for new transformers version (#1088)
Browse files Browse the repository at this point in the history
---------

Co-authored-by: Giuseppe Franco <[email protected]>
  • Loading branch information
pablomlago and Giuseppe5 authored Nov 28, 2024
1 parent 5e473c4 commit d51087c
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 13 deletions.
2 changes: 1 addition & 1 deletion requirements/requirements-llm.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# optimum-amd[brevitas] @ git+https://github.com/huggingface/optimum-amd.git@main
tqdm
transformers[sentencepiece]==4.45.2
transformers[sentencepiece]
10 changes: 7 additions & 3 deletions src/brevitas_examples/llm/llm_quant/mha_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ def forward(
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
position_ids: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if key_value_states is None:
key_value_states = hidden_states
Expand All @@ -164,14 +165,17 @@ def forward(
query_seq_length, batch_size = hidden_states.shape[:2]
key_value_seq_length = key_value_states.shape[0]
num_heads = self.num_heads
attention_mask = attention_mask_handler(
attention_mask, batch_size, num_heads, query_seq_length, key_value_seq_length)
attention_mask = (
attention_mask_handler(
attention_mask, batch_size, num_heads, query_seq_length, key_value_seq_length)
if attention_mask is not None else None)
attn_output, attn_output_weights = self.mha(
hidden_states,
key_value_states,
key_value_states,
attn_mask=attention_mask,
need_weights=output_attentions,
average_attn_weights=False)
average_attn_weights=False,
)
past_key_value = None
return attn_output, attn_output_weights, past_key_value
7 changes: 6 additions & 1 deletion src/brevitas_examples/llm/llm_quant/prepare_for_quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,16 @@

import torch
from transformers.models.opt.modeling_opt import OPTAttention
from transformers.models.opt.modeling_opt import OPTSdpaAttention

from brevitas.graph import ModuleToModuleByClass
from brevitas_examples.llm.llm_quant.mha_layers import QuantizableOPTAttention

QUANTIZABLE_MHA_MAP = {OPTAttention: (QuantizableOPTAttention, {'batch_first': True})}
QUANTIZABLE_MHA_MAP = {
OPTAttention: (QuantizableOPTAttention, {
'batch_first': True}),
OPTSdpaAttention: (QuantizableOPTAttention, {
'batch_first': True}),}


def replace_mha_with_quantizable_layers(model, dtype):
Expand Down
22 changes: 14 additions & 8 deletions tests/brevitas_examples/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import pytest
import pytest_cases
import torch
import transformers

from brevitas import config
from brevitas import torch_version
Expand All @@ -40,6 +41,10 @@ def allexact(x, y):
return np.allclose(x, y, rtol=0.0, atol=0.0, equal_nan=False)


def transformers_version_ge(required_version: str):
return version.parse(required_version) >= version.parse(transformers.__version__)


# Check that all args in args are used
def validate_args(args):
a = vars(args)
Expand Down Expand Up @@ -126,6 +131,7 @@ def default_run_args(request):
args.weight_quant_granularity = "per_channel" # "per_tensor", "per_channel", "per_group".
args.input_bit_width = 8
args.act_calibration = True
args.no_float16 = True
return args


Expand Down Expand Up @@ -203,18 +209,18 @@ def test_small_models_toggle_run_args_pt_ge_2_4(
"llama",
"mistral",],
params=[
{
"model": "hf-internal-testing/tiny-random-MistralForCausalLM",
"act_equalization": "layerwise",
"gptq": True,
"float_ppl": 31274.05078125,
"quant_ppl": 33139.23046875},
{
"model": "hf-internal-testing/tiny-random-LlamaForCausalLM",
"act_equalization": "fx",
"bias_corr": True,
"float_ppl": 33239.5,
"quant_ppl": 33283.75390625},])
"float_ppl": 33312.0 if transformers_version_ge('4.46.0') else 33239.5,
"quant_ppl": 33056.0 if transformers_version_ge('4.46.0') else 33283.75390625},
{
"model": "hf-internal-testing/tiny-random-MistralForCausalLM",
"act_equalization": "layerwise",
"gptq": True,
"float_ppl": 31056.0 if transformers_version_ge('4.46.0') else 31274.05078125,
"quant_ppl": 33056.0 if transformers_version_ge('4.46.0') else 33139.23046875},])
def acc_args_and_acc(default_run_args, request):
args = default_run_args
run_dict = request.param
Expand Down

0 comments on commit d51087c

Please sign in to comment.