Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ipex Page attn xpu support bug fix #1053

Merged
merged 8 commits into from
Dec 9, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 13 additions & 9 deletions optimum/exporters/ipex/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -744,13 +744,13 @@ def __init__(self, module, config) -> None:
super().__init__()
_setattr_from_module(self, module)
self.config = config
self.module_device = next(module.parameters()).device.type
if self.module_device == "cpu":
self.module_device = next(module.parameters()).device
if self.module_device.type == "cpu":
# LinearAllreduce and LinearLayer cannot use fused op LinearAdd
if module.down_proj.__class__.__name__ not in ["LinearAllreduce"]:
self.mlp_linear_add = LinearAdd(module.down_proj)
self.linear_silu_mul = Linear2SiluMul(module.gate_proj, module.up_proj)
elif self.module_device == "xpu":
elif self.module_device.type == "xpu":
# LinearAllreduce and LinearLayer cannot use fused op LinearAdd
if module.down_proj.__class__.__name__ not in ["LinearAllreduce"]:
self.mlp_linear_add = XPULinearAdd(module.down_proj)
Expand All @@ -777,15 +777,15 @@ def __init__(self, module, config) -> None:
_setattr_from_module(self, module)
self.config = config
# LinearAllreduce and LinearLayer cannot use fused op LinearAdd
self.module_device = next(module.parameters()).device.type
if self.module_device == "cpu":
self.module_device = next(module.parameters()).device
if self.module_device.type == "cpu":
self.linear_gelu = LinearGelu(module.dense_h_to_4h)
elif self.module_device == "xpu":
elif self.module_device.type == "xpu":
self.linear_gelu = XPULinearGelu(module.dense_h_to_4h)
if module.dense_4h_to_h.__class__.__name__ not in ["LinearAllreduce"]:
if self.module_device == "cpu":
if self.module_device.type == "cpu":
self.linear_add_add = LinearAddAdd(module.dense_4h_to_h)
elif self.module_device == "xpu":
elif self.module_device.type == "xpu":
self.linear_add_add = XPUlinearAddAdd(module.dense_4h_to_h)

def forward(
Expand Down Expand Up @@ -870,7 +870,11 @@ class _IPEXIntermediate(nn.Module):
def __init__(self, module, config):
super().__init__()
_setattr_from_module(self, module)
self.linear_gelu = LinearGelu(module.dense)
self.module_device = next(module.parameters()).device
if self.module_device.type == "cpu":
self.linear_gelu = LinearGelu(module.dense)
elif self.module_device.type == "xpu":
self.linear_gelu = XPULinearGelu(module.dense)

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.linear_gelu(hidden_states)
Expand Down
9 changes: 7 additions & 2 deletions optimum/intel/pipelines/pipeline_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,14 +246,17 @@ def load_ipex_model(
SUPPORTED_TASKS,
hub_kwargs: Optional[Dict[str, Any]] = None,
model_kwargs: Optional[Dict[str, Any]] = None,
device_map: Optional[torch.device] = None,
):
hub_kwargs = hub_kwargs or {}
model_kwargs = model_kwargs or {}
ipex_model_class = SUPPORTED_TASKS[targeted_task]["class"][0]

if model is None:
model_id = SUPPORTED_TASKS[targeted_task]["default"]
model = ipex_model_class.from_pretrained(model_id, export=True, **hub_kwargs, **model_kwargs)
model = ipex_model_class.from_pretrained(
model_id, export=True, **hub_kwargs, **model_kwargs, device_map=device_map
)
elif isinstance(model, str):
model_id = model
try:
Expand All @@ -262,7 +265,9 @@ def load_ipex_model(
except RuntimeError:
logger.warning("We will use IPEXModel with export=True to export the model")
export = True
model = ipex_model_class.from_pretrained(model, export=export, **hub_kwargs, **model_kwargs)
model = ipex_model_class.from_pretrained(
model, export=export, **hub_kwargs, **model_kwargs, device_map=device_map
)
elif isinstance(model, IPEXModel):
model_id = getattr(model.config, "name_or_path", None)
else:
Expand Down
Loading
Loading