From 3da80f6a2eb670be3dde8ef541291e8f35ce6655 Mon Sep 17 00:00:00 2001 From: jiqing-feng <107918818+jiqing-feng@users.noreply.github.com> Date: Wed, 6 Dec 2023 18:15:34 +0800 Subject: [PATCH] Fix pkv dtype (#481) * pkv dtype * fix dtype --- optimum/intel/generation/modeling.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/optimum/intel/generation/modeling.py b/optimum/intel/generation/modeling.py index bbfc3db63d..fd946ea607 100644 --- a/optimum/intel/generation/modeling.py +++ b/optimum/intel/generation/modeling.py @@ -44,12 +44,24 @@ logger = logging.getLogger(__name__) +def get_float_type(model_dtype: torch.dtype): + if model_dtype == torch.bfloat16: + return "bf16" + elif model_dtype == torch.float16: + return "fp16" + else: + return "fp32" + + def prepare_jit_inputs(model: PreTrainedModel, task: str, use_cache: bool = False): task = _TASK_ALIASES.get(task, task) signature = inspect.signature(model.forward) if hasattr(model, "forward") else inspect.signature(model.__call__) onnx_config_class = TasksManager.get_exporter_config_constructor(model=model, exporter="onnx", task=task) + float_dtype = get_float_type(model.dtype) if "text-generation" in task: - onnx_config = onnx_config_class(model.config, use_past=use_cache, use_past_in_inputs=use_cache) + onnx_config = onnx_config_class( + model.config, use_past=use_cache, use_past_in_inputs=use_cache, float_dtype=float_dtype + ) else: onnx_config = onnx_config_class(model.config)