Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix when attention_mask=None #1067

Merged
merged 4 commits into from
Dec 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions optimum/exporters/ipex/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def _llama_model_forward(
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
)
position_ids = position_ids.unsqueeze(0)
position_ids = position_ids.unsqueeze(0).repeat_interleave(input_ids.shape[0], 0)

if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
Expand Down Expand Up @@ -297,7 +297,7 @@ def _falcon_model_forward(
)

if position_ids is None:
position_ids = cache_position.unsqueeze(0)
position_ids = cache_position.unsqueeze(0).repeat_interleave(input_ids.shape[0], 0)

# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
Expand Down Expand Up @@ -419,7 +419,7 @@ def _gpt2_model_forward(
past_length = past_key_values.get_seq_length() if past_key_values is not None else 0
if position_ids is None:
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0)
position_ids = position_ids.unsqueeze(0).repeat_interleave(input_ids.shape[0], 0)

if inputs_embeds is None:
inputs_embeds = self.wte(input_ids)
Expand Down
2 changes: 2 additions & 0 deletions optimum/intel/ipex/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,8 @@ def forward(
attention_mask: Optional[torch.FloatTensor] = None,
**kwargs,
) -> CausalLMOutputWithPast:
if self.add_patch and input_ids is not None and attention_mask is None:
attention_mask = torch.ones_like(input_ids)
return self.model(input_ids=input_ids, attention_mask=attention_mask, **kwargs)

def _prepare_generation_config(
Expand Down
33 changes: 32 additions & 1 deletion tests/ipex/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,6 @@ def test_compare_to_transformers(self, model_arch):
model_id = MODEL_NAMES[model_arch]
set_seed(SEED)
dtype = torch.float16 if IS_XPU_AVAILABLE else torch.float32
# Test model forward do not need cache.
ipex_model = IPEXModelForCausalLM.from_pretrained(model_id, torch_dtype=dtype, device_map=DEVICE)
self.assertIsInstance(ipex_model.config, PretrainedConfig)
tokenizer = AutoTokenizer.from_pretrained(model_id)
Expand Down Expand Up @@ -275,6 +274,38 @@ def test_compare_to_transformers(self, model_arch):
self.assertTrue(torch.allclose(outputs.logits, loaded_model_outputs.logits, atol=1e-7))
self.assertTrue(torch.allclose(outputs.logits, init_model_outputs.logits, atol=1e-7))

@parameterized.expand(SUPPORTED_ARCHITECTURES)
def test_forward(self, model_arch):
model_id = MODEL_NAMES[model_arch]
set_seed(SEED)
dtype = torch.float16 if IS_XPU_AVAILABLE else torch.float32
ipex_model = IPEXModelForCausalLM.from_pretrained(model_id, torch_dtype=dtype, device_map=DEVICE)
self.assertIsInstance(ipex_model.config, PretrainedConfig)
input_ids = torch.Tensor([[1, 2, 3], [4, 5, 6]]).to(torch.long)
outputs = ipex_model(input_ids)

self.assertIsInstance(outputs.logits, torch.Tensor)

transformers_model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=dtype, device_map=DEVICE)
with torch.no_grad():
transformers_outputs = transformers_model(input_ids)

# Test re-load model
with tempfile.TemporaryDirectory() as tmpdirname:
ipex_model.save_pretrained(tmpdirname)
loaded_model = self.IPEX_MODEL_CLASS.from_pretrained(tmpdirname, torch_dtype=dtype, device_map=DEVICE)
loaded_model_outputs = loaded_model(input_ids)

# Test init method
init_model = self.IPEX_MODEL_CLASS(transformers_model)
init_model_outputs = init_model(input_ids)

# Compare tensor outputs
self.assertTrue(torch.allclose(outputs.logits, transformers_outputs.logits, atol=1e-4))
# To avoid float pointing error
self.assertTrue(torch.allclose(outputs.logits, loaded_model_outputs.logits, atol=1e-7))
self.assertTrue(torch.allclose(outputs.logits, init_model_outputs.logits, atol=1e-7))

@parameterized.expand(SUPPORTED_ARCHITECTURES)
def test_pipeline(self, model_arch):
dtype = torch.float16 if IS_XPU_AVAILABLE else torch.float32
Expand Down