diff --git a/optimum/intel/generation/modeling.py b/optimum/intel/generation/modeling.py index bbfc3db63d..fd946ea607 100644 --- a/optimum/intel/generation/modeling.py +++ b/optimum/intel/generation/modeling.py @@ -44,12 +44,24 @@ logger = logging.getLogger(__name__) +def get_float_type(model_dtype: torch.dtype): + if model_dtype == torch.bfloat16: + return "bf16" + elif model_dtype == torch.float16: + return "fp16" + else: + return "fp32" + + def prepare_jit_inputs(model: PreTrainedModel, task: str, use_cache: bool = False): task = _TASK_ALIASES.get(task, task) signature = inspect.signature(model.forward) if hasattr(model, "forward") else inspect.signature(model.__call__) onnx_config_class = TasksManager.get_exporter_config_constructor(model=model, exporter="onnx", task=task) + float_dtype = get_float_type(model.dtype) if "text-generation" in task: - onnx_config = onnx_config_class(model.config, use_past=use_cache, use_past_in_inputs=use_cache) + onnx_config = onnx_config_class( + model.config, use_past=use_cache, use_past_in_inputs=use_cache, float_dtype=float_dtype + ) else: onnx_config = onnx_config_class(model.config)