diff --git a/py/torch_tensorrt/dynamo/utils.py b/py/torch_tensorrt/dynamo/utils.py index 85aa663809..95e5f30e4d 100644 --- a/py/torch_tensorrt/dynamo/utils.py +++ b/py/torch_tensorrt/dynamo/utils.py @@ -1,6 +1,7 @@ from __future__ import annotations import logging +import warnings from dataclasses import fields, replace from enum import Enum from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union @@ -10,6 +11,8 @@ import tensorrt as trt import torch from torch._subclasses.fake_tensor import FakeTensor + +from packaging import version from torch_tensorrt._Device import Device from torch_tensorrt._enums import dtype from torch_tensorrt._features import ENABLED_FEATURES @@ -19,8 +22,6 @@ from torch_tensorrt.dynamo._engine_cache import BaseEngineCache from torch_tensorrt.dynamo._settings import CompilationSettings -from packaging import version - from .types import TRTDataType logger = logging.getLogger(__name__) @@ -494,6 +495,27 @@ def parse_dynamo_kwargs( if "options" in kwargs and len(kwargs) == 1: kwargs = kwargs["options"] + if "truncate_long_and_double" in kwargs: + if ( + "truncate_double" in kwargs + and kwargs["truncate_double"] is not _defaults.TRUNCATE_DOUBLE + ): + raise ValueError( + 'Provided configuration for "truncate_double" and deprecated API "truncate_long_and_double". ' + 'Please only use "truncate_double".' + ) + else: + kwargs["truncate_double"] = kwargs["truncate_long_and_double"] + warnings.warn( + 'Compiler option "truncate_long_and_double" is deprecated in favor of "truncate_double" as int64 is now natively supported. ' + "This option will be removed in the next version.", + DeprecationWarning, + stacklevel=2, + ) + del kwargs[ + "truncate_long_and_double" + ] # Remove deprecated key after handling + valid_attrs = {attr.name for attr in fields(settings)} valid_kwargs = {k: v for k, v in kwargs.items() if k in valid_attrs} settings = replace(settings, **valid_kwargs)