Skip to content
This repository has been archived by the owner on Oct 25, 2024. It is now read-only.

Commit

Permalink
Improve SQ model restored from json (#1600)
Browse files Browse the repository at this point in the history
Signed-off-by: Wang, Chang <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
changwangss and pre-commit-ci[bot] authored Jun 11, 2024
1 parent 14734de commit 44a24ec
Showing 1 changed file with 6 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,8 @@ def recover_model_from_json(fp32_model_name_or_path, json_file_path, trust_remot
(object): quantized model
"""
from transformers import AutoModelForCausalLM

# ipex recovered int8 model from configure.json requests float32 model input and on cpu device.
user_model = AutoModelForCausalLM.from_pretrained(fp32_model_name_or_path,
trust_remote_code=trust_remote_code).float()
if user_model.config.model_type in IPEX_OPT_LLM_SUPPORTED:
Expand Down Expand Up @@ -496,5 +498,9 @@ def recover_model_from_json(fp32_model_name_or_path, json_file_path, trust_remot
from intel_extension_for_transformers.transformers.llm.evaluation.models import (
TSModelCausalLMForITREX,
)
origin_model_type = config.model_type
if origin_model_type in ["chatglm", "qwen", "baichuan"]:
config.model_type = "qwen2"
user_model = TSModelCausalLMForITREX(user_model, config=config)
user_model.config.model_type = origin_model_type
return user_model

0 comments on commit 44a24ec

Please sign in to comment.