Skip to content
This repository has been archived by the owner on Oct 25, 2024. It is now read-only.

Commit

Permalink
catch prepack error and fallback tor torch bf16 (#1526)
Browse files Browse the repository at this point in the history
Co-authored-by: VincyZhang <[email protected]>
  • Loading branch information
Spycsh and VincyZhang authored Jun 7, 2024
1 parent ba199dc commit 3f492c4
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 9 deletions.
28 changes: 21 additions & 7 deletions intel_extension_for_transformers/neural_chat/models/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -839,13 +839,27 @@ def load_model(
import intel_extension_for_pytorch as intel_ipex

if not use_tpp:
model = intel_ipex.optimize(
model.eval(),
dtype=torch_dtype,
inplace=True,
level="O1",
auto_kernel_selection=True,
)
try:
model = intel_ipex.optimize(
model.eval(),
dtype=torch_dtype,
inplace=True,
level="O1",
auto_kernel_selection=True,
)
except AssertionError:
model = intel_ipex.optimize(
model.eval(),
dtype=torch_dtype,
inplace=True,
level="O1",
auto_kernel_selection=True,
weights_prepack=False,
)
except Exception as e:
logging.info(f"IPEX optimize failure! Skip IPEX.")
model = model.eval()

if cpu_jit and (re.search("mpt-7b", model_name, re.IGNORECASE)
or re.search("neural-chat-7b-v1", model_name, re.IGNORECASE)):
from intel_extension_for_transformers.transformers.llm.utils.mpt_trace import \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,15 @@ def __init__(self,
import torch
import intel_extension_for_pytorch as ipex
if precision == "bf16" and CpuInfo().bf16:
self.embeddings.client = ipex.optimize(
self.embeddings.client.eval(), dtype=torch.bfloat16, inplace=True)
try:
self.embeddings.client = ipex.optimize(
self.embeddings.client.eval(), dtype=torch.bfloat16, inplace=True)
except AssertionError:
self.embeddings.client = ipex.optimize(
self.embeddings.client.eval(), dtype=torch.bfloat16, inplace=True, weights_prepack=False)
except Exception as e:
logging.info(f"IPEX optimize failure! Skip IPEX.")
self.embeddings.client = self.embeddings.client.eval()
elif precision == "fp32":
self.embeddings.client = ipex.optimize(
self.embeddings.client.eval(), dtype=torch.float32, inplace=True)
Expand Down

0 comments on commit 3f492c4

Please sign in to comment.