Skip to content

Commit

Permalink
update ipex
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed Dec 11, 2024
1 parent d1bba1f commit 71f1100
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 23 deletions.
26 changes: 7 additions & 19 deletions optimum_benchmark/backends/ipex/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,41 +45,29 @@ def load(self) -> None:

self.tmpdir.cleanup()

def _load_automodel_from_pretrained(self) -> None:
self.pretrained_model = self.automodel_loader.from_pretrained(self.config.model, **self.config.model_kwargs)

def _load_automodel_with_no_weights(self) -> None:
original_model, self.config.model = self.config.model, self.no_weights_model

with fast_weights_init():
self._load_automodel_from_pretrained()

self.logger.info("\t+ Tying model weights")
self.pretrained_model.tie_weights()

self.config.model = original_model

def _load_ipexmodel_from_pretrained(self) -> None:
self.pretrained_model = self.ipexmodel_class.from_pretrained(
self.config.model,
export=self.config.export,
**self.config.model_kwargs,
**self.automodel_kwargs,
**self.ipexmodel_kwargs,
)

def _load_ipexmodel_with_no_weights(self) -> None:
with fast_weights_init():
self.logger.info("\t+ Loading no weights IPEXModel")
original_model, self.config.model = self.config.model, self.no_weights_model
original_export, self.config.export = self.config.export, True
self.logger.info("\t+ Loading no weights IPEXModel")
self._load_ipexmodel_from_pretrained()
self.config.export = original_export
self.config.model = original_model

@property
def automodel_kwargs(self) -> Dict[str, Any]:
def ipexmodel_kwargs(self) -> Dict[str, Any]:
kwargs = {}

if self.config.export:
kwargs["export"] = self.config.export

if self.config.torch_dtype is not None:
kwargs["torch_dtype"] = getattr(torch, self.config.torch_dtype)

Expand All @@ -89,7 +77,7 @@ def automodel_kwargs(self) -> Dict[str, Any]:
def split_between_processes(self) -> bool:
return is_torch_distributed_available() and torch.distributed.is_initialized()

def prepare_inputs(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
def prepare_inputs_before_load(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
if self.split_between_processes:
with Accelerator().split_between_processes(inputs=inputs, apply_padding=False) as process_inputs:
inputs = process_inputs
Expand Down
8 changes: 4 additions & 4 deletions optimum_benchmark/backends/ipex/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,17 @@ class IPEXConfig(BackendConfig):
version: Optional[str] = ipex_version()
_target_: str = "optimum_benchmark.backends.ipex.backend.IPEXBackend"

# load options
no_weights: bool = False
torch_dtype: Optional[str] = None

# export options
export: bool = True
# ipexmodel kwargs
export: Optional[bool] = None
torch_dtype: Optional[str] = None

def __post_init__(self):
super().__post_init__()

self.device = self.device.lower()

if self.device not in ["cpu", "gpu"]:
raise ValueError(f"IPEXBackend only supports CPU devices, got {self.device}")

Expand Down

0 comments on commit 71f1100

Please sign in to comment.