From 8fc3482cb57d9ca6c6fb80f431ac1bf7b765b7a3 Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Tue, 12 Nov 2024 13:03:55 -0800 Subject: [PATCH] change decomposition default table due to upstream torch change --- py/torch_tensorrt/dynamo/lowering/_decompositions.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/py/torch_tensorrt/dynamo/lowering/_decompositions.py b/py/torch_tensorrt/dynamo/lowering/_decompositions.py index dda014890d..5fcccb5c77 100644 --- a/py/torch_tensorrt/dynamo/lowering/_decompositions.py +++ b/py/torch_tensorrt/dynamo/lowering/_decompositions.py @@ -4,8 +4,8 @@ import torch from torch._decomp import register_decomposition -from torch._export.utils import _decomp_table_to_post_autograd_aten from torch._ops import OpOverload +from torch.export import default_decompositions from torch_tensorrt.dynamo._defaults import default_device from torch_tensorrt.dynamo.conversion.converter_utils import get_positive_dim from torch_tensorrt.dynamo.utils import to_torch_device @@ -412,7 +412,8 @@ def get_decompositions( return {**CORE_ATEN_DECOMPOSITIONS_FILTERED, **TORCH_TRT_DECOMPOSITIONS} else: # changes made here due to torch2.6 changes https://github.com/pytorch/pytorch/pull/135080 - decomp_table = _decomp_table_to_post_autograd_aten() + # changes made here due to torch2.6 changes https://github.com/pytorch/pytorch/pull/140085 + decomp_table = default_decompositions() DECOMP_TABLE_FILTERED: Dict[OpOverload, Callable[[Any], Any]] = { decomp: decomp_table[decomp] for decomp in decomp_table