Skip to content

Commit

Permalink
disable use_cache when just run forward
Browse files Browse the repository at this point in the history
  • Loading branch information
jiqing-feng committed Nov 26, 2024
1 parent 26c8ff4 commit af369ed
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 13 deletions.
9 changes: 1 addition & 8 deletions optimum/intel/ipex/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,16 +308,9 @@ def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
**kwargs,
) -> CausalLMOutputWithPast:
if self.use_cache and past_key_values is None and self._add_patch:
max_length = self.config.max_length + input_ids.shape[1]
batch_size = input_ids.shape[0]
past_key_values = IPEXPagedCache(self.config, batch_size, max_length, input_ids.device, dtype=self.dtype)
return self.model(
input_ids=input_ids, attention_mask=attention_mask, past_key_values=past_key_values, **kwargs
)
return self.model(input_ids=input_ids, attention_mask=attention_mask, **kwargs)

def _prepare_generation_config(
self, generation_config: Optional[GenerationConfig], **kwargs: Dict
Expand Down
13 changes: 8 additions & 5 deletions tests/ipex/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,10 +223,11 @@ def test_compare_to_transformers(self, model_arch):
dtype = torch.float32
if IS_XPU:
dtype = torch.float16
ipex_model = IPEXModelForCausalLM.from_pretrained(model_id, export=True, torch_dtype=dtype)
# Test model forward do not need cache.
ipex_model = IPEXModelForCausalLM.from_pretrained(model_id, export=True, torch_dtype=dtype, use_cache=False)
device = ipex_model.device
self.assertIsInstance(ipex_model.config, PretrainedConfig)
self.assertTrue(ipex_model.use_cache)
self.assertFalse(ipex_model.use_cache)
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokens = tokenizer(
"This is a sample",
Expand All @@ -238,18 +239,20 @@ def test_compare_to_transformers(self, model_arch):

self.assertIsInstance(outputs.logits, torch.Tensor)

transformers_model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=dtype).to(device)
transformers_model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=dtype, use_cache=False).to(
device
)
with torch.no_grad():
transformers_outputs = transformers_model(**tokens)

# Test re-load model
with tempfile.TemporaryDirectory() as tmpdirname:
ipex_model.save_pretrained(tmpdirname)
loaded_model = self.IPEX_MODEL_CLASS.from_pretrained(tmpdirname, torch_dtype=dtype)
loaded_model = self.IPEX_MODEL_CLASS.from_pretrained(tmpdirname, torch_dtype=dtype, use_cache=False)
loaded_model_outputs = loaded_model(**inputs)

# Test init method
init_model = self.IPEX_MODEL_CLASS(transformers_model, export=True)
init_model = self.IPEX_MODEL_CLASS(transformers_model, export=True, use_cache=False)
init_model_outputs = init_model(**inputs)

# Compare tensor outputs
Expand Down

0 comments on commit af369ed

Please sign in to comment.