From ca4aa2e369a278362220b148ef2bb72ad9a1aa31 Mon Sep 17 00:00:00 2001 From: changwangss Date: Wed, 26 Jun 2024 01:59:59 -0700 Subject: [PATCH] fix bloom Signed-off-by: changwangss --- .../transformers/llm/evaluation/models.py | 3 +-- .../transformers/modeling/modeling_auto.py | 6 +----- .../transformers/utils/utility.py | 6 +----- 3 files changed, 3 insertions(+), 12 deletions(-) diff --git a/intel_extension_for_transformers/transformers/llm/evaluation/models.py b/intel_extension_for_transformers/transformers/llm/evaluation/models.py index 98dc24e3673..61b301a380a 100644 --- a/intel_extension_for_transformers/transformers/llm/evaluation/models.py +++ b/intel_extension_for_transformers/transformers/llm/evaluation/models.py @@ -38,8 +38,7 @@ def _reorder_cache( This is required to match `past_key_values` with the correct beam_idx at every generation step. """ - if self.config.model_type == "bloom": - return self._reorder_cache_bloom(past_key_values, beam_idx) + if self.config.model_type == "chatglm": return tuple( tuple( diff --git a/intel_extension_for_transformers/transformers/modeling/modeling_auto.py b/intel_extension_for_transformers/transformers/modeling/modeling_auto.py index c0a9925494a..263e4784d92 100644 --- a/intel_extension_for_transformers/transformers/modeling/modeling_auto.py +++ b/intel_extension_for_transformers/transformers/modeling/modeling_auto.py @@ -946,11 +946,7 @@ def collate_batch(batch): ) last_ind.append(input_ids.shape[0] - 1) - if model_type in ["bloom"]: - attention_mask = torch.ones(len(input_ids) + 1) - attention_mask[0] = 0 - else: - attention_mask = torch.ones(len(input_ids)) + attention_mask = torch.ones(len(input_ids)) position_ids = torch.arange(len(input_ids)) input_ids_padded.append(input_ids) attention_mask_padded.append(attention_mask) diff --git a/intel_extension_for_transformers/transformers/utils/utility.py b/intel_extension_for_transformers/transformers/utils/utility.py index 2467531fab2..78fe5f2063d 100644 --- a/intel_extension_for_transformers/transformers/utils/utility.py +++ b/intel_extension_for_transformers/transformers/utils/utility.py @@ -375,11 +375,7 @@ def get_example_inputs(model_config, batch_size=1, tokenizer=None, num_beams=4): past_key_values = generate_dummy_past_key_values(config=model_config, input_bs=batch_size) input_ids = input_ids[:, :512] - if model_type in ["bloom", "qwen"]: - attention_mask = torch.ones(input_ids.shape[0], input_ids.shape[1] + 1) - attention_mask[:,0] = 0 - else: - attention_mask = torch.ones(input_ids.shape) + attention_mask = torch.ones(input_ids.shape) position_ids = torch.arange(input_ids.shape[1]).repeat(batch_size, 1) if model_type in MODEL_TYPES_REQUIRING_POSITION_IDS: