Skip to content

Commit

Permalink
upgrad minimum torch version to 2.5
Browse files Browse the repository at this point in the history
Signed-off-by: jiqing-feng <[email protected]>
  • Loading branch information
jiqing-feng committed Dec 13, 2024
1 parent 2c54045 commit bce9aa9
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 9 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test_ipex.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ jobs:
fail-fast: false
matrix:
transformers-version: ["4.46.0", "4.46.3"]
torch-version: ["2.4.0", "2.5.*"]
torch-version: ["2.5.*"]

runs-on: ubuntu-22.04

Expand Down
20 changes: 15 additions & 5 deletions optimum/exporters/ipex/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,7 @@

logger = logging.getLogger(__name__)

_IPEX_MINIMUM_VERSION_FOR_PATCHING = "2.4.0"
_IPEX_MINIMUM_VERSION_FOR_FLASH_VARLEN_ATTN = "2.5.0"
_IPEX_MINIMUM_VERSION_FOR_PATCHING = "2.5.0"


if is_ipex_version("<", _IPEX_MINIMUM_VERSION_FOR_PATCHING):
Expand Down Expand Up @@ -213,6 +212,8 @@ def _llama_model_forward(
position_embeddings = (cos.unsqueeze(1), sin.unsqueeze(1))
else:
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])

if past_key_values is None:
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask=attention_mask,
input_shape=(input_ids.shape[0], input_ids.shape[-1]),
Expand Down Expand Up @@ -334,6 +335,8 @@ def _falcon_model_forward(
position_embeddings = (cos.unsqueeze(1), sin.unsqueeze(1))
else:
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])

if past_key_values is None:
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask=attention_mask,
input_shape=(input_ids.shape[0], input_ids.shape[-1]),
Expand Down Expand Up @@ -463,6 +466,8 @@ def _gpt2_model_forward(
hidden_states = (hidden_states.view(-1, hidden_states.shape[-1]))[index]
else:
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])

if past_key_values is None:
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask=attention_mask,
input_shape=(input_ids.shape[0], input_ids.shape[-1]),
Expand Down Expand Up @@ -660,11 +665,16 @@ def forward(

if past_len == 0:
# prefill
if past_key_value is None or is_ipex_version("<", _IPEX_MINIMUM_VERSION_FOR_FLASH_VARLEN_ATTN):
if past_key_value is None:
n_rep = query.shape[1] // key.shape[1]
attn_output = torch.nn.functional.scaled_dot_product_attention(
query.reshape(input_lens.shape[0], input_lens.max().item(), -1, query.shape[-1]).transpose(1, 2),
key.reshape(input_lens.shape[0], input_lens.max().item(), -1, key.shape[-1]).transpose(1, 2),
value.reshape(input_lens.shape[0], input_lens.max().item(), -1, value.shape[-1]).transpose(1, 2),
key.reshape(input_lens.shape[0], input_lens.max().item(), -1, key.shape[-1])
.transpose(1, 2)
.repeat_interleave(n_rep, 1),
value.reshape(input_lens.shape[0], input_lens.max().item(), -1, value.shape[-1])
.transpose(1, 2)
.repeat_interleave(n_rep, 1),
attn_mask=attention_mask,
dropout_p=0.0,
is_causal=True,
Expand Down
4 changes: 2 additions & 2 deletions optimum/intel/ipex/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,9 +299,9 @@ def prepare_inputs_for_generation(self, *args, **kwargs):
return self.model.prepare_inputs_for_generation(*args, **kwargs)

def generate(self, *args, **kwargs):
if is_ipex_version("<", "2.4.0") and self._add_patch and kwargs.get("assistant_model", None):
if self._add_patch and kwargs.get("assistant_model", None):
raise ValueError(
f"Assisted decoding is not supported for patched models if ipex < 2.4, support methods are {_IPEX_EXPORTED_GENERATION_METHODS}"
f"Assisted decoding is not supported for patched models for now, support methods are {_IPEX_EXPORTED_GENERATION_METHODS}"
)
# Patch functions to support ipex_paged cache
if self._add_patch:
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@
"nncf": ["nncf>=2.14.0"],
"openvino": ["nncf>=2.14.0", "openvino>=2024.5.0", "openvino-tokenizers>=2024.5.0"],
"neural-compressor": ["neural-compressor[pt]>3.0", "accelerate", "transformers<4.46"],
"ipex": ["intel-extension-for-pytorch>=2.4", "transformers>4.45,<4.47"],
"ipex": ["intel-extension-for-pytorch>=2.5", "transformers>4.45,<4.47"],
"diffusers": ["diffusers"],
"quality": QUALITY_REQUIRE,
"tests": TESTS_REQUIRE,
Expand Down
3 changes: 3 additions & 0 deletions tests/ipex/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,9 @@ def test_compare_with_and_without_past_key_values(self):
outputs_model_without_pkv = model_without_pkv.generate(
**tokens, min_new_tokens=self.GENERATION_LENGTH, max_new_tokens=self.GENERATION_LENGTH, num_beams=1
)
import pdb

pdb.set_trace()
self.assertTrue(torch.equal(outputs_model_with_pkv, outputs_model_without_pkv))
self.assertEqual(outputs_model_with_pkv.shape[1], self.GENERATION_LENGTH + tokens.input_ids.shape[1])
self.assertEqual(outputs_model_without_pkv.shape[1], self.GENERATION_LENGTH + tokens.input_ids.shape[1])
Expand Down

0 comments on commit bce9aa9

Please sign in to comment.