diff --git a/optimum/intel/generation/modeling.py b/optimum/intel/generation/modeling.py index 40b6746465..fd946ea607 100644 --- a/optimum/intel/generation/modeling.py +++ b/optimum/intel/generation/modeling.py @@ -44,22 +44,20 @@ logger = logging.getLogger(__name__) -def get_float_type(torch_dtype: torch.dtype): - if torch_dtype == torch.bfloat16: +def get_float_type(model_dtype: torch.dtype): + if model_dtype == torch.bfloat16: return "bf16" - elif torch_dtype == torch.float16: + elif model_dtype == torch.float16: return "fp16" - elif torch_dtype == torch.float32: - return "fp32" else: - raise ValueError("torch_dtype should be in bf16, fp16 or fp32") + return "fp32" def prepare_jit_inputs(model: PreTrainedModel, task: str, use_cache: bool = False): task = _TASK_ALIASES.get(task, task) signature = inspect.signature(model.forward) if hasattr(model, "forward") else inspect.signature(model.__call__) onnx_config_class = TasksManager.get_exporter_config_constructor(model=model, exporter="onnx", task=task) - float_dtype = get_float_type(model.config.torch_dtype) + float_dtype = get_float_type(model.dtype) if "text-generation" in task: onnx_config = onnx_config_class( model.config, use_past=use_cache, use_past_in_inputs=use_cache, float_dtype=float_dtype