Skip to content

Commit

Permalink
full_like to full decomposition moving to decomposition.py for dynami… (
Browse files Browse the repository at this point in the history
  • Loading branch information
apbose authored Dec 18, 2024
1 parent 1849a3c commit a47e590
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 70 deletions.
12 changes: 12 additions & 0 deletions py/torch_tensorrt/dynamo/lowering/_decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,6 +549,18 @@ def scaled_dot_product_cudnn_attention_decomposition(
return attn, None, None, None, 0, 0, None, None, None


@register_torch_trt_decomposition(
torch.ops.aten.full_like, registry=TORCH_TRT_DECOMPOSITIONS
)
def full_like_decomposition(*args, **kwargs) -> torch.Tensor:
input = args[0]
shape = args[0].shape
fill_value = args[1]
kwargs["dtype"] = input.dtype
kwargs["device"] = to_torch_device(default_device())
return torch.full(shape, fill_value, dtype=kwargs["dtype"], device=kwargs["device"])


def get_decompositions(
enable_experimental_decompositions: bool = False,
) -> Dict[OpOverload, Callable[[Any], Any]]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from .remove_detach import remove_detach
from .remove_input_alias_fixing_clones import remove_input_alias_fixing_clones
from .repair_input_as_output import repair_input_as_output
from .replace_full_like_with_full import replace_full_like_with_full
from .replace_max_pool_with_indices import replace_max_pool_with_indices
from .view_to_reshape import view_to_reshape

Expand All @@ -23,7 +22,6 @@
repair_input_as_output,
fuse_prims_broadcast,
replace_max_pool_with_indices,
replace_full_like_with_full,
view_to_reshape,
remove_assert_scalar,
accumulate_fp32_matmul,
Expand Down

This file was deleted.

74 changes: 69 additions & 5 deletions tests/py/dynamo/lowering/test_decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,10 @@
PLATFORM_SUPPORTS_CUDNN_ATTENTION,
PLATFORM_SUPPORTS_FLASH_ATTENTION,
)
from testing_utilities import DECIMALS_OF_AGREEMENT, lower_graph_testing
from torch.testing._internal.common_utils import TestCase, run_tests
from torch_tensorrt.dynamo.utils import ATOL, RTOL

from ..testing_utilities import DECIMALS_OF_AGREEMENT, lower_graph_testing


class TestLowering(TestCase):
def test_lowering_inplace_op(self):
Expand Down Expand Up @@ -434,11 +433,13 @@ def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)

def forward(self, x):
y = torch.full_like(x, 2.0)
return y
c = torch.ops.aten.add(x, x)
y = torch.ops.aten.full_like.default(c, 2)
d = y + c
return d

# Operations expected to be removed in the traced graph after decompositions
expected_ops = {torch.ops.aten.full.default}
expected_ops = {torch.ops.aten.add.Tensor}
unexpected_ops = {torch.ops.aten.full_like.default}

inputs = [torch.randn(3, 3, dtype=torch.float32).cuda()]
Expand Down Expand Up @@ -488,6 +489,69 @@ def forward(self, x):
f"FullLike TRT outputs don't match with the original model.",
)

def test_lowering_full_like_to_full_dynamic_module(self):
class FullLike(torch.nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)

def forward(self, x):
c = torch.ops.aten.add(x, x)
y = torch.ops.aten.full_like.default(c, 2)
d = y + c
return d

# Operations expected to be removed in the traced graph after decompositions
expected_ops = {torch.ops.aten.add.Tensor}
unexpected_ops = {torch.ops.aten.full_like.default}

inputs = [torch.randn(3, 3, dtype=torch.float32).cuda()]
torch._dynamo.mark_dynamic(inputs[0], 0, min=1, max=3)
fx_graph = torch.fx.symbolic_trace(FullLike())

unexpected_ops_seen, expected_ops_unseen = lower_graph_testing(
fx_graph,
inputs,
expected_ops=expected_ops,
unexpected_ops=unexpected_ops,
min_block_size=1,
)

self.assertEqual(
len(unexpected_ops_seen),
0,
f"The following unexpected ops were encountered: {unexpected_ops_seen}",
)

self.assertEqual(
len(expected_ops_unseen),
0,
f"The following expected ops were not encountered: {expected_ops_unseen}",
)

torch._dynamo.reset()

# Validate that the results between Torch and Torch-TRT are similar
optimized_model = torch_tensorrt.compile(
fx_graph,
"torch_compile",
inputs,
min_block_size=1,
truncate_double=True,
pass_through_build_failures=True,
)
optimized_model_results = optimized_model(*inputs).detach().cpu()
torch_model_results = fx_graph(*inputs).detach().cpu()

max_diff = float(
torch.max(torch.abs(optimized_model_results - torch_model_results))
)
self.assertAlmostEqual(
max_diff,
0,
DECIMALS_OF_AGREEMENT,
f"FullLike TRT outputs don't match with the original model.",
)

def test_lowering_empty_like_module(self):
class emptyLike(torch.nn.Module):
def __init__(self, *args, **kwargs) -> None:
Expand Down

0 comments on commit a47e590

Please sign in to comment.