From b275ef6de354d9a740d84687f14328965dea2fcf Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Tue, 6 Aug 2024 15:37:17 -0700 Subject: [PATCH 1/8] chore: bug fix --- .../dynamo/lowering/passes/replace_max_pool_with_indices.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/py/torch_tensorrt/dynamo/lowering/passes/replace_max_pool_with_indices.py b/py/torch_tensorrt/dynamo/lowering/passes/replace_max_pool_with_indices.py index 6e3762e73c..81fdff8060 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/replace_max_pool_with_indices.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/replace_max_pool_with_indices.py @@ -44,6 +44,8 @@ def replace_max_pool_with_indices( kwargs=node.kwargs, ) maxpool_fused.meta = node.meta + # The metadata for this node should exclude the indices metadata + maxpool_fused.meta["val"] = maxpool_fused.meta["val"][0] logger.debug( f"Replacing all uses of nodes {node}, {getitem_node} with fused maxpool node {maxpool_fused} " From a73d1477db6ef1e7ee43fa41e176deb036cd91ca Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Sat, 17 Aug 2024 20:16:36 +0000 Subject: [PATCH 2/8] chore: updates --- .../dynamo/conversion/_conversion.py | 4 +- .../dynamo/conversion/aten_ops_converters.py | 17 +++--- .../dynamo/conversion/impl/pool.py | 16 +++--- .../dynamo/partitioning/common.py | 4 +- py/torch_tensorrt/dynamo/utils.py | 53 +++++++++++++++---- 5 files changed, 64 insertions(+), 30 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/_conversion.py b/py/torch_tensorrt/dynamo/conversion/_conversion.py index 8f22a6c993..60ab5dba12 100644 --- a/py/torch_tensorrt/dynamo/conversion/_conversion.py +++ b/py/torch_tensorrt/dynamo/conversion/_conversion.py @@ -4,7 +4,7 @@ from typing import List, Sequence import torch -from torch.fx.experimental.proxy_tensor import maybe_disable_fake_tensor_mode +from torch.fx.experimental.proxy_tensor import unset_fake_temporarily from torch_tensorrt._Device import Device from torch_tensorrt._enums import dtype from torch_tensorrt._features import ENABLED_FEATURES @@ -28,7 +28,7 @@ def infer_module_output_dtypes( device: Device, truncate_double: bool = False, ) -> List[dtype]: - with maybe_disable_fake_tensor_mode(): + with unset_fake_temporarily(): torch_inputs = get_torch_inputs(inputs, device) module = module.to(device.to(torch.device)) module_outputs = module(*torch_inputs) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 36a5596d96..499832ce1f 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -2694,18 +2694,17 @@ def topk_sort_validator(k: int) -> bool: def max_pool_param_validator(pool_node: Node) -> bool: + # breakpoint() dilation = args_bounds_check(pool_node.args, 4, 1) ceil_mode = args_bounds_check(pool_node.args, 5, False) - if dilation != 1: - _LOGGER.debug(f"Currently we don't support dilation, got dilation={dilation}.") - return False - - if ceil_mode is not False: - _LOGGER.debug( - f"Currently we don't support specifying ceil_mode, got ceil_mode={ceil_mode}." - ) - return False + if not isinstance(dilation, (list, tuple)): + dilation = (dilation,) + + for dil in dilation: + if dil != 1: + _LOGGER.debug("Currently we don't support dilation > 1 at any dimension.") + return False return True diff --git a/py/torch_tensorrt/dynamo/conversion/impl/pool.py b/py/torch_tensorrt/dynamo/conversion/impl/pool.py index 4e18aaaef2..fa27c0265c 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/pool.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/pool.py @@ -30,8 +30,10 @@ def avg_poolNd( count_include_pad: bool = True, divisor_override: Optional[int] = None, ) -> TRTTensor: - if ceil_mode is not False: - raise RuntimeError("ceil_mode is not yet supported!") + + padding_mode = trt.PaddingMode.EXPLICIT_ROUND_DOWN + if ceil_mode: + padding_mode = trt.PaddingMode.EXPLICIT_ROUND_UP if divisor_override is not None: raise RuntimeError("divisor_override is not yet supported!") @@ -57,6 +59,7 @@ def avg_poolNd( pool_layer.stride_nd = stride pool_layer.padding_nd = padding pool_layer.average_count_excludes_padding = not count_include_pad + pool_layer.padding_mode = padding_mode set_layer_name(pool_layer, target, name, source_ir) return pool_layer.get_output(0) @@ -77,11 +80,9 @@ def max_poolNd( if has_dynamic_shape(input.shape): assert input.shape[1] != -1, "Channel dim can't be dynamic for pooling." - if dilation != 1: - raise RuntimeError("dilation is not yet supported!") - - if ceil_mode is not False: - raise RuntimeError("ceil_mode is not yet supported!") + padding_mode = trt.PaddingMode.EXPLICIT_ROUND_DOWN + if ceil_mode: + padding_mode = trt.PaddingMode.EXPLICIT_ROUND_UP dim = len(kernel_size) @@ -103,6 +104,7 @@ def max_poolNd( pool_layer.stride_nd = stride pool_layer.padding_nd = padding + pool_layer.padding_mode = padding_mode set_layer_name(pool_layer, target, name, source_ir) return pool_layer.get_output(0) diff --git a/py/torch_tensorrt/dynamo/partitioning/common.py b/py/torch_tensorrt/dynamo/partitioning/common.py index fdc55126ee..af78a72e03 100644 --- a/py/torch_tensorrt/dynamo/partitioning/common.py +++ b/py/torch_tensorrt/dynamo/partitioning/common.py @@ -3,7 +3,7 @@ import torch from torch._subclasses.fake_tensor import FakeTensor -from torch.fx.experimental.proxy_tensor import maybe_disable_fake_tensor_mode +from torch.fx.experimental.proxy_tensor import unset_fake_temporarily from torch_tensorrt._Input import Input from torch_tensorrt.dynamo._defaults import DEBUG @@ -90,7 +90,7 @@ def construct_submodule_inputs(module: torch.fx.GraphModule) -> Sequence[Input]: Returns: Sequence of torch_tensorrt.Input's representing inputs to given module """ - with maybe_disable_fake_tensor_mode(): + with unset_fake_temporarily(): torchtrt_inputs = [] module_inputs = [ node for node in module.graph.nodes if node.op == "placeholder" diff --git a/py/torch_tensorrt/dynamo/utils.py b/py/torch_tensorrt/dynamo/utils.py index acfb2b0094..bdf7b0d087 100644 --- a/py/torch_tensorrt/dynamo/utils.py +++ b/py/torch_tensorrt/dynamo/utils.py @@ -136,16 +136,49 @@ def get_torch_inputs( on the mode requested. """ device = to_torch_device(device) - if mode: - return [ - input.example_tensor(mode).to(device) - for input in inputs - if isinstance(input, Input) - ] - return [ - input.torch_tensor.to(device) if isinstance(input, Input) else input - for input in inputs - ] + + if isinstance(inputs, dict): + result = {} + for k, v in inputs.items(): + if isinstance(v, (list, tuple, dict)): + result[k] = get_torch_inputs(v, device) + elif isinstance(v, Input): + if len(mode) > 0: + result[k] = v.example_tensor(mode).to(device) + else: + result[k] = v.torch_tensor.to(device) + else: + result = [] + for input in inputs: + if isinstance(input, Input): + if len(mode) > 0: + result.append(input.example_tensor(mode).to(device)) + else: + result.append(input.torch_tensor.to(device)) + elif isinstance(input, torch.Tensor): + result.append(input.to(device)) + else: + raise AssertionError(f"Input type {type(input)} is not a valid type") + + return result + + +def get_model_device(module: torch.fx.GraphModule) -> Union[Device, torch.device, str]: + """ + Returns the device on which the module parameters exist. + """ + device = None + for parameter in list(module.parameters()): + if isinstance(parameter, (torch.nn.parameter.Parameter, torch.Tensor)): + device = parameter.device + break + + if device is None: + device = torch.device("cpu") + logger.warning( + "Could not detect the device on which the model exists. Assuming the model is on CPU" + ) + return device def set_log_level(parent_logger: Any, level: Any) -> None: From 847459d4c75a770ef383d6128e03d2377d3c3dbe Mon Sep 17 00:00:00 2001 From: Hoonkyung Cho Date: Tue, 13 Aug 2024 02:29:19 +0900 Subject: [PATCH 3/8] feat: lowering replace aten.full_like with aten.full --- .../dynamo/lowering/_decompositions.py | 14 ++-- .../lowering/passes/_aten_lowering_pass.py | 2 + .../passes/replace_full_like_with_full.py | 43 ++++++++++++ .../py/dynamo/lowering/test_decompositions.py | 68 +++++++++++++++++-- 4 files changed, 115 insertions(+), 12 deletions(-) create mode 100644 py/torch_tensorrt/dynamo/lowering/passes/replace_full_like_with_full.py diff --git a/py/torch_tensorrt/dynamo/lowering/_decompositions.py b/py/torch_tensorrt/dynamo/lowering/_decompositions.py index 2729e38ff5..378d407416 100644 --- a/py/torch_tensorrt/dynamo/lowering/_decompositions.py +++ b/py/torch_tensorrt/dynamo/lowering/_decompositions.py @@ -168,7 +168,7 @@ def var_decomposition( @register_torch_trt_decomposition( torch.ops.aten.empty_permuted.default, registry=TORCH_TRT_DECOMPOSITIONS ) -def empty_permuted_decomposition(*args, **kwargs) -> torch.Tensor: +def empty_permuted_decomposition(*args, **kwargs) -> torch.Tensor: # type: ignore empty_size = args[0] empty_permute = args[1] perm = [0] * len(empty_size) @@ -188,7 +188,7 @@ def slice_scatter_decomposition( start: Optional[int] = None, end: Optional[int] = None, step: Optional[int] = None, -): +) -> torch.Tensor: dim_size = input_tensor.shape[dim] start = get_positive_dim(start, input_tensor.shape[dim]) if end is None: @@ -197,6 +197,11 @@ def slice_scatter_decomposition( if step is None: step = 1 + # Ensure start, end, and step are all integers + assert isinstance(start, int), "start must be an integer" + assert isinstance(end, int), "end must be an integer" + assert isinstance(step, int), "step must be an integer" + src_dim = src_tensor.shape # step == 0 is not a valid torch case # also src_dim should be equal to slice dimension @@ -233,7 +238,7 @@ def select_scatter_decomposition( @register_torch_trt_decomposition( torch.ops.aten.empty_strided.default, registry=TORCH_TRT_DECOMPOSITIONS ) -def empty_strided_decomposition(*args, **kwargs) -> torch.Tensor: +def empty_strided_decomposition(*args, **kwargs) -> torch.Tensor: # type: ignore empty_size = args[0] empty_stride = args[1] return torch.as_strided( @@ -256,8 +261,7 @@ def scatter_add_decomposition( src_shape = list(src_tensor.shape) src_dim = src_shape[dim] for i in range(0, src_dim): - to_scatter_tensor = torch.zeros_like(input_tensor) - + to_scatter_tensor = torch.zeros(input_tensor.shape, dtype=input_tensor.dtype) # index and src slice src_slice = torch.select(src_tensor, dim, i) index_slice = torch.select(index, dim, i) diff --git a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py index 3d1663fe0b..958fd5305a 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py @@ -11,6 +11,7 @@ from .remove_detach import remove_detach from .remove_input_alias_fixing_clones import remove_input_alias_fixing_clones from .repair_input_as_output import repair_input_as_output +from .replace_full_like_with_full import replace_full_like_with_full from .replace_max_pool_with_indices import replace_max_pool_with_indices from .view_to_reshape import view_to_reshape @@ -23,6 +24,7 @@ lower_linear, fuse_prims_broadcast, replace_max_pool_with_indices, + replace_full_like_with_full, view_to_reshape, ] ) diff --git a/py/torch_tensorrt/dynamo/lowering/passes/replace_full_like_with_full.py b/py/torch_tensorrt/dynamo/lowering/passes/replace_full_like_with_full.py new file mode 100644 index 0000000000..35f9b1cd3f --- /dev/null +++ b/py/torch_tensorrt/dynamo/lowering/passes/replace_full_like_with_full.py @@ -0,0 +1,43 @@ +import logging +from typing import Sequence + +import torch +import torch.fx +from torch_tensorrt.dynamo.lowering.passes.pass_utils import ( + clean_up_graph_after_modifications, +) + +logger = logging.getLogger(__name__) + + +def replace_full_like_with_full( + gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor] +) -> torch.fx.GraphModule: + """Replace full_like nodes with equivalent full nodes""" + modified_graph = False + + for node in gm.graph.nodes: + if node.target == torch.ops.aten.full_like.default: + modified_graph = True + + # Extract arguments from full_like + input_tensor = node.args[0] + fill_value = node.args[1] + shape = list(input_tensor.meta["tensor_meta"].shape) + + # Replace full_like with full, using the shape as a list + with gm.graph.inserting_after(node): + full_node = gm.graph.call_function( + torch.ops.aten.full.default, + args=(shape, fill_value), + kwargs=node.kwargs, + ) + full_node.meta = node.meta + + node.replace_all_uses_with(full_node) + gm.graph.erase_node(node) + + if modified_graph: + gm = clean_up_graph_after_modifications(gm) + + return gm diff --git a/tests/py/dynamo/lowering/test_decompositions.py b/tests/py/dynamo/lowering/test_decompositions.py index a1416c00db..74ac6cde62 100644 --- a/tests/py/dynamo/lowering/test_decompositions.py +++ b/tests/py/dynamo/lowering/test_decompositions.py @@ -421,6 +421,66 @@ def forward(self, x): f"MaxPool3d TRT outputs don't match with the original model.", ) + def test_lowering_full_like_module(self): + class FullLike(torch.nn.Module): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + def forward(self, x): + y = torch.full_like(x, 2.0) + return y + + # Operations expected to be removed in the traced graph after decompositions + expected_ops = {torch.ops.aten.full.default} + unexpected_ops = {torch.ops.aten.full_like.default} + + inputs = [torch.randn(3, 3, dtype=torch.float32).cuda()] + + fx_graph = torch.fx.symbolic_trace(FullLike()) + unexpected_ops_seen, expected_ops_unseen = lower_graph_testing( + fx_graph, + inputs, + expected_ops=expected_ops, + unexpected_ops=unexpected_ops, + min_block_size=1, + ) + + self.assertEqual( + len(unexpected_ops_seen), + 0, + f"The following unexpected ops were encountered: {unexpected_ops_seen}", + ) + + self.assertEqual( + len(expected_ops_unseen), + 0, + f"The following expected ops were not encountered: {expected_ops_unseen}", + ) + + torch._dynamo.reset() + + # Validate that the results between Torch and Torch-TRT are similar + optimized_model = torch_tensorrt.compile( + fx_graph, + "torch_compile", + inputs, + min_block_size=1, + truncate_double=True, + pass_through_build_failures=True, + ) + optimized_model_results = optimized_model(*inputs).detach().cpu() + torch_model_results = fx_graph(*inputs).detach().cpu() + + max_diff = float( + torch.max(torch.abs(optimized_model_results - torch_model_results)) + ) + self.assertAlmostEqual( + max_diff, + 0, + DECIMALS_OF_AGREEMENT, + f"FullLike TRT outputs don't match with the original model.", + ) + def test_lowering_empty_like_module(self): class emptyLike(torch.nn.Module): def __init__(self, *args, **kwargs) -> None: @@ -976,7 +1036,7 @@ def forward(self, input): 0, torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]).cuda(), torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=torch.int32).cuda(), - {torch.ops.aten.add.Tensor, torch.ops.aten.scatter.src}, + {torch.ops.aten.add.Tensor}, ), ( "scatter_add_one_dim_indexOne_constant", @@ -985,8 +1045,6 @@ def forward(self, input): torch.tensor([[1, 2, 3, 1]], dtype=torch.int32).cuda(), { torch.ops.aten.add.Tensor, - torch.ops.aten.scatter.src, - torch.ops.aten.full_like.default, }, ), ( @@ -996,8 +1054,6 @@ def forward(self, input): torch.tensor([[1, 2, 3, 1], [5, 6, 5, 5]], dtype=torch.int32).cuda(), { torch.ops.aten.add.Tensor, - torch.ops.aten.scatter.src, - torch.ops.aten.full_like.default, }, ), ( @@ -1009,8 +1065,6 @@ def forward(self, input): ).cuda(), { torch.ops.aten.add.Tensor, - torch.ops.aten.scatter.src, - torch.ops.aten.full_like.default, }, ), ] From 70540a2882f0f35fe38dc367abeb272520640879 Mon Sep 17 00:00:00 2001 From: Hoonkyung Cho Date: Tue, 13 Aug 2024 02:40:46 +0900 Subject: [PATCH 4/8] chore: minor linting --- .../dynamo/lowering/passes/replace_full_like_with_full.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py/torch_tensorrt/dynamo/lowering/passes/replace_full_like_with_full.py b/py/torch_tensorrt/dynamo/lowering/passes/replace_full_like_with_full.py index 35f9b1cd3f..d09778f3c6 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/replace_full_like_with_full.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/replace_full_like_with_full.py @@ -11,7 +11,7 @@ def replace_full_like_with_full( - gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor] + gm: torch.fx.GraphModule, ) -> torch.fx.GraphModule: """Replace full_like nodes with equivalent full nodes""" modified_graph = False From 728801b1996c5c93b30414ea8314a24ab5e33c07 Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Sat, 17 Aug 2024 21:37:19 +0000 Subject: [PATCH 5/8] chore: updates --- py/torch_tensorrt/dynamo/backend/backends.py | 2 -- py/torch_tensorrt/dynamo/lowering/_decompositions.py | 4 ++-- .../dynamo/lowering/passes/replace_full_like_with_full.py | 8 ++++++-- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/py/torch_tensorrt/dynamo/backend/backends.py b/py/torch_tensorrt/dynamo/backend/backends.py index ae3cb38f2d..9d0df74e87 100644 --- a/py/torch_tensorrt/dynamo/backend/backends.py +++ b/py/torch_tensorrt/dynamo/backend/backends.py @@ -85,7 +85,6 @@ def _pretraced_backend( # Remove detach nodes remove_detach(gm) - # Invoke AOTAutograd to translate operators to aten gm = aot_export_joint_simple( gm, @@ -95,7 +94,6 @@ def _pretraced_backend( settings.enable_experimental_decompositions ), ) - logger.debug("Post-AOT Autograd graph:\n" + str(gm.graph)) gm = post_lowering(gm) diff --git a/py/torch_tensorrt/dynamo/lowering/_decompositions.py b/py/torch_tensorrt/dynamo/lowering/_decompositions.py index 378d407416..7fe0032d80 100644 --- a/py/torch_tensorrt/dynamo/lowering/_decompositions.py +++ b/py/torch_tensorrt/dynamo/lowering/_decompositions.py @@ -216,7 +216,7 @@ def slice_scatter_decomposition( index_tensor_shape.append(src_each_dim) for index in range(start, end, step): cat_tensors.append(index * torch.ones(index_tensor_shape, dtype=torch.int64)) - index_tensor = torch.stack(cat_tensors, dim).cuda() + index_tensor = torch.stack(cat_tensors, dim).to(input_tensor.device) index_tensor_64 = index_tensor.to(torch.int64) output_tensor = torch.scatter(input_tensor, dim, index_tensor_64, src_tensor) return output_tensor @@ -271,7 +271,7 @@ def scatter_add_decomposition( index_slice = torch.unsqueeze(index_slice, dim) # moving tensor to default device - device = to_torch_device(default_device()) + device = input_tensor.device scatter_add_tensor = scatter_add_tensor.to(device) to_scatter_tensor = to_scatter_tensor.to(device) index_slice = index_slice.to(device) diff --git a/py/torch_tensorrt/dynamo/lowering/passes/replace_full_like_with_full.py b/py/torch_tensorrt/dynamo/lowering/passes/replace_full_like_with_full.py index d09778f3c6..de78e4b7b6 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/replace_full_like_with_full.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/replace_full_like_with_full.py @@ -24,13 +24,17 @@ def replace_full_like_with_full( input_tensor = node.args[0] fill_value = node.args[1] shape = list(input_tensor.meta["tensor_meta"].shape) - + + new_kwargs = {} + for key, val in node.kwargs.items(): + if key != "memory_format": + new_kwargs[key] = val # Replace full_like with full, using the shape as a list with gm.graph.inserting_after(node): full_node = gm.graph.call_function( torch.ops.aten.full.default, args=(shape, fill_value), - kwargs=node.kwargs, + kwargs=new_kwargs, ) full_node.meta = node.meta From be3088b0bcd609f652d3d62ac5752ae8819bf850 Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Sun, 18 Aug 2024 03:34:18 +0000 Subject: [PATCH 6/8] chore: updates --- MODULE.bazel | 28 +++++++++---------- .../dynamo/conversion/aten_ops_converters.py | 7 ++--- .../dynamo/conversion/impl/full.py | 13 ++++++--- .../dynamo/conversion/impl/pool.py | 1 - .../dynamo/conversion/impl/slice/ops.py | 3 +- .../passes/replace_full_like_with_full.py | 20 +++++++++++-- pyproject.toml | 4 +-- 7 files changed, 48 insertions(+), 28 deletions(-) diff --git a/MODULE.bazel b/MODULE.bazel index 958ea92f1b..9d7710e856 100644 --- a/MODULE.bazel +++ b/MODULE.bazel @@ -76,15 +76,15 @@ http_archive( # Either place them in the distdir directory in third_party and use the --distdir flag # or modify the urls to "file:////.tar.gz -http_archive( - name = "tensorrt", - build_file = "@//third_party/tensorrt/archive:BUILD", - sha256 = "606436ed219c72d1a25a889b2b0ae5cb5a68499dd6f944da4cabb3c34c067d55", - strip_prefix = "TensorRT-10.1.0.27", - urls = [ - "https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.1.0/tars/TensorRT-10.1.0.27.Linux.x86_64-gnu.cuda-12.4.tar.gz", - ], -) +# http_archive( +# name = "tensorrt", +# build_file = "@//third_party/tensorrt/archive:BUILD", +# sha256 = "606436ed219c72d1a25a889b2b0ae5cb5a68499dd6f944da4cabb3c34c067d55", +# strip_prefix = "TensorRT-10.1.0.27", +# urls = [ +# "https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.1.0/tars/TensorRT-10.1.0.27.Linux.x86_64-gnu.cuda-12.4.tar.gz", +# ], +# ) http_archive( name = "tensorrt_win", @@ -119,8 +119,8 @@ http_archive( # build_file = "third_party/libtorch/BUILD" #) -#new_local_repository( -# name = "tensorrt", -# path = "/usr/", -# build_file = "@//third_party/tensorrt/local:BUILD" -#) +new_local_repository( + name = "tensorrt", + path = "/usr/", + build_file = "@//third_party/tensorrt/local:BUILD" +) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 9c029dd0b7..6a63734572 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -2695,18 +2695,16 @@ def topk_sort_validator(k: int) -> bool: def max_pool_param_validator(pool_node: Node) -> bool: - # breakpoint() dilation = args_bounds_check(pool_node.args, 4, 1) - ceil_mode = args_bounds_check(pool_node.args, 5, False) if not isinstance(dilation, (list, tuple)): dilation = (dilation,) - + for dil in dilation: if dil != 1: _LOGGER.debug("Currently we don't support dilation > 1 at any dimension.") return False - + return True @@ -3860,4 +3858,5 @@ def aten_ops_full( name, shape=args[0], fill_value=args[1], + dtype=kwargs["dtype"] ) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/full.py b/py/torch_tensorrt/dynamo/conversion/impl/full.py index d211cef532..b989515592 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/full.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/full.py @@ -1,6 +1,7 @@ from typing import List, Optional, Union import numpy as np +import torch import tensorrt as trt from torch.fx.node import Target from torch_tensorrt.dynamo.conversion import impl @@ -10,6 +11,7 @@ cast_trt_tensor, get_trt_tensor, ) +from torch_tensorrt import _enums from torch_tensorrt.fx.types import TRTTensor @@ -20,12 +22,14 @@ def full( name: str, shape: Union[List[int], TRTTensor], fill_value: Union[int, float, bool], + dtype: Union[torch.dtype, trt.DataType] ) -> TRTTensor: - # in static shape scenario, shape is a list of int + output_dtype = _enums.dtype._from(dtype) if isinstance(shape, List): # in static shape scenario, shape is a list of int if all(isinstance(dim, int) for dim in shape): - return np.full(shape, fill_value) + output_np_dtype = output_dtype.try_to(np.dtype, use_default=True) + return np.full(shape, fill_value, dtype=output_np_dtype) else: shape = impl.cat.cat( ctx, target, source_ir, name + "_concat_shape", shape, 0 @@ -33,7 +37,8 @@ def full( # in dynamic shape scenario, shape is a shap tensor # use IFillLayer to fill the shape tensor with LINSPACE value - layer = ctx.net.add_fill(shape.shape, trt.FillOperation.LINSPACE, shape.dtype) + output_trt_dtype = output_dtype.to(trt.DataType, use_default=True) + layer = ctx.net.add_fill(shape.shape, trt.FillOperation.LINSPACE, output_trt_dtype) layer.set_input(0, shape) layer.set_input(1, get_trt_tensor(ctx, 0, name + "_start", min_rank=0)) delta = get_trt_tensor(ctx, 1, name + "_delta") @@ -62,5 +67,5 @@ def full( output = impl.elementwise.logical_or( ctx, target, source_ir, name + "_add", output, fill_value ) - + breakpoint() return output diff --git a/py/torch_tensorrt/dynamo/conversion/impl/pool.py b/py/torch_tensorrt/dynamo/conversion/impl/pool.py index fa27c0265c..bc70d59527 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/pool.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/pool.py @@ -30,7 +30,6 @@ def avg_poolNd( count_include_pad: bool = True, divisor_override: Optional[int] = None, ) -> TRTTensor: - padding_mode = trt.PaddingMode.EXPLICIT_ROUND_DOWN if ceil_mode: padding_mode = trt.PaddingMode.EXPLICIT_ROUND_UP diff --git a/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py index eae0e24dcb..58d66b3c05 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py @@ -211,9 +211,10 @@ def slice_op( # TODO: This should be slice not whatever is in base return layer.get_output(0) output_shape[dim] = math.ceil((stop - start) / step) - return slice( + out = slice( ctx, target, source_ir, name, input, start_slice, output_shape, stride_slice ) + return out def expand( diff --git a/py/torch_tensorrt/dynamo/lowering/passes/replace_full_like_with_full.py b/py/torch_tensorrt/dynamo/lowering/passes/replace_full_like_with_full.py index de78e4b7b6..2a1c529c2f 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/replace_full_like_with_full.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/replace_full_like_with_full.py @@ -3,6 +3,8 @@ import torch import torch.fx +from torch_tensorrt.dynamo._defaults import default_device +from torch_tensorrt.dynamo.utils import to_torch_device from torch_tensorrt.dynamo.lowering.passes.pass_utils import ( clean_up_graph_after_modifications, ) @@ -23,17 +25,31 @@ def replace_full_like_with_full( # Extract arguments from full_like input_tensor = node.args[0] fill_value = node.args[1] + input_dtype = None + input_device = to_torch_device(default_device()) + if "val" in input_tensor.meta: + input_dtype = input_tensor.meta["val"].dtype + input_device = input_tensor.meta["val"].device + elif "tensor_meta" in input_tensor.meta: + input_dtype = input_tensor.meta["tensor_meta"].dtype + input_device = input_tensor.meta["tensor_meta"].device + shape = list(input_tensor.meta["tensor_meta"].shape) - + + # There's no memory format argument for torch.full. + # Set the input_device and dtype correspondingly. new_kwargs = {} for key, val in node.kwargs.items(): if key != "memory_format": new_kwargs[key] = val + new_kwargs["device"] = input_device + new_kwargs["dtype"] = input_dtype # Replace full_like with full, using the shape as a list + input_nodes = (shape, fill_value) with gm.graph.inserting_after(node): full_node = gm.graph.call_function( torch.ops.aten.full.default, - args=(shape, fill_value), + args=input_nodes, kwargs=new_kwargs, ) full_node.meta = node.meta diff --git a/pyproject.toml b/pyproject.toml index 1d6570db90..f6673d211b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,8 +43,8 @@ keywords = ["pytorch", "torch", "tensorrt", "trt", "ai", "artificial intelligenc dependencies = [ "torch >=2.5.0.dev,<2.6.0", "tensorrt==10.1.0", - "tensorrt-cu12_bindings==10.1.0", - "tensorrt-cu12_libs==10.1.0", + #"tensorrt-cu12_bindings==10.1.0", + #"tensorrt-cu12_libs==10.1.0", "packaging>=23", "numpy", "typing-extensions>=4.7.0", From cf390fa55869c9844f5589823a263660a2be26de Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Sun, 18 Aug 2024 05:59:34 +0000 Subject: [PATCH 7/8] chore: updates --- .../dynamo/conversion/aten_ops_converters.py | 8 ------- tests/py/dynamo/conversion/test_pool_aten.py | 24 ++++++++++++++++++- 2 files changed, 23 insertions(+), 9 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 6a63734572..f55e34591d 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -2552,15 +2552,7 @@ def aten_ops_cdist_forward( def avg_pool_param_validator(pool_node: Node) -> bool: - ceil_mode = args_bounds_check(pool_node.args, 4, False) divisor_override = args_bounds_check(pool_node.args, 6) - - if ceil_mode is not False: - _LOGGER.debug( - f"Currently we don't support specifying ceil_mode, got ceil_mode={ceil_mode}." - ) - return False - if divisor_override is not None: _LOGGER.debug( f"Currently we don't support divisor_override, got divisor_override={divisor_override}." diff --git a/tests/py/dynamo/conversion/test_pool_aten.py b/tests/py/dynamo/conversion/test_pool_aten.py index 29fdf30480..38746f23b3 100644 --- a/tests/py/dynamo/conversion/test_pool_aten.py +++ b/tests/py/dynamo/conversion/test_pool_aten.py @@ -15,6 +15,8 @@ class TestPoolConverter(DispatchTestCase): ((4,), (1,), (1,)), ((5,), (2,), (0,)), ((7,), (2,), (1,)), + ((3,), (1,), (1,), 0, True), + ((7,), (2,), (1,), 0, True), ] ) def test_avg_pool1d( @@ -44,8 +46,11 @@ def forward(self, x): (3, 1, 1), ((2, 2), [], (1, 0)), ((4, 3), (1, 1), (1, 1)), + ((4, 3), (1, 1), (1, 1), True), ((5, 4), (2, 1), (1, 0)), + ((5, 4), (2, 1), (1, 0), True), ((7, 7), (1, 2), (0, 1)), + ((7, 7), (1, 2), (0, 1), True), ] ) def test_avg_pool2d( @@ -70,7 +75,7 @@ def forward(self, x): ) inputs = [torch.randn(1, 3, 32, 32)] - self.run_test(TestModule(), inputs, use_dynamo_tracer=True) + self.run_test(TestModule(), inputs, rtol=5e-03, atol=5e-03, use_dynamo_tracer=True) @parameterized.expand( [ @@ -80,6 +85,8 @@ def forward(self, x): ((4, 3, 2), (1, 1, 1), (1, 1, 0)), ((5, 4, 3), (2, 1, 2), (1, 0, 1)), ((7, 7, 7), (1, 2, 1), (0, 1, 1)), + ((7, 7, 7), (1, 2, 1), (0, 1, 1), True), + ((5, 4, 3), (2, 1, 2), (1, 0, 1), True), ] ) def test_avg_pool3d( @@ -168,6 +175,16 @@ def forward(self, x): (1, 1), (1, 1), ), + ( + (1, 1, 1, 1), + (2, 2, 2, 2), + (3, 3, 3, 3), + torch.float, + (3, 3), + (1, 1), + (1, 1), + True + ), ] ) def test_dynamic_shape_pool2d( @@ -258,6 +275,7 @@ def forward(self, x): ((4,), (1,), (1,)), ((5,), (2,), (0,)), ((7,), (2,), (1,)), + ((7,), (2,), (1,), 1, True), ] ) def test_max_pool1d( @@ -290,6 +308,9 @@ def forward(self, x): ((4, 3), (1, 1), (1, 1)), ((5, 4), (2, 1), (1, 0)), ((7, 7), (1, 2), (0, 1)), + ((4, 3), (1, 1), (1, 1), 1, True), + ((5, 4), (2, 1), (1, 0), 1, True), + ((7, 7), (1, 2), (0, 1), 1, True), ] ) def test_max_pool2d( @@ -322,6 +343,7 @@ def forward(self, x): ((4, 3, 2), (1, 1, 1), (1, 1, 0)), ((5, 4, 3), (2, 1, 2), (1, 0, 1)), ((7, 7, 7), (1, 2, 1), (0, 1, 1)), + ((7, 7, 7), (1, 2, 1), (0, 1, 1), 1, True), ] ) def test_max_pool3d( From ec9f744c4e6ac187a121e0425674c1989f6e0ec7 Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Tue, 20 Aug 2024 17:29:46 -0700 Subject: [PATCH 8/8] chore: updates --- MODULE.bazel | 28 +++++++++---------- py/torch_tensorrt/dynamo/backend/backends.py | 2 ++ .../dynamo/conversion/impl/slice/ops.py | 3 +- pyproject.toml | 4 +-- 4 files changed, 19 insertions(+), 18 deletions(-) diff --git a/MODULE.bazel b/MODULE.bazel index 9d7710e856..958ea92f1b 100644 --- a/MODULE.bazel +++ b/MODULE.bazel @@ -76,15 +76,15 @@ http_archive( # Either place them in the distdir directory in third_party and use the --distdir flag # or modify the urls to "file:////.tar.gz -# http_archive( -# name = "tensorrt", -# build_file = "@//third_party/tensorrt/archive:BUILD", -# sha256 = "606436ed219c72d1a25a889b2b0ae5cb5a68499dd6f944da4cabb3c34c067d55", -# strip_prefix = "TensorRT-10.1.0.27", -# urls = [ -# "https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.1.0/tars/TensorRT-10.1.0.27.Linux.x86_64-gnu.cuda-12.4.tar.gz", -# ], -# ) +http_archive( + name = "tensorrt", + build_file = "@//third_party/tensorrt/archive:BUILD", + sha256 = "606436ed219c72d1a25a889b2b0ae5cb5a68499dd6f944da4cabb3c34c067d55", + strip_prefix = "TensorRT-10.1.0.27", + urls = [ + "https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.1.0/tars/TensorRT-10.1.0.27.Linux.x86_64-gnu.cuda-12.4.tar.gz", + ], +) http_archive( name = "tensorrt_win", @@ -119,8 +119,8 @@ http_archive( # build_file = "third_party/libtorch/BUILD" #) -new_local_repository( - name = "tensorrt", - path = "/usr/", - build_file = "@//third_party/tensorrt/local:BUILD" -) +#new_local_repository( +# name = "tensorrt", +# path = "/usr/", +# build_file = "@//third_party/tensorrt/local:BUILD" +#) diff --git a/py/torch_tensorrt/dynamo/backend/backends.py b/py/torch_tensorrt/dynamo/backend/backends.py index 9d0df74e87..ae3cb38f2d 100644 --- a/py/torch_tensorrt/dynamo/backend/backends.py +++ b/py/torch_tensorrt/dynamo/backend/backends.py @@ -85,6 +85,7 @@ def _pretraced_backend( # Remove detach nodes remove_detach(gm) + # Invoke AOTAutograd to translate operators to aten gm = aot_export_joint_simple( gm, @@ -94,6 +95,7 @@ def _pretraced_backend( settings.enable_experimental_decompositions ), ) + logger.debug("Post-AOT Autograd graph:\n" + str(gm.graph)) gm = post_lowering(gm) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py index 58d66b3c05..eae0e24dcb 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py @@ -211,10 +211,9 @@ def slice_op( # TODO: This should be slice not whatever is in base return layer.get_output(0) output_shape[dim] = math.ceil((stop - start) / step) - out = slice( + return slice( ctx, target, source_ir, name, input, start_slice, output_shape, stride_slice ) - return out def expand( diff --git a/pyproject.toml b/pyproject.toml index f6673d211b..1d6570db90 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,8 +43,8 @@ keywords = ["pytorch", "torch", "tensorrt", "trt", "ai", "artificial intelligenc dependencies = [ "torch >=2.5.0.dev,<2.6.0", "tensorrt==10.1.0", - #"tensorrt-cu12_bindings==10.1.0", - #"tensorrt-cu12_libs==10.1.0", + "tensorrt-cu12_bindings==10.1.0", + "tensorrt-cu12_libs==10.1.0", "packaging>=23", "numpy", "typing-extensions>=4.7.0",