Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

full_like to full decomposition moving to decomposition.py for dynami… #3289

Merged
merged 2 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions py/torch_tensorrt/dynamo/lowering/_decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,6 +549,18 @@ def scaled_dot_product_cudnn_attention_decomposition(
return attn, None, None, None, 0, 0, None, None, None


@register_torch_trt_decomposition(
torch.ops.aten.full_like, registry=TORCH_TRT_DECOMPOSITIONS
)
def full_like_decomposition(*args, **kwargs) -> torch.Tensor:
input = args[0]
shape = args[0].shape
fill_value = args[1]
kwargs["dtype"] = input.dtype
kwargs["device"] = to_torch_device(default_device())
return torch.full(shape, fill_value, dtype=kwargs["dtype"], device=kwargs["device"])


def get_decompositions(
enable_experimental_decompositions: bool = False,
) -> Dict[OpOverload, Callable[[Any], Any]]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from .remove_detach import remove_detach
from .remove_input_alias_fixing_clones import remove_input_alias_fixing_clones
from .repair_input_as_output import repair_input_as_output
from .replace_full_like_with_full import replace_full_like_with_full
from .replace_max_pool_with_indices import replace_max_pool_with_indices
from .view_to_reshape import view_to_reshape

Expand All @@ -25,7 +24,6 @@
lower_linear,
fuse_prims_broadcast,
replace_max_pool_with_indices,
replace_full_like_with_full,
view_to_reshape,
remove_assert_scalar,
accumulate_fp32_matmul,
Expand Down

This file was deleted.

74 changes: 69 additions & 5 deletions tests/py/dynamo/lowering/test_decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,10 @@
PLATFORM_SUPPORTS_CUDNN_ATTENTION,
PLATFORM_SUPPORTS_FLASH_ATTENTION,
)
from testing_utilities import DECIMALS_OF_AGREEMENT, lower_graph_testing
from torch.testing._internal.common_utils import TestCase, run_tests
from torch_tensorrt.dynamo.utils import ATOL, RTOL

from ..testing_utilities import DECIMALS_OF_AGREEMENT, lower_graph_testing


class TestLowering(TestCase):
def test_lowering_inplace_op(self):
Expand Down Expand Up @@ -434,11 +433,13 @@ def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)

def forward(self, x):
y = torch.full_like(x, 2.0)
return y
c = torch.ops.aten.add(x, x)
y = torch.ops.aten.full_like.default(c, 2)
d = y + c
return d

# Operations expected to be removed in the traced graph after decompositions
expected_ops = {torch.ops.aten.full.default}
expected_ops = {torch.ops.aten.add.Tensor}
unexpected_ops = {torch.ops.aten.full_like.default}

inputs = [torch.randn(3, 3, dtype=torch.float32).cuda()]
Expand Down Expand Up @@ -488,6 +489,69 @@ def forward(self, x):
f"FullLike TRT outputs don't match with the original model.",
)

def test_lowering_full_like_to_full_dynamic_module(self):
class FullLike(torch.nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)

def forward(self, x):
c = torch.ops.aten.add(x, x)
y = torch.ops.aten.full_like.default(c, 2)
d = y + c
return d

# Operations expected to be removed in the traced graph after decompositions
expected_ops = {torch.ops.aten.add.Tensor}
unexpected_ops = {torch.ops.aten.full_like.default}

inputs = [torch.randn(3, 3, dtype=torch.float32).cuda()]
torch._dynamo.mark_dynamic(inputs[0], 0, min=1, max=3)
fx_graph = torch.fx.symbolic_trace(FullLike())

unexpected_ops_seen, expected_ops_unseen = lower_graph_testing(
fx_graph,
inputs,
expected_ops=expected_ops,
unexpected_ops=unexpected_ops,
min_block_size=1,
)

self.assertEqual(
len(unexpected_ops_seen),
0,
f"The following unexpected ops were encountered: {unexpected_ops_seen}",
)

self.assertEqual(
len(expected_ops_unseen),
0,
f"The following expected ops were not encountered: {expected_ops_unseen}",
)

torch._dynamo.reset()

# Validate that the results between Torch and Torch-TRT are similar
optimized_model = torch_tensorrt.compile(
fx_graph,
"torch_compile",
inputs,
min_block_size=1,
truncate_double=True,
pass_through_build_failures=True,
)
optimized_model_results = optimized_model(*inputs).detach().cpu()
torch_model_results = fx_graph(*inputs).detach().cpu()

max_diff = float(
torch.max(torch.abs(optimized_model_results - torch_model_results))
)
self.assertAlmostEqual(
max_diff,
0,
DECIMALS_OF_AGREEMENT,
f"FullLike TRT outputs don't match with the original model.",
)

def test_lowering_empty_like_module(self):
class emptyLike(torch.nn.Module):
def __init__(self, *args, **kwargs) -> None:
Expand Down
Loading