Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: cumsum add_constant bug fix (add dtype for np zeros) #3258

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,7 @@ def cumsum(
)
else:
new_dims = tuple(data.shape)
zeros = np.zeros(new_dims)
zeros = np.zeros(new_dims, dtype=np.float32)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this dtype be dependent on input dtype or always float32 ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using np.float32 for the input works fine regardless of the input type.

However, if we trace the root cause of the error, we can see that in this line, the name truncate_double is used, but we are passing truncate_long_and_double as an argument to torch.compile, as shown here.

Because of this, at this point, the truncate_long_and_double argument is not handled and then removed, which leads to an error when trying to process the default type float64 of np.zeros.

According to this section, torch_tensorrt.dynamo.compile prefers truncate_double as the input but can also handle truncate_long_and_double.

What would be the best way to fix this issue?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Change the truncate_long_and_double to truncate_double in this example

"truncate_long_and_double": True,
and that should work right ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

truncate_long_and_double is deprecated. But it looks like it is not correctly handled if user provides this argument in torch.compile workflow. Can you add this check in

valid_attrs = {attr.name for attr in fields(settings)}
? similar to https://github.com/pytorch/TensorRT/blob/main/py/torch_tensorrt/dynamo/_compiler.py#L180-L185

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@peri044 Using float32 ensures compatibility, even with varying inputs, as it is the most commonly used data type. Additionally, I added exception handling for cases where the deprecated truncated_long_and_double argument might still be used.

zero_trttensor = get_trt_tensor(ctx, zeros, f"{name}_initial_value")

running_sum = loop.add_recurrence(zero_trttensor)
Expand Down
26 changes: 24 additions & 2 deletions py/torch_tensorrt/dynamo/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import logging
import warnings
from dataclasses import fields, replace
from enum import Enum
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
Expand All @@ -10,6 +11,8 @@
import tensorrt as trt
import torch
from torch._subclasses.fake_tensor import FakeTensor

from packaging import version
from torch_tensorrt._Device import Device
from torch_tensorrt._enums import dtype
from torch_tensorrt._features import ENABLED_FEATURES
Expand All @@ -19,8 +22,6 @@
from torch_tensorrt.dynamo._engine_cache import BaseEngineCache
from torch_tensorrt.dynamo._settings import CompilationSettings

from packaging import version

from .types import TRTDataType

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -494,6 +495,27 @@ def parse_dynamo_kwargs(
if "options" in kwargs and len(kwargs) == 1:
kwargs = kwargs["options"]

if "truncate_long_and_double" in kwargs:
if (
"truncate_double" in kwargs
and kwargs["truncate_double"] is not _defaults.TRUNCATE_DOUBLE
):
raise ValueError(
'Provided configuration for "truncate_double" and deprecated API "truncate_long_and_double". '
'Please only use "truncate_double".'
)
else:
kwargs["truncate_double"] = kwargs["truncate_long_and_double"]
warnings.warn(
'Compiler option "truncate_long_and_double" is deprecated in favor of "truncate_double" as int64 is now natively supported. '
"This option will be removed in the next version.",
DeprecationWarning,
stacklevel=2,
)
del kwargs[
"truncate_long_and_double"
] # Remove deprecated key after handling

valid_attrs = {attr.name for attr in fields(settings)}
valid_kwargs = {k: v for k, v in kwargs.items() if k in valid_attrs}
settings = replace(settings, **valid_kwargs)
Expand Down
Loading