Skip to content

Commit

Permalink
feat: support aten._local_scalar_dense converter
Browse files Browse the repository at this point in the history
  • Loading branch information
chohk88 committed Apr 16, 2024
1 parent b76024d commit d83081e
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 0 deletions.
17 changes: 17 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -1551,6 +1551,23 @@ def aten_ops_isnan(
)


@dynamo_tensorrt_converter(torch.ops.aten._local_scalar_dense.default)
def aten_ops_local_scalar_dense(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.unary.local_scalar_dense(
ctx,
target,
SourceIR.ATEN,
name,
args[0],
)


@dynamo_tensorrt_converter(torch.ops.aten.add.Tensor)
@dynamo_tensorrt_converter(torch.ops.aten.add.Scalar)
def aten_ops_add(
Expand Down
23 changes: 23 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,3 +554,26 @@ def isnan(
)

return nan_values_mask


def local_scalar_dense(
ctx: ConversionContext,
target: Union[Target, str],
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
) -> TRTTensor:
start = [0] * len(input.shape)
shape = [1] * len(input.shape) # Get one element from each dimension
stride = [1] * len(input.shape) # Step through each dimension by 1

layer = ctx.net.add_slice(input=input, start=start, shape=shape, stride=stride)
set_layer_name(layer, target, f"{name}_slice", source_ir)

reshape_layer = ctx.net.add_shuffle(layer.get_output(0))
reshape_layer.reshape_dims = [
1,
] # Reshape to a single-element tensor
set_layer_name(reshape_layer, target, f"{name}_reshape", source_ir)

return reshape_layer.get_output(0)
65 changes: 65 additions & 0 deletions tests/py/dynamo/conversion/test_local_scalar_dense_aten.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import torch
import torch.nn as nn
from harness import DispatchTestCase
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests


class TestLocalScalarDenseConverter(DispatchTestCase):
@parameterized.expand(
[
(
torch.randn((5, 10, 5), dtype=torch.float32),
),
(
torch.randint(-10, 10, (5, 1, 15), dtype=torch.int32),
),
(
torch.randn((1), dtype=torch.float32),
),
(
(torch.tensor([-2.4])),
),
(
(torch.tensor([5.5, 3.5, 3.6])),
),
(
(torch.tensor([True])),
),
(
torch.tensor(
[
float("nan"),
1.23,
float("inf"),
]
),
),
(
torch.tensor(
[
float("-inf"),
1.23,
float("nan"),
]
),
),
(
(torch.tensor([float("inf")])),
),
]
)
def test_local_scalar_dense(self, data):
class local_scalar_dense(nn.Module):
def forward(self, input):
return torch.ops.aten._local_scalar_dense.default(input)

inputs = [data]
self.run_test(
local_scalar_dense(),
inputs,
)


if __name__ == "__main__":
run_tests()

0 comments on commit d83081e

Please sign in to comment.