Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: cumsum add_constant bug fix (add dtype for np zeros) #3258

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

chohk88
Copy link
Collaborator

@chohk88 chohk88 commented Oct 22, 2024

Description

When compiling the roberta-base model from Hugging Face (https://huggingface.co/FacebookAI/roberta-base), a TypeError occurs in the cumsum operation. For static shape input, the default datatype of np.zeros(new_dims) function is np.float64 which is not handled properly by the create_constant utility function.

Fixes # (issue)

  File "/usr/local/lib/python3.10/dist-packages/torch_tensorrt/dynamo/conversion/aten_ops_converters.py", line 934, in aten_ops_cumsum
    return impl.slice.cumsum(
  File "/usr/local/lib/python3.10/dist-packages/torch_tensorrt/dynamo/conversion/impl/slice/ops.py", line 387, in cumsum
    zero_trttensor = get_trt_tensor(ctx, zeros, f"{name}_initial_value")
  File "/usr/local/lib/python3.10/dist-packages/torch_tensorrt/dynamo/conversion/converter_utils.py", line 388, in get_trt_tensor
    return create_constant(ctx, input_val, name, dtype, min_rank)
  File "/usr/local/lib/python3.10/dist-packages/torch_tensorrt/dynamo/conversion/converter_utils.py", line 349, in create_constant
    constant = ctx.net.add_constant(
torch._dynamo.exc.BackendCompilerFailed: backend='torch_tensorrt' raised:
TypeError: add_constant(): incompatible function arguments. The following argument types are supported:
    1. (self: tensorrt.tensorrt.INetworkDefinition, shape: tensorrt.tensorrt.Dims, weights: tensorrt.tensorrt.Weights) -> tensorrt.tensorrt.IConstantLayer
Invoked with: <tensorrt.tensorrt.INetworkDefinition object at 0x7fee84ebd770>, (1,), array([0.])

Reproduction Code:

# https://huggingface.co/FacebookAI/roberta-base
import torch
from transformers import RobertaTokenizer, RobertaModel
import torch_tensorrt

backend = "torch_tensorrt"
device = "cuda:0"

# Load tokenizer and model
tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
model = RobertaModel.from_pretrained('roberta-base')
model = model.to(device)

# Tokenize input text
text = "Replace me by any text you'd like."
encoded_input = tokenizer(text, return_tensors='pt')
encoded_input = {k: v.to(device) for k, v in encoded_input.items()} 

# Compile model with Torch-TensorRT
model = torch.compile(
    model,
    backend=backend,
    options={
        "truncate_long_and_double": True,
        "enabled_precisions": {torch.float16},
    },
    dynamic=False,
)

# Run inference
output = model(**encoded_input)
print(output)

Type of change

  • Bug fix (non-breaking change which fixes an issue)

Checklist:

  • My code follows the style guidelines of this project (You can use the linters)
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas and hacks
  • I have made corresponding changes to the documentation
  • I have added tests to verify my fix or my feature
  • New and existing unit tests pass locally with my changes
  • I have added the relevant labels to my PR in so that relevant reviewers are notified

@github-actions github-actions bot added component: conversion Issues re: Conversion stage component: converters Issues re: Specific op converters component: api [Python] Issues re: Python API component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths labels Oct 22, 2024
@chohk88 chohk88 self-assigned this Oct 22, 2024
@chohk88 chohk88 linked an issue Oct 22, 2024 that may be closed by this pull request
@@ -370,7 +370,7 @@ def cumsum(
)
else:
new_dims = tuple(data.shape)
zeros = np.zeros(new_dims)
zeros = np.zeros(new_dims, dtype=np.float32)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this dtype be dependent on input dtype or always float32 ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using np.float32 for the input works fine regardless of the input type.

However, if we trace the root cause of the error, we can see that in this line, the name truncate_double is used, but we are passing truncate_long_and_double as an argument to torch.compile, as shown here.

Because of this, at this point, the truncate_long_and_double argument is not handled and then removed, which leads to an error when trying to process the default type float64 of np.zeros.

According to this section, torch_tensorrt.dynamo.compile prefers truncate_double as the input but can also handle truncate_long_and_double.

What would be the best way to fix this issue?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Change the truncate_long_and_double to truncate_double in this example

"truncate_long_and_double": True,
and that should work right ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

truncate_long_and_double is deprecated. But it looks like it is not correctly handled if user provides this argument in torch.compile workflow. Can you add this check in

valid_attrs = {attr.name for attr in fields(settings)}
? similar to https://github.com/pytorch/TensorRT/blob/main/py/torch_tensorrt/dynamo/_compiler.py#L180-L185

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@peri044 Using float32 ensures compatibility, even with varying inputs, as it is the most commonly used data type. Additionally, I added exception handling for cases where the deprecated truncated_long_and_double argument might still be used.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed component: api [Python] Issues re: Python API component: conversion Issues re: Conversion stage component: converters Issues re: Specific op converters component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Coverage] Type Error for torch.ops.aten.cumsum.default
3 participants