Skip to content

Commit

Permalink
Fix LayerNorm fp16 precision
Browse files Browse the repository at this point in the history
  • Loading branch information
HolyWu committed Nov 3, 2024
1 parent 8e2c82d commit 8148866
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 78 deletions.
12 changes: 3 additions & 9 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,10 +134,6 @@ def aten_ops_batch_norm_legit_no_training(
capability_validator=one_user_validator,
supports_dynamic_shapes=True,
)
@dynamo_tensorrt_converter(
torch.ops.aten.layer_norm.default, supports_dynamic_shapes=True
)
@dynamo_tensorrt_converter(torch.ops.aten.layer_norm, supports_dynamic_shapes=True)
@enforce_tensor_types(
{
0: (TRTTensor,),
Expand All @@ -157,11 +153,9 @@ def aten_ops_layer_norm(
name,
input=args[0],
normalized_shape=args[1],
weight=args_bounds_check(args, 2, 1.0),
bias=args_bounds_check(args, 3, 0.0),
eps=args_bounds_check(args, 4, 1e-05),
cudnn_enable=args_bounds_check(args, 5, True),
return_mean_rstd=(target == torch.ops.aten.native_layer_norm.default),
weight=args[2],
bias=args[3],
eps=args[4],
)


Expand Down
39 changes: 16 additions & 23 deletions py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,49 +159,42 @@ def layer_norm(
name: str,
input: TRTTensor,
normalized_shape: List[int],
weight: Optional[Union[torch.Tensor, np.ndarray]],
bias: Optional[Union[torch.Tensor, np.ndarray]],
weight: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]],
bias: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]],
eps: float,
cudnn_enable: bool,
return_mean_rstd: bool,
) -> Union[TRTTensor, Tuple[TRTTensor, torch.Tensor, torch.Tensor]]:
) -> Tuple[TRTTensor, torch.Tensor, torch.Tensor]:
dims = list(range(len(input.shape) - len(normalized_shape), len(input.shape)))
axes = get_axes_for_reduce_op(dims)
weight = get_trt_tensor(ctx, weight, f"{name}_weight")
bias = get_trt_tensor(ctx, bias, f"{name}_bias")

weight = get_trt_tensor(
ctx, weight if weight is not None else 1.0, f"{name}_weight"
)
bias = get_trt_tensor(ctx, bias if bias is not None else 0.0, f"{name}_bias")

# Cast weight and bias to have same dtype as input
weight = cast_trt_tensor(
ctx, weight, input.dtype, f"{name}_weight_cast", target, source_ir
)
bias = cast_trt_tensor(
ctx, bias, input.dtype, f"{name}_bias_cast", target, source_ir
)

if tuple(input.shape) != tuple(weight.shape):
weight = impl.slice.expand(
ctx, target, source_ir, f"{name}_expand_weight", weight, input.shape
)

if tuple(input.shape) != tuple(bias.shape):
bias = impl.slice.expand(
ctx, target, source_ir, f"{name}_expand_bias", bias, input.shape
)
strongly_typed_network = False
if ctx.net.get_flag(trt.NetworkDefinitionCreationFlag.STRONGLY_TYPED):
weight = cast_trt_tensor(ctx, weight, input.dtype, name)
bias = cast_trt_tensor(ctx, bias, input.dtype, name)
strongly_typed_network = True

layer_norm = ctx.net.add_normalization(input, weight, bias, axes)
layer_norm.epsilon = eps
# compute_precision ignored for strongly typed network.
if not strongly_typed_network:
layer_norm.compute_precision = input.dtype
set_layer_name(layer_norm, target, f"{name}_layer_norm", source_ir)

if return_mean_rstd:
# return fake mean and rstd for now
return layer_norm.get_output(0), None, None
layer = ctx.net.add_normalization(input, weight, bias, axes)
layer.epsilon = eps
set_layer_name(layer, target, name, source_ir)

return layer_norm.get_output(0)
# return fake mean and rstd for now
return layer.get_output(0), None, None


def native_group_norm(
Expand Down
55 changes: 9 additions & 46 deletions tests/py/dynamo/conversion/test_layer_norm_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,61 +6,24 @@
from .harness import DispatchTestCase


class TestLayerNormConverter(DispatchTestCase):
@parameterized.expand(
[
(
(5, 3, 2, 4),
[
4,
],
),
((5, 3, 2, 4), [2, 4]),
((5, 3, 2, 4), [3, 2, 4]),
((5, 3, 2, 4), [5, 3, 2, 4]),
]
)
def test_layer_norm(self, input_shape, normalized_shape, eps=1e-05):
class LayerNorm(torch.nn.Module):
def forward(self, x):
return torch.ops.aten.layer_norm.default(
x,
normalized_shape,
torch.randn(normalized_shape),
torch.randn(normalized_shape),
eps,
)

inputs = [torch.randn(input_shape)]
self.run_test(
LayerNorm(),
inputs,
)


class TestNativeLayerNormConverter(DispatchTestCase):
@parameterized.expand(
[
(
(5, 3, 2, 4),
[
4,
],
),
((5, 3, 2, 4), [4]),
((5, 3, 2, 4), [2, 4]),
((5, 3, 2, 4), [3, 2, 4]),
((5, 3, 2, 4), [5, 3, 2, 4]),
]
)
def test_layer_norm(self, input_shape, normalized_shape, eps=1e-05):
def test_layer_norm(self, input_shape, normalized_shape):
class LayerNorm(torch.nn.Module):
def forward(self, x):
return torch.ops.aten.native_layer_norm.default(
x,
normalized_shape,
torch.randn(normalized_shape),
torch.randn(normalized_shape),
eps,
1e-05,
)[0]

inputs = [torch.randn(input_shape)]
Expand All @@ -74,7 +37,7 @@ class LayerNorm(torch.nn.Module):
def forward(self, x):
return torch.ops.aten.native_layer_norm.default(
x,
torch.tensor([3, 224, 224]),
[3, 224, 224],
torch.ones((3, 224, 224)),
torch.zeros((3, 224, 224)),
1e-05,
Expand All @@ -99,9 +62,9 @@ class LayerNorm(torch.nn.Module):
def forward(self, x):
return torch.ops.aten.native_layer_norm.default(
x,
torch.tensor([3]),
torch.ones((3)),
torch.zeros((3)),
[3],
torch.randn((3)),
torch.randn((3)),
1e-05,
)[0]

Expand All @@ -120,15 +83,15 @@ def forward(self, x):
)

@parameterized.expand([((5, 3, 2, 4), [2, 4])])
def test_layer_norm_without_Scaling(self, input_shape, normalized_shape, eps=1e-05):
def test_layer_norm_without_Scaling(self, input_shape, normalized_shape):
class LayerNorm(torch.nn.Module):
def forward(self, x):
return torch.ops.aten.native_layer_norm.default(
x,
normalized_shape,
None,
None,
eps,
1e-05,
)[0]

inputs = [torch.randn(input_shape)]
Expand Down

0 comments on commit 8148866

Please sign in to comment.