Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: bug fixes #3065

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
23 changes: 6 additions & 17 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -2550,15 +2550,7 @@ def aten_ops_cdist_forward(


def avg_pool_param_validator(pool_node: Node) -> bool:
ceil_mode = args_bounds_check(pool_node.args, 4, False)
divisor_override = args_bounds_check(pool_node.args, 6)

if ceil_mode is not False:
_LOGGER.debug(
f"Currently we don't support specifying ceil_mode, got ceil_mode={ceil_mode}."
)
return False

if divisor_override is not None:
_LOGGER.debug(
f"Currently we don't support divisor_override, got divisor_override={divisor_override}."
Expand Down Expand Up @@ -2694,17 +2686,14 @@ def topk_sort_validator(k: int) -> bool:

def max_pool_param_validator(pool_node: Node) -> bool:
dilation = args_bounds_check(pool_node.args, 4, 1)
ceil_mode = args_bounds_check(pool_node.args, 5, False)

if dilation != 1:
_LOGGER.debug(f"Currently we don't support dilation, got dilation={dilation}.")
return False
if not isinstance(dilation, (list, tuple)):
dilation = (dilation,)

if ceil_mode is not False:
_LOGGER.debug(
f"Currently we don't support specifying ceil_mode, got ceil_mode={ceil_mode}."
)
return False
for dil in dilation:
if dil != 1:
_LOGGER.debug("Currently we don't support dilation > 1 at any dimension.")
return False

return True

Expand Down
15 changes: 8 additions & 7 deletions py/torch_tensorrt/dynamo/conversion/impl/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,9 @@ def avg_poolNd(
count_include_pad: bool = True,
divisor_override: Optional[int] = None,
) -> TRTTensor:
if ceil_mode is not False:
raise RuntimeError("ceil_mode is not yet supported!")
padding_mode = trt.PaddingMode.EXPLICIT_ROUND_DOWN
if ceil_mode:
padding_mode = trt.PaddingMode.EXPLICIT_ROUND_UP

if divisor_override is not None:
raise RuntimeError("divisor_override is not yet supported!")
Expand All @@ -57,6 +58,7 @@ def avg_poolNd(
pool_layer.stride_nd = stride
pool_layer.padding_nd = padding
pool_layer.average_count_excludes_padding = not count_include_pad
pool_layer.padding_mode = padding_mode

set_layer_name(pool_layer, target, name, source_ir)
return pool_layer.get_output(0)
Expand All @@ -77,11 +79,9 @@ def max_poolNd(
if has_dynamic_shape(input.shape):
assert input.shape[1] != -1, "Channel dim can't be dynamic for pooling."

if dilation != 1:
raise RuntimeError("dilation is not yet supported!")

if ceil_mode is not False:
raise RuntimeError("ceil_mode is not yet supported!")
padding_mode = trt.PaddingMode.EXPLICIT_ROUND_DOWN
if ceil_mode:
padding_mode = trt.PaddingMode.EXPLICIT_ROUND_UP

dim = len(kernel_size)

Expand All @@ -103,6 +103,7 @@ def max_poolNd(

pool_layer.stride_nd = stride
pool_layer.padding_nd = padding
pool_layer.padding_mode = padding_mode

set_layer_name(pool_layer, target, name, source_ir)
return pool_layer.get_output(0)
Expand Down
24 changes: 23 additions & 1 deletion tests/py/dynamo/conversion/test_pool_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ class TestPoolConverter(DispatchTestCase):
((4,), (1,), (1,)),
((5,), (2,), (0,)),
((7,), (2,), (1,)),
((3,), (1,), (1,), 0, True),
((7,), (2,), (1,), 0, True),
]
)
def test_avg_pool1d(
Expand Down Expand Up @@ -44,8 +46,11 @@ def forward(self, x):
(3, 1, 1),
((2, 2), [], (1, 0)),
((4, 3), (1, 1), (1, 1)),
((4, 3), (1, 1), (1, 1), True),
((5, 4), (2, 1), (1, 0)),
((5, 4), (2, 1), (1, 0), True),
((7, 7), (1, 2), (0, 1)),
((7, 7), (1, 2), (0, 1), True),
]
)
def test_avg_pool2d(
Expand All @@ -70,7 +75,7 @@ def forward(self, x):
)

inputs = [torch.randn(1, 3, 32, 32)]
self.run_test(TestModule(), inputs, use_dynamo_tracer=True)
self.run_test(TestModule(), inputs, rtol=5e-03, atol=5e-03, use_dynamo_tracer=True)

@parameterized.expand(
[
Expand All @@ -80,6 +85,8 @@ def forward(self, x):
((4, 3, 2), (1, 1, 1), (1, 1, 0)),
((5, 4, 3), (2, 1, 2), (1, 0, 1)),
((7, 7, 7), (1, 2, 1), (0, 1, 1)),
((7, 7, 7), (1, 2, 1), (0, 1, 1), True),
((5, 4, 3), (2, 1, 2), (1, 0, 1), True),
]
)
def test_avg_pool3d(
Expand Down Expand Up @@ -168,6 +175,16 @@ def forward(self, x):
(1, 1),
(1, 1),
),
(
(1, 1, 1, 1),
(2, 2, 2, 2),
(3, 3, 3, 3),
torch.float,
(3, 3),
(1, 1),
(1, 1),
True
),
]
)
def test_dynamic_shape_pool2d(
Expand Down Expand Up @@ -258,6 +275,7 @@ def forward(self, x):
((4,), (1,), (1,)),
((5,), (2,), (0,)),
((7,), (2,), (1,)),
((7,), (2,), (1,), 1, True),
]
)
def test_max_pool1d(
Expand Down Expand Up @@ -290,6 +308,9 @@ def forward(self, x):
((4, 3), (1, 1), (1, 1)),
((5, 4), (2, 1), (1, 0)),
((7, 7), (1, 2), (0, 1)),
((4, 3), (1, 1), (1, 1), 1, True),
((5, 4), (2, 1), (1, 0), 1, True),
((7, 7), (1, 2), (0, 1), 1, True),
]
)
def test_max_pool2d(
Expand Down Expand Up @@ -322,6 +343,7 @@ def forward(self, x):
((4, 3, 2), (1, 1, 1), (1, 1, 0)),
((5, 4, 3), (2, 1, 2), (1, 0, 1)),
((7, 7, 7), (1, 2, 1), (0, 1, 1)),
((7, 7, 7), (1, 2, 1), (0, 1, 1), 1, True),
]
)
def test_max_pool3d(
Expand Down
Loading