Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/main' into penghuic/weight_onl…
Browse files Browse the repository at this point in the history
…y_with_itrex
  • Loading branch information
PenghuiCheng committed Oct 23, 2023
2 parents f5363e7 + b7703dc commit 4011612
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 3 deletions.
46 changes: 43 additions & 3 deletions optimum/intel/openvino/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,34 @@ def _from_transformers(
if use_cache:
task = task + "-with-past"

# Patch the modules to export of GPTQ models w/o GPU
do_gptq_patching = False
config_dict = config.to_dict()
quantization_config = config_dict.get("quantization_config", None)
do_gptq_patching = quantization_config and quantization_config["quant_method"] == "gptq"
if do_gptq_patching:
torch.set_default_dtype(torch.float32)
orig_cuda_check = torch.cuda.is_available
torch.cuda.is_available = lambda: True

from optimum.gptq import GPTQQuantizer

orig_post_init_model = GPTQQuantizer.post_init_model

def post_init_model(self, model):
from auto_gptq import exllama_set_max_input_length

class StoreAttr(object):
pass

model.quantize_config = StoreAttr()
model.quantize_config.desc_act = self.desc_act
if self.desc_act and not self.disable_exllama and self.max_input_length is not None:
model = exllama_set_max_input_length(model, self.max_input_length)
return model

GPTQQuantizer.post_init_model = post_init_model

main_export(
model_name_or_path=model_id,
output=save_dir_path,
Expand All @@ -238,10 +266,14 @@ def _from_transformers(
local_files_only=local_files_only,
force_download=force_download,
trust_remote_code=trust_remote_code,
model_kwargs=kwargs,
int8=load_in_8bit,
)

# Unpatch modules after GPTQ export
if do_gptq_patching:
torch.cuda.is_available = orig_cuda_check
GPTQQuantizer.post_init_model = orig_post_init_model

config.is_decoder = True
config.is_encoder_decoder = False
config.save_pretrained(save_dir_path)
Expand Down Expand Up @@ -320,7 +352,10 @@ def forward(
input_ids = input_ids[:, -1:]

inputs = {}
past_len = 0
if past_key_values is not None:
seq_len_dim = 1 if self.model.input(self.key_value_input_names[0]).get_partial_shape()[1].is_dynamic else 2
past_len = past_key_values[0][0].shape[seq_len_dim]
if self._pkv_precision == Type.bf16:
# numpy does not support bf16, pretending f16, should change to bf16
past_key_values = tuple(
Expand Down Expand Up @@ -355,8 +390,13 @@ def forward(
inputs["input_ids"] = np.array(input_ids)

# Add the attention_mask inputs when needed
if "attention_mask" in self.input_names and attention_mask is not None:
inputs["attention_mask"] = np.array(attention_mask)
if "attention_mask" in self.input_names:
if attention_mask is not None:
inputs["attention_mask"] = np.array(attention_mask)
else:
inputs["attention_mask"] = np.ones(
(input_ids.shape[0], input_ids.shape[1] + past_len), dtype=inputs["input_ids"].dtype
)

# Run inference
self.request.start_async(inputs, shared_memory=True)
Expand Down
23 changes: 23 additions & 0 deletions tests/openvino/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,6 +573,29 @@ def test_auto_device_loading(self):
del model
gc.collect()

def test_default_filling_attention_mask(self):
model_id = MODEL_NAMES["gpt2"]
model_with_cache = OVModelForCausalLM.from_pretrained(model_id, export=True, use_cache=True)
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
texts = ["this is a simple input"]
tokens = tokenizer(texts, return_tensors="pt")
self.assertTrue("attention_mask" in model_with_cache.input_names)
outs = model_with_cache(**tokens)
attention_mask = tokens.pop("attention_mask")
outs_without_attn_mask = model_with_cache(**tokens)
self.assertTrue(torch.allclose(outs.logits, outs_without_attn_mask.logits))
input_ids = torch.argmax(outs.logits, dim=2)
past_key_values = outs.past_key_values
attention_mask = torch.ones((input_ids.shape[0], tokens.input_ids.shape[1] + 1), dtype=torch.long)
outs_step2 = model_with_cache(
input_ids=input_ids, attention_mask=attention_mask, past_key_values=past_key_values
)
outs_without_attn_mask_step2 = model_with_cache(input_ids=input_ids, past_key_values=past_key_values)
self.assertTrue(torch.allclose(outs_step2.logits, outs_without_attn_mask_step2.logits))
del model_with_cache
gc.collect()


class OVModelForMaskedLMIntegrationTest(unittest.TestCase):
SUPPORTED_ARCHITECTURES = (
Expand Down

0 comments on commit 4011612

Please sign in to comment.