Skip to content

Commit

Permalink
fix dtype
Browse files Browse the repository at this point in the history
  • Loading branch information
jiqing-feng committed Nov 29, 2023
1 parent c8c2fed commit 96356c4
Showing 1 changed file with 5 additions and 7 deletions.
12 changes: 5 additions & 7 deletions optimum/intel/generation/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,22 +44,20 @@
logger = logging.getLogger(__name__)


def get_float_type(torch_dtype: torch.dtype):
if torch_dtype == torch.bfloat16:
def get_float_type(model_dtype: torch.dtype):
if model_dtype == torch.bfloat16:
return "bf16"
elif torch_dtype == torch.float16:
elif model_dtype == torch.float16:
return "fp16"
elif torch_dtype == torch.float32:
return "fp32"
else:
raise ValueError("torch_dtype should be in bf16, fp16 or fp32")
return "fp32"


def prepare_jit_inputs(model: PreTrainedModel, task: str, use_cache: bool = False):
task = _TASK_ALIASES.get(task, task)
signature = inspect.signature(model.forward) if hasattr(model, "forward") else inspect.signature(model.__call__)
onnx_config_class = TasksManager.get_exporter_config_constructor(model=model, exporter="onnx", task=task)
float_dtype = get_float_type(model.config.torch_dtype)
float_dtype = get_float_type(model.dtype)
if "text-generation" in task:
onnx_config = onnx_config_class(
model.config, use_past=use_cache, use_past_in_inputs=use_cache, float_dtype=float_dtype
Expand Down

0 comments on commit 96356c4

Please sign in to comment.