From 7c69c3147459d3f8141cdf6cb7070c1d328cc675 Mon Sep 17 00:00:00 2001 From: Hoonkyung Cho Date: Thu, 13 Jun 2024 08:27:10 +0900 Subject: [PATCH 1/3] feat: support aten._local_scalar_dense converter --- .../dynamo/conversion/aten_ops_converters.py | 17 +++++ .../dynamo/conversion/impl/unary/ops.py | 23 +++++++ .../test_local_scalar_dense_aten.py | 65 +++++++++++++++++++ 3 files changed, 105 insertions(+) create mode 100644 tests/py/dynamo/conversion/test_local_scalar_dense_aten.py diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 14c25ec8ab..066d0b2588 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -1714,6 +1714,23 @@ def aten_ops_isnan( ) +@dynamo_tensorrt_converter(torch.ops.aten._local_scalar_dense.default) +def aten_ops_local_scalar_dense( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.unary.local_scalar_dense( + ctx, + target, + SourceIR.ATEN, + name, + args[0], + ) + + @dynamo_tensorrt_converter(operator.add, supports_dynamic_shapes=True) @dynamo_tensorrt_converter(torch.ops.aten.add.Tensor, supports_dynamic_shapes=True) @dynamo_tensorrt_converter(torch.ops.aten.add.Scalar, supports_dynamic_shapes=True) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py index beb13fca9b..dbdbb332c5 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py @@ -571,3 +571,26 @@ def isnan( ) return nan_values_mask + + +def local_scalar_dense( + ctx: ConversionContext, + target: Union[Target, str], + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, +) -> TRTTensor: + start = [0] * len(input.shape) + shape = [1] * len(input.shape) # Get one element from each dimension + stride = [1] * len(input.shape) # Step through each dimension by 1 + + layer = ctx.net.add_slice(input=input, start=start, shape=shape, stride=stride) + set_layer_name(layer, target, f"{name}_slice", source_ir) + + reshape_layer = ctx.net.add_shuffle(layer.get_output(0)) + reshape_layer.reshape_dims = [ + 1, + ] # Reshape to a single-element tensor + set_layer_name(reshape_layer, target, f"{name}_reshape", source_ir) + + return reshape_layer.get_output(0) diff --git a/tests/py/dynamo/conversion/test_local_scalar_dense_aten.py b/tests/py/dynamo/conversion/test_local_scalar_dense_aten.py new file mode 100644 index 0000000000..4d22632789 --- /dev/null +++ b/tests/py/dynamo/conversion/test_local_scalar_dense_aten.py @@ -0,0 +1,65 @@ +import torch +import torch.nn as nn +from harness import DispatchTestCase +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests + + +class TestLocalScalarDenseConverter(DispatchTestCase): + @parameterized.expand( + [ + ( + torch.randn((5, 10, 5), dtype=torch.float32), + ), + ( + torch.randint(-10, 10, (5, 1, 15), dtype=torch.int32), + ), + ( + torch.randn((1), dtype=torch.float32), + ), + ( + (torch.tensor([-2.4])), + ), + ( + (torch.tensor([5.5, 3.5, 3.6])), + ), + ( + (torch.tensor([True])), + ), + ( + torch.tensor( + [ + float("nan"), + 1.23, + float("inf"), + ] + ), + ), + ( + torch.tensor( + [ + float("-inf"), + 1.23, + float("nan"), + ] + ), + ), + ( + (torch.tensor([float("inf")])), + ), + ] + ) + def test_local_scalar_dense(self, data): + class local_scalar_dense(nn.Module): + def forward(self, input): + return torch.ops.aten._local_scalar_dense.default(input) + + inputs = [data] + self.run_test( + local_scalar_dense(), + inputs, + ) + + +if __name__ == "__main__": + run_tests() From 8831462f3bc964c0daa4425dcbcc0635d1470959 Mon Sep 17 00:00:00 2001 From: Hoonkyung Cho Date: Thu, 13 Jun 2024 08:27:10 +0900 Subject: [PATCH 2/3] chore: small linting --- .../test_local_scalar_dense_aten.py | 28 +++++-------------- 1 file changed, 7 insertions(+), 21 deletions(-) diff --git a/tests/py/dynamo/conversion/test_local_scalar_dense_aten.py b/tests/py/dynamo/conversion/test_local_scalar_dense_aten.py index 4d22632789..3576c5c9fa 100644 --- a/tests/py/dynamo/conversion/test_local_scalar_dense_aten.py +++ b/tests/py/dynamo/conversion/test_local_scalar_dense_aten.py @@ -8,24 +8,12 @@ class TestLocalScalarDenseConverter(DispatchTestCase): @parameterized.expand( [ - ( - torch.randn((5, 10, 5), dtype=torch.float32), - ), - ( - torch.randint(-10, 10, (5, 1, 15), dtype=torch.int32), - ), - ( - torch.randn((1), dtype=torch.float32), - ), - ( - (torch.tensor([-2.4])), - ), - ( - (torch.tensor([5.5, 3.5, 3.6])), - ), - ( - (torch.tensor([True])), - ), + (torch.randn((5, 10, 5), dtype=torch.float32),), + (torch.randint(-10, 10, (5, 1, 15), dtype=torch.int32),), + (torch.randn((1), dtype=torch.float32),), + ((torch.tensor([-2.4])),), + ((torch.tensor([5.5, 3.5, 3.6])),), + ((torch.tensor([True])),), ( torch.tensor( [ @@ -44,9 +32,7 @@ class TestLocalScalarDenseConverter(DispatchTestCase): ] ), ), - ( - (torch.tensor([float("inf")])), - ), + ((torch.tensor([float("inf")])),), ] ) def test_local_scalar_dense(self, data): From 040ecd8f43f25bfd61db9f1c3393a2f5e41426f9 Mon Sep 17 00:00:00 2001 From: Hoonkyung Cho Date: Thu, 13 Jun 2024 13:02:23 +0900 Subject: [PATCH 3/3] chore: a minor linting issue --- tests/py/dynamo/conversion/test_local_scalar_dense_aten.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/py/dynamo/conversion/test_local_scalar_dense_aten.py b/tests/py/dynamo/conversion/test_local_scalar_dense_aten.py index 3576c5c9fa..7817fc0ab7 100644 --- a/tests/py/dynamo/conversion/test_local_scalar_dense_aten.py +++ b/tests/py/dynamo/conversion/test_local_scalar_dense_aten.py @@ -1,9 +1,10 @@ import torch import torch.nn as nn -from harness import DispatchTestCase from parameterized import parameterized from torch.testing._internal.common_utils import run_tests +from .harness import DispatchTestCase + class TestLocalScalarDenseConverter(DispatchTestCase): @parameterized.expand(