From 81488668f9454320fda7fa32ce60bafc5e5348fe Mon Sep 17 00:00:00 2001 From: HolyWu Date: Sun, 3 Nov 2024 18:05:40 +0800 Subject: [PATCH] Fix LayerNorm fp16 precision --- .../dynamo/conversion/aten_ops_converters.py | 12 +--- .../conversion/impl/normalization/ops.py | 39 ++++++------- .../dynamo/conversion/test_layer_norm_aten.py | 55 +++---------------- 3 files changed, 28 insertions(+), 78 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 07c8c03697..c52248eaea 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -134,10 +134,6 @@ def aten_ops_batch_norm_legit_no_training( capability_validator=one_user_validator, supports_dynamic_shapes=True, ) -@dynamo_tensorrt_converter( - torch.ops.aten.layer_norm.default, supports_dynamic_shapes=True -) -@dynamo_tensorrt_converter(torch.ops.aten.layer_norm, supports_dynamic_shapes=True) @enforce_tensor_types( { 0: (TRTTensor,), @@ -157,11 +153,9 @@ def aten_ops_layer_norm( name, input=args[0], normalized_shape=args[1], - weight=args_bounds_check(args, 2, 1.0), - bias=args_bounds_check(args, 3, 0.0), - eps=args_bounds_check(args, 4, 1e-05), - cudnn_enable=args_bounds_check(args, 5, True), - return_mean_rstd=(target == torch.ops.aten.native_layer_norm.default), + weight=args[2], + bias=args[3], + eps=args[4], ) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py index 4f39a6d5d9..9d69daa1e8 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py @@ -159,16 +159,18 @@ def layer_norm( name: str, input: TRTTensor, normalized_shape: List[int], - weight: Optional[Union[torch.Tensor, np.ndarray]], - bias: Optional[Union[torch.Tensor, np.ndarray]], + weight: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]], + bias: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]], eps: float, - cudnn_enable: bool, - return_mean_rstd: bool, -) -> Union[TRTTensor, Tuple[TRTTensor, torch.Tensor, torch.Tensor]]: +) -> Tuple[TRTTensor, torch.Tensor, torch.Tensor]: dims = list(range(len(input.shape) - len(normalized_shape), len(input.shape))) axes = get_axes_for_reduce_op(dims) - weight = get_trt_tensor(ctx, weight, f"{name}_weight") - bias = get_trt_tensor(ctx, bias, f"{name}_bias") + + weight = get_trt_tensor( + ctx, weight if weight is not None else 1.0, f"{name}_weight" + ) + bias = get_trt_tensor(ctx, bias if bias is not None else 0.0, f"{name}_bias") + # Cast weight and bias to have same dtype as input weight = cast_trt_tensor( ctx, weight, input.dtype, f"{name}_weight_cast", target, source_ir @@ -176,32 +178,23 @@ def layer_norm( bias = cast_trt_tensor( ctx, bias, input.dtype, f"{name}_bias_cast", target, source_ir ) + if tuple(input.shape) != tuple(weight.shape): weight = impl.slice.expand( ctx, target, source_ir, f"{name}_expand_weight", weight, input.shape ) + if tuple(input.shape) != tuple(bias.shape): bias = impl.slice.expand( ctx, target, source_ir, f"{name}_expand_bias", bias, input.shape ) - strongly_typed_network = False - if ctx.net.get_flag(trt.NetworkDefinitionCreationFlag.STRONGLY_TYPED): - weight = cast_trt_tensor(ctx, weight, input.dtype, name) - bias = cast_trt_tensor(ctx, bias, input.dtype, name) - strongly_typed_network = True - - layer_norm = ctx.net.add_normalization(input, weight, bias, axes) - layer_norm.epsilon = eps - # compute_precision ignored for strongly typed network. - if not strongly_typed_network: - layer_norm.compute_precision = input.dtype - set_layer_name(layer_norm, target, f"{name}_layer_norm", source_ir) - if return_mean_rstd: - # return fake mean and rstd for now - return layer_norm.get_output(0), None, None + layer = ctx.net.add_normalization(input, weight, bias, axes) + layer.epsilon = eps + set_layer_name(layer, target, name, source_ir) - return layer_norm.get_output(0) + # return fake mean and rstd for now + return layer.get_output(0), None, None def native_group_norm( diff --git a/tests/py/dynamo/conversion/test_layer_norm_aten.py b/tests/py/dynamo/conversion/test_layer_norm_aten.py index c6cfc430ba..d5658ff93a 100644 --- a/tests/py/dynamo/conversion/test_layer_norm_aten.py +++ b/tests/py/dynamo/conversion/test_layer_norm_aten.py @@ -6,53 +6,16 @@ from .harness import DispatchTestCase -class TestLayerNormConverter(DispatchTestCase): - @parameterized.expand( - [ - ( - (5, 3, 2, 4), - [ - 4, - ], - ), - ((5, 3, 2, 4), [2, 4]), - ((5, 3, 2, 4), [3, 2, 4]), - ((5, 3, 2, 4), [5, 3, 2, 4]), - ] - ) - def test_layer_norm(self, input_shape, normalized_shape, eps=1e-05): - class LayerNorm(torch.nn.Module): - def forward(self, x): - return torch.ops.aten.layer_norm.default( - x, - normalized_shape, - torch.randn(normalized_shape), - torch.randn(normalized_shape), - eps, - ) - - inputs = [torch.randn(input_shape)] - self.run_test( - LayerNorm(), - inputs, - ) - - class TestNativeLayerNormConverter(DispatchTestCase): @parameterized.expand( [ - ( - (5, 3, 2, 4), - [ - 4, - ], - ), + ((5, 3, 2, 4), [4]), ((5, 3, 2, 4), [2, 4]), ((5, 3, 2, 4), [3, 2, 4]), ((5, 3, 2, 4), [5, 3, 2, 4]), ] ) - def test_layer_norm(self, input_shape, normalized_shape, eps=1e-05): + def test_layer_norm(self, input_shape, normalized_shape): class LayerNorm(torch.nn.Module): def forward(self, x): return torch.ops.aten.native_layer_norm.default( @@ -60,7 +23,7 @@ def forward(self, x): normalized_shape, torch.randn(normalized_shape), torch.randn(normalized_shape), - eps, + 1e-05, )[0] inputs = [torch.randn(input_shape)] @@ -74,7 +37,7 @@ class LayerNorm(torch.nn.Module): def forward(self, x): return torch.ops.aten.native_layer_norm.default( x, - torch.tensor([3, 224, 224]), + [3, 224, 224], torch.ones((3, 224, 224)), torch.zeros((3, 224, 224)), 1e-05, @@ -99,9 +62,9 @@ class LayerNorm(torch.nn.Module): def forward(self, x): return torch.ops.aten.native_layer_norm.default( x, - torch.tensor([3]), - torch.ones((3)), - torch.zeros((3)), + [3], + torch.randn((3)), + torch.randn((3)), 1e-05, )[0] @@ -120,7 +83,7 @@ def forward(self, x): ) @parameterized.expand([((5, 3, 2, 4), [2, 4])]) - def test_layer_norm_without_Scaling(self, input_shape, normalized_shape, eps=1e-05): + def test_layer_norm_without_Scaling(self, input_shape, normalized_shape): class LayerNorm(torch.nn.Module): def forward(self, x): return torch.ops.aten.native_layer_norm.default( @@ -128,7 +91,7 @@ def forward(self, x): normalized_shape, None, None, - eps, + 1e-05, )[0] inputs = [torch.randn(input_shape)]