Skip to content

Commit

Permalink
add _load_model in pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
jiqing-feng committed Jan 9, 2024
1 parent 8394d41 commit 39b7804
Showing 1 changed file with 8 additions and 3 deletions.
11 changes: 8 additions & 3 deletions optimum/intel/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,11 @@ def clean_custom_task(task_info):
return task_info, None


def _load_model(task, model, **kwargs):
if task == "text-generation":
return TSModelForCausalLM.from_pretrained(model, **kwargs)


def pipeline(
task: str = None,
model: Optional[Union[str, "PreTrainedModel"]] = None,
Expand Down Expand Up @@ -262,8 +267,8 @@ def pipeline(
"Please provide a task class or a model"
)

if task != "text-generation":
raise ValueError("Optimum-intel ipex optimization only supports text-generation task for now.")
if task not in SUPPORTED_TASKS.keys():
raise ValueError(f"Optimum-intel ipex optimization only supports {SUPPORTED_TASKS.keys()} task for now.")

if model is None and tokenizer is not None:
raise RuntimeError(
Expand Down Expand Up @@ -358,7 +363,7 @@ def pipeline(
# Load the correct model if possible
# Infer the framework from the model if not already defined
if isinstance(model, str):
model = TSModelForCausalLM.from_pretrained(model, config=config, export=True, **model_kwargs)
model = _load_model(task, model, config=config, export=True, **model_kwargs)

model_config = model.config
hub_kwargs["_commit_hash"] = model.config._commit_hash
Expand Down

0 comments on commit 39b7804

Please sign in to comment.