diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 9acb750aed..ec3affd303 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -1551,6 +1551,23 @@ def aten_ops_isnan( ) +@dynamo_tensorrt_converter(torch.ops.aten._local_scalar_dense.default) +def aten_ops_local_scalar_dense( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.unary.local_scalar_dense( + ctx, + target, + SourceIR.ATEN, + name, + args[0], + ) + + @dynamo_tensorrt_converter(torch.ops.aten.add.Tensor) @dynamo_tensorrt_converter(torch.ops.aten.add.Scalar) def aten_ops_add( diff --git a/py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py index 9f2ad07612..234bd8cddd 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py @@ -554,3 +554,26 @@ def isnan( ) return nan_values_mask + + +def local_scalar_dense( + ctx: ConversionContext, + target: Union[Target, str], + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, +) -> TRTTensor: + start = [0] * len(input.shape) + shape = [1] * len(input.shape) # Get one element from each dimension + stride = [1] * len(input.shape) # Step through each dimension by 1 + + layer = ctx.net.add_slice(input=input, start=start, shape=shape, stride=stride) + set_layer_name(layer, target, f"{name}_slice", source_ir) + + reshape_layer = ctx.net.add_shuffle(layer.get_output(0)) + reshape_layer.reshape_dims = [ + 1, + ] # Reshape to a single-element tensor + set_layer_name(reshape_layer, target, f"{name}_reshape", source_ir) + + return reshape_layer.get_output(0) diff --git a/tests/py/dynamo/conversion/test_local_scalar_dense_aten.py b/tests/py/dynamo/conversion/test_local_scalar_dense_aten.py new file mode 100644 index 0000000000..4d22632789 --- /dev/null +++ b/tests/py/dynamo/conversion/test_local_scalar_dense_aten.py @@ -0,0 +1,65 @@ +import torch +import torch.nn as nn +from harness import DispatchTestCase +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests + + +class TestLocalScalarDenseConverter(DispatchTestCase): + @parameterized.expand( + [ + ( + torch.randn((5, 10, 5), dtype=torch.float32), + ), + ( + torch.randint(-10, 10, (5, 1, 15), dtype=torch.int32), + ), + ( + torch.randn((1), dtype=torch.float32), + ), + ( + (torch.tensor([-2.4])), + ), + ( + (torch.tensor([5.5, 3.5, 3.6])), + ), + ( + (torch.tensor([True])), + ), + ( + torch.tensor( + [ + float("nan"), + 1.23, + float("inf"), + ] + ), + ), + ( + torch.tensor( + [ + float("-inf"), + 1.23, + float("nan"), + ] + ), + ), + ( + (torch.tensor([float("inf")])), + ), + ] + ) + def test_local_scalar_dense(self, data): + class local_scalar_dense(nn.Module): + def forward(self, input): + return torch.ops.aten._local_scalar_dense.default(input) + + inputs = [data] + self.run_test( + local_scalar_dense(), + inputs, + ) + + +if __name__ == "__main__": + run_tests()