Skip to content

Commit

Permalink
Handling deprecated truncated_long_and_double
Browse files Browse the repository at this point in the history
  • Loading branch information
Hoonkyung Cho committed Nov 21, 2024
1 parent cc2016a commit 59c05af
Showing 1 changed file with 24 additions and 2 deletions.
26 changes: 24 additions & 2 deletions py/torch_tensorrt/dynamo/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import logging
import warnings
from dataclasses import fields, replace
from enum import Enum
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
Expand All @@ -10,6 +11,8 @@
import tensorrt as trt
import torch
from torch._subclasses.fake_tensor import FakeTensor

from packaging import version
from torch_tensorrt._Device import Device
from torch_tensorrt._enums import dtype
from torch_tensorrt._features import ENABLED_FEATURES
Expand All @@ -19,8 +22,6 @@
from torch_tensorrt.dynamo._engine_cache import BaseEngineCache
from torch_tensorrt.dynamo._settings import CompilationSettings

from packaging import version

from .types import TRTDataType

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -494,6 +495,27 @@ def parse_dynamo_kwargs(
if "options" in kwargs and len(kwargs) == 1:
kwargs = kwargs["options"]

if "truncate_long_and_double" in kwargs:
if (
"truncate_double" in kwargs
and kwargs["truncate_double"] is not _defaults.TRUNCATE_DOUBLE
):
raise ValueError(
'Provided configuration for "truncate_double" and deprecated API "truncate_long_and_double". '
'Please only use "truncate_double".'
)
else:
kwargs["truncate_double"] = kwargs["truncate_long_and_double"]
warnings.warn(
'Compiler option "truncate_long_and_double" is deprecated in favor of "truncate_double" as int64 is now natively supported. '
"This option will be removed in the next version.",
DeprecationWarning,
stacklevel=2,
)
del kwargs[
"truncate_long_and_double"
] # Remove deprecated key after handling

valid_attrs = {attr.name for attr in fields(settings)}
valid_kwargs = {k: v for k, v in kwargs.items() if k in valid_attrs}
settings = replace(settings, **valid_kwargs)
Expand Down

0 comments on commit 59c05af

Please sign in to comment.