Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: InstanceNorm decomposition #3288

Merged
merged 2 commits into from
Nov 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions py/torch_tensorrt/dynamo/lowering/_decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,26 @@ def log_softmax_decomposition(
)


@register_torch_trt_decomposition(aten.instance_norm, registry=TORCH_TRT_DECOMPOSITIONS)
def instance_norm_decomposition(
input: torch.Tensor,
weight: Optional[torch.Tensor],
bias: Optional[torch.Tensor],
running_mean: Optional[torch.Tensor],
running_var: Optional[torch.Tensor],
use_input_stats: bool,
momentum: float,
eps: float,
cudnn_enabled: bool,
) -> torch.Tensor:
if use_input_stats:
return torch.nn.functional.group_norm(input, input.shape[1], weight, bias, eps)
else:
return torch.nn.functional.batch_norm(
input, running_mean, running_var, weight, bias, False, momentum, eps
)


def get_decompositions(
enable_experimental_decompositions: bool = False,
) -> Dict[OpOverload, Callable[[Any], Any]]:
Expand Down
69 changes: 69 additions & 0 deletions tests/py/dynamo/lowering/test_decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1587,6 +1587,75 @@ def forward(self, x):
f"Log_softmax TRT outputs don't match with the original model.",
)

@parameterized.expand(
[
((1, 3, 5), True),
((1, 3, 5), False),
((2, 4, 6, 8), True),
((2, 4, 6, 8), False),
((3, 6, 9, 12, 15), True),
((3, 6, 9, 12, 15), False),
]
)
def test_lowering_instance_norm(self, shape, use_input_stats):
class TestModule(torch.nn.Module):
def forward(self, input, weight, bias, running_mean=None, running_var=None):
return torch.ops.aten.instance_norm.default(
input,
weight,
bias,
running_mean,
running_var,
use_input_stats,
0.1,
1e-05,
True,
)

# Operations expected to be removed in the traced graph after decompositions
unexpected_ops = {torch.ops.aten.instance_norm.default}

inputs = [
torch.randn(shape, device="cuda"),
torch.randn(shape[1], device="cuda"),
torch.randn(shape[1], device="cuda"),
]
if not use_input_stats:
inputs += [
torch.randn(shape[1], device="cuda"),
torch.rand(shape[1], device="cuda"),
]

fx_graph = torch.fx.symbolic_trace(TestModule())
unexpected_ops_seen, _ = lower_graph_testing(
fx_graph, inputs, unexpected_ops=unexpected_ops, min_block_size=1
)

self.assertEqual(
len(unexpected_ops_seen),
0,
f"The following unexpected ops were encountered: {unexpected_ops_seen}",
)

torch._dynamo.reset()

# Validate that the results between Torch and Torch-TRT are similar
optimized_model = torch_tensorrt.compile(
fx_graph, "dynamo", inputs, min_block_size=1
)
optimized_model_results = optimized_model(*inputs).detach().cpu()
torch_model_results = fx_graph(*inputs).detach().cpu()

max_diff = float(
torch.max(torch.abs(optimized_model_results - torch_model_results))
)
self.assertAlmostEqual(
max_diff,
0,
DECIMALS_OF_AGREEMENT,
"Instance_norm TRT outputs don't match with the original model.",
)


if __name__ == "__main__":
run_tests()
Loading