Skip to content

Commit

Permalink
feat: InstanceNorm decomposition
Browse files Browse the repository at this point in the history
  • Loading branch information
HolyWu committed Nov 11, 2024
1 parent b4de166 commit 3c9ac96
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 0 deletions.
20 changes: 20 additions & 0 deletions py/torch_tensorrt/dynamo/lowering/_decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,26 @@ def log_softmax_decomposition(
)


@register_torch_trt_decomposition(aten.instance_norm, registry=TORCH_TRT_DECOMPOSITIONS)
def instance_norm_decomposition(
input: torch.Tensor,
weight: Optional[torch.Tensor],
bias: Optional[torch.Tensor],
running_mean: Optional[torch.Tensor],
running_var: Optional[torch.Tensor],
use_input_stats: bool,
momentum: float,
eps: float,
cudnn_enabled: bool,
) -> torch.Tensor:
if use_input_stats:
return torch.nn.functional.group_norm(input, input.shape[1], weight, bias, eps)
else:
return torch.nn.functional.batch_norm(
input, running_mean, running_var, weight, bias, False, momentum, eps
)


def get_decompositions(
enable_experimental_decompositions: bool = False,
) -> Dict[OpOverload, Callable[[Any], Any]]:
Expand Down
69 changes: 69 additions & 0 deletions tests/py/dynamo/lowering/test_decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1587,6 +1587,75 @@ def forward(self, x):
f"Log_softmax TRT outputs don't match with the original model.",
)

@parameterized.expand(
[
((1, 3, 5), True),
((1, 3, 5), False),
((2, 4, 6, 8), True),
((2, 4, 6, 8), False),
((3, 6, 9, 12, 15), True),
((3, 6, 9, 12, 15), False),
]
)
def test_lowering_instance_norm(self, shape, use_input_stats):
class TestModule(torch.nn.Module):
def forward(self, input, weight, bias, running_mean=None, running_var=None):
return torch.ops.aten.instance_norm.default(
input,
weight,
bias,
running_mean,
running_var,
use_input_stats,
0.1,
1e-05,
True,
)

# Operations expected to be removed in the traced graph after decompositions
unexpected_ops = {torch.ops.aten.instance_norm.default}

inputs = [
torch.randn(shape, device="cuda"),
torch.randn(shape[1], device="cuda"),
torch.randn(shape[1], device="cuda"),
]
if not use_input_stats:
inputs += [
torch.randn(shape[1], device="cuda"),
torch.rand(shape[1], device="cuda"),
]

fx_graph = torch.fx.symbolic_trace(TestModule())
unexpected_ops_seen, _ = lower_graph_testing(
fx_graph, inputs, unexpected_ops=unexpected_ops, min_block_size=1
)

self.assertEqual(
len(unexpected_ops_seen),
0,
f"The following unexpected ops were encountered: {unexpected_ops_seen}",
)

torch._dynamo.reset()

# Validate that the results between Torch and Torch-TRT are similar
optimized_model = torch_tensorrt.compile(
fx_graph, "dynamo", inputs, min_block_size=1
)
optimized_model_results = optimized_model(*inputs).detach().cpu()
torch_model_results = fx_graph(*inputs).detach().cpu()

max_diff = float(
torch.max(torch.abs(optimized_model_results - torch_model_results))
)
self.assertAlmostEqual(
max_diff,
0,
DECIMALS_OF_AGREEMENT,
"Instance_norm TRT outputs don't match with the original model.",
)


if __name__ == "__main__":
run_tests()

0 comments on commit 3c9ac96

Please sign in to comment.