Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix LayerNorm fp16 precision #3272

Open
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

HolyWu
Copy link
Contributor

@HolyWu HolyWu commented Nov 3, 2024

Description

Setting layer.compute_precision = input.dtype causes accuracy issue in FP16 mode. https://docs.nvidia.com/deeplearning/tensorrt/api/python_api/infer/Graph/Layers.html#inormalizationlayer said By default TensorRT will run the normalization computation in DataType.kFLOAT32 even in mixed precision mode regardless of any set builder flags to avoid overflow errors.

Also, the operator actually taking effect is only aten.native_layer_norm.default. aten.layer_norm and aten.layer_norm.default are of no use and hence redundant.

To Reproduce

import os

import torch
import torch.nn as nn
import torch_tensorrt

os.environ["CI_BUILD"] = "1"


class MyModule(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.m = nn.LayerNorm([512, 224, 224])

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.m(x)


with torch.inference_mode():
    model = MyModule().eval().cuda().half()
    inputs = (
        torch.randn(1, 512, 224, 224, dtype=torch.half, device="cuda"),
    )

    exported_program = torch.export.export(model, inputs)

    trt_model = torch_tensorrt.dynamo.compile(
        exported_program,
        inputs,
        enabled_precisions={torch.half},
        debug=True,
        min_block_size=1,
    )

    torch.testing.assert_close(model(*inputs), trt_model(*inputs), rtol=5e-3, atol=5e-3)

Before Patch

DEBUG:torch_tensorrt.dynamo.lowering.passes.remove_detach:Removed 0 detach nodes:
graph():
    %p_m_weight : [num_users=1] = placeholder[target=p_m_weight]
    %p_m_bias : [num_users=1] = placeholder[target=p_m_bias]
    %x : [num_users=1] = placeholder[target=x]
    %layer_norm : [num_users=1] = call_function[target=torch.ops.aten.layer_norm.default](args = (%x, [512, 224, 224], %p_m_weight, %p_m_bias), kwargs = {})
    return (layer_norm,)
DEBUG:torch_tensorrt.dynamo._compiler:Input graph: graph():
    %m_weight : [num_users=1] = get_attr[target=m.weight]
    %m_bias : [num_users=1] = get_attr[target=m.bias]
    %x : [num_users=1] = placeholder[target=x]
    %native_layer_norm : [num_users=1] = call_function[target=torch.ops.aten.native_layer_norm.default](args = (%x, [512, 224, 224], %m_weight, %m_bias, 1e-05), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%native_layer_norm, 0), kwargs = {})
    return (getitem,)
DEBUG:torch_tensorrt.dynamo.lowering.passes.constant_folding:Graph after constant folding:
graph():
    %m_weight : [num_users=1] = get_attr[target=m.weight]
    %m_bias : [num_users=1] = get_attr[target=m.bias]
    %x : [num_users=1] = placeholder[target=x]
    %native_layer_norm : [num_users=1] = call_function[target=torch.ops.aten.native_layer_norm.default](args = (%x, [512, 224, 224], %m_weight, %m_bias, 1e-05), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%native_layer_norm, 0), kwargs = {})
    return (getitem,)
DEBUG:torch_tensorrt.dynamo.lowering.passes.remove_assert_scalar:Removed 0 assert_scalar nodes:
graph():
    %m_weight : [num_users=1] = get_attr[target=m.weight]
    %m_bias : [num_users=1] = get_attr[target=m.bias]
    %x : [num_users=1] = placeholder[target=x]
    %native_layer_norm : [num_users=1] = call_function[target=torch.ops.aten.native_layer_norm.default](args = (%x, [512, 224, 224], %m_weight, %m_bias, 1e-05), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%native_layer_norm, 0), kwargs = {})
    return (getitem,)
DEBUG:torch_tensorrt.dynamo.lowering.passes.accumulate_fp32_matmul:Skipping FP32 accumulation for matmul layers as use_fp32_acc is not enabled in the compilation settings
DEBUG:torch_tensorrt.dynamo._compiler:Lowered Input graph: graph():
    %m_weight : [num_users=1] = get_attr[target=m.weight]
    %m_bias : [num_users=1] = get_attr[target=m.bias]
    %x : [num_users=1] = placeholder[target=x]
    %native_layer_norm : [num_users=1] = call_function[target=torch.ops.aten.native_layer_norm.default](args = (%x, [512, 224, 224], %m_weight, %m_bias, 1e-05), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%native_layer_norm, 0), kwargs = {})
    return (getitem,)
DEBUG:torch_tensorrt.dynamo.partitioning._global_partitioner:
Supported Nodes:
- torch.ops.aten.native_layer_norm.default + Operator Count: 1
- _operator.getitem + Operator Count: 1

DEBUG:torch_tensorrt.dynamo.partitioning._global_partitioner:
All Nodes Supported

DEBUG:torch_tensorrt.dynamo._compiler:Detected support for 2 operators out of 2 in subgraph.
INFO:torch_tensorrt.dynamo._compiler:Partitioning the graph via the fast partitioner
DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:
Number of TensorRT-Accelerated Engines Generated: 1
DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:
Supported Nodes:
- torch.ops.aten.native_layer_norm.default + Operator Count: 1
- _operator.getitem + Operator Count: 1

DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:
All Nodes Supported

DEBUG:torch_tensorrt.dynamo._compiler:Updated metadata for node: _run_on_acc_0 with its corresponding submodule outputs
DEBUG:torch_tensorrt.dynamo._compiler:Converting submodule: _run_on_acc_0
 Input shapes: [(1, 512, 224, 224)]
 graph():
    %x : [num_users=1] = placeholder[target=x]
    %m_weight : [num_users=1] = get_attr[target=m.weight]
    %m_bias : [num_users=1] = get_attr[target=m.bias]
    %native_layer_norm : [num_users=1] = call_function[target=torch.ops.aten.native_layer_norm.default](args = (%x, [512, 224, 224], %m_weight, %m_bias, 1e-05), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%native_layer_norm, 0), kwargs = {})
    return getitem
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node x (kind: x, args: ())
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Adding input to in-progress INetwork: x [shape=[1, 512, 224, 224], dtype=DataType.HALF]
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node x [x] (Inputs: () | Outputs: (x: (1, 512, 224, 224)@torch.float16))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node m_weight (kind: m.weight, args: ())
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node m_weight [m.weight] (Inputs: () | Outputs: (m_weight: (512, 224, 224)@float16))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node m_bias (kind: m.bias, args: ())
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node m_bias [m.bias] (Inputs: () | Outputs: (m_bias: (512, 224, 224)@float16))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node m/native_layer_norm (kind: aten.native_layer_norm.default, args: ('x <Node>', ['512 <int>', '224 <int>', '224 <int>'], 'm_weight <Node>', 'm_bias <Node>', '1e-05 <float>'))
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node m/native_layer_norm [aten.native_layer_norm.default] (Inputs: (x: (1, 512, 224, 224)@torch.float16, [512, 224, 224], m_weight: (512, 224, 224)@float16, m_bias: (512, 224, 224)@float16, 1e-05) | Outputs: (native_layer_norm: ((1, 512, 224, 224)@torch.float16, (1, 1, 1, 1)@torch.float32, (1, 1, 1, 1)@torch.float32)))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node m/getitem (kind: <built-in function getitem>, args: ('native_layer_norm <Node>', '0 <int>'))
DEBUG:torch_tensorrt.dynamo.conversion.ops_evaluators:Evaluating _operator.getitem on object with name: m/getitem
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node m/getitem [<built-in function getitem>] (Inputs: (native_layer_norm: ((1, 512, 224, 224)@torch.float16, (1, 1, 1, 1)@torch.float32, (1, 1, 1, 1)@torch.float32), 0) | Outputs: (getitem: (1, 512, 224, 224)@torch.float16))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node output (kind: output, args: ('getitem <Node>',))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Marking output output0 [shape=(1, 512, 224, 224), dtype=DataType.HALF]
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node output [output] (Inputs: (getitem: (1, 512, 224, 224)@torch.float16) | Outputs: (output: ))
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:TRT INetwork construction elapsed time: 0:00:00.034229
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Not found cached TRT engines. Start building engine.
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Build TRT engine elapsed time: 0:00:00.474792
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:TRT Engine uses: 102798204 bytes of Memory
DEBUG: [Torch-TensorRT] - Deserializing Device Info: 0%8%9%0%NVIDIA GeForce RTX 4060 Ti
DEBUG: [Torch-TensorRT] - Deserialized Device Info: Device(ID: 0, Name: NVIDIA GeForce RTX 4060 Ti, SM Capability: 8.9, Type: GPU)
DEBUG: [Torch-TensorRT] - Target Device: Device(ID: 0, Name: NVIDIA GeForce RTX 4060 Ti, SM Capability: 8.9, Type: GPU)
DEBUG: [Torch-TensorRT] - Setting Device(ID: 0, Name: NVIDIA GeForce RTX 4060 Ti, SM Capability: 8.9, Type: GPU) as active device
INFO: [Torch-TensorRT] - Loaded engine size: 98 MiB
DEBUG: [Torch-TensorRT] - Deserialization required 15759 microseconds.
DEBUG: [Torch-TensorRT] - Total per-runner device persistent memory is 0
DEBUG: [Torch-TensorRT] - Total per-runner host persistent memory is 32
DEBUG: [Torch-TensorRT] - Allocated device scratch memory of size 2507264
DEBUG: [Torch-TensorRT] - - Runner scratch: 2507264 bytes
DEBUG: [Torch-TensorRT] - [runner] Allocating resources for 1 graphs.
INFO: [Torch-TensorRT] - [MemUsageChange] TensorRT-managed allocation in IExecutionContext creation: CPU +0, GPU +2, now: CPU 0, GPU 100 (MiB)
DEBUG: [Torch-TensorRT] - CUDA lazy loading is enabled.
DEBUG: [Torch-TensorRT] - Input binding name: x has TensorRT binding index: 0, Torch binding index: 0
DEBUG: [Torch-TensorRT] - Output binding name: output0 has TensorRT binding index: 1, Torch binding index: 1
DEBUG: [Torch-TensorRT] - Torch-TensorRT TensorRT Engine:
  Name: _run_on_acc_0_engine
  Inputs: [
    id: 0
      name: x
      shape: [1, 512, 224, 224]
      dtype: Half
  ]
  Outputs: [
    id: 0
      name: output0
      shape: [1, 512, 224, 224]
      dtype: Half
  ]
  Device: Device(ID: 0, Name: NVIDIA GeForce RTX 4060 Ti, SM Capability: 8.9, Type: GPU)
  Hardware Compatibility: Disabled
  Target Platform: windows_x86_64

DEBUG:torch_tensorrt.dynamo._DryRunTracker:
++++++++++++++++++++++++++++++++++++++++++++++++++ Dry-Run Results for Graph ++++++++++++++++++++++++++++++++++++++++++++++++++

The graph consists of 2 Total Operators, of which 2 operators are supported, 100.0% coverage

Compiled with: CompilationSettings(enabled_precisions={<dtype.f16: 6>}, debug=True, workspace_size=0, min_block_size=1, torch_executed_ops=set(), pass_through_build_failures=False, max_aux_streams=None, version_compatible=False, optimization_level=None, use_python_runtime=False, truncate_double=False, use_fast_partitioner=True, enable_experimental_decompositions=False, device=Device(type=DeviceType.GPU, gpu_id=0), require_full_compilation=False, disable_tf32=False, assume_dynamic_shape_support=False, sparse_weights=False, make_refittable=False, engine_capability=<EngineCapability.STANDARD: 1>, num_avg_timing_iters=1, dla_sram_size=1048576, dla_local_dram_size=1073741824, dla_global_dram_size=536870912, dryrun=False, hardware_compatible=False, timing_cache_path='C:\\Users\\HolyWu\\AppData\\Local\\Temp\\torch_tensorrt_engine_cache\\timing_cache.bin', lazy_engine_init=False, cache_built_engines=False, reuse_cached_engines=False, use_explicit_typing=False, use_fp32_acc=False, enable_weight_streaming=False)

  Graph Structure:

   Inputs: List[Tensor: (1, 512, 224, 224)@float16]
    ...
    TRT Engine #1 - Submodule name: _run_on_acc_0
     Engine Inputs: List[Tensor: (1, 512, 224, 224)@float16]
     Number of Operators in Engine: 2
     Engine Outputs: List[Tensor: (1, 512, 224, 224)@float16]
    ...
   Outputs: List[Tensor: (1, 512, 224, 224)@float16]

  ------------------------- Aggregate Stats -------------------------

   Average Number of Operators per TRT Engine: 2.0
   Most Operators in a TRT Engine: 2

  ********** Recommendations **********

   - For minimal graph segmentation, select min_block_size=2 which would generate 1 TRT engine(s)
   - The current level of graph segmentation is equivalent to selecting min_block_size=2 which generates 1 TRT engine(s)
DEBUG: [Torch-TensorRT] - Attempting to run engine (ID: _run_on_acc_0_engine); Hardware Compatible: 0
DEBUG: [Torch-TensorRT] - Input Name: x Shape: [1, 512, 224, 224]
DEBUG: [Torch-TensorRT] - Output Name: output0 Shape: [1, 512, 224, 224]
Traceback (most recent call last):
  File "C:\Users\HolyWu\Downloads\test.py", line 35, in <module>
    torch.testing.assert_close(model(*inputs), trt_model(*inputs), rtol=5e-3, atol=5e-3)
  File "C:\Python312\Lib\site-packages\torch\testing\_comparison.py", line 1530, in assert_close
    raise error_metas[0].to_error(msg)
AssertionError: Tensor-likes are not close!

Mismatched elements: 25587877 / 25690112 (99.6%)
Greatest absolute difference: 5.39453125 at index (0, 97, 172, 140) (up to 0.005 allowed)
Greatest relative difference: inf at index (0, 0, 0, 0) (up to 0.005 allowed)

After Patch

DEBUG:torch_tensorrt.dynamo.lowering.passes.remove_detach:Removed 0 detach nodes:
graph():
    %p_m_weight : [num_users=1] = placeholder[target=p_m_weight]
    %p_m_bias : [num_users=1] = placeholder[target=p_m_bias]
    %x : [num_users=1] = placeholder[target=x]
    %layer_norm : [num_users=1] = call_function[target=torch.ops.aten.layer_norm.default](args = (%x, [512, 224, 224], %p_m_weight, %p_m_bias), kwargs = {})
    return (layer_norm,)
DEBUG:torch_tensorrt.dynamo._compiler:Input graph: graph():
    %m_weight : [num_users=1] = get_attr[target=m.weight]
    %m_bias : [num_users=1] = get_attr[target=m.bias]
    %x : [num_users=1] = placeholder[target=x]
    %native_layer_norm : [num_users=1] = call_function[target=torch.ops.aten.native_layer_norm.default](args = (%x, [512, 224, 224], %m_weight, %m_bias, 1e-05), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%native_layer_norm, 0), kwargs = {})
    return (getitem,)
DEBUG:torch_tensorrt.dynamo.lowering.passes.constant_folding:Graph after constant folding:
graph():
    %m_weight : [num_users=1] = get_attr[target=m.weight]
    %m_bias : [num_users=1] = get_attr[target=m.bias]
    %x : [num_users=1] = placeholder[target=x]
    %native_layer_norm : [num_users=1] = call_function[target=torch.ops.aten.native_layer_norm.default](args = (%x, [512, 224, 224], %m_weight, %m_bias, 1e-05), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%native_layer_norm, 0), kwargs = {})
    return (getitem,)
DEBUG:torch_tensorrt.dynamo.lowering.passes.remove_assert_scalar:Removed 0 assert_scalar nodes:
graph():
    %m_weight : [num_users=1] = get_attr[target=m.weight]
    %m_bias : [num_users=1] = get_attr[target=m.bias]
    %x : [num_users=1] = placeholder[target=x]
    %native_layer_norm : [num_users=1] = call_function[target=torch.ops.aten.native_layer_norm.default](args = (%x, [512, 224, 224], %m_weight, %m_bias, 1e-05), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%native_layer_norm, 0), kwargs = {})
    return (getitem,)
DEBUG:torch_tensorrt.dynamo.lowering.passes.accumulate_fp32_matmul:Skipping FP32 accumulation for matmul layers as use_fp32_acc is not enabled in the compilation settings
DEBUG:torch_tensorrt.dynamo._compiler:Lowered Input graph: graph():
    %m_weight : [num_users=1] = get_attr[target=m.weight]
    %m_bias : [num_users=1] = get_attr[target=m.bias]
    %x : [num_users=1] = placeholder[target=x]
    %native_layer_norm : [num_users=1] = call_function[target=torch.ops.aten.native_layer_norm.default](args = (%x, [512, 224, 224], %m_weight, %m_bias, 1e-05), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%native_layer_norm, 0), kwargs = {})
    return (getitem,)
DEBUG:torch_tensorrt.dynamo.partitioning._global_partitioner:
Supported Nodes:
- torch.ops.aten.native_layer_norm.default + Operator Count: 1
- _operator.getitem + Operator Count: 1

DEBUG:torch_tensorrt.dynamo.partitioning._global_partitioner:
All Nodes Supported

DEBUG:torch_tensorrt.dynamo._compiler:Detected support for 2 operators out of 2 in subgraph.
INFO:torch_tensorrt.dynamo._compiler:Partitioning the graph via the fast partitioner
DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:
Number of TensorRT-Accelerated Engines Generated: 1
DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:
Supported Nodes:
- torch.ops.aten.native_layer_norm.default + Operator Count: 1
- _operator.getitem + Operator Count: 1

DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:
All Nodes Supported

DEBUG:torch_tensorrt.dynamo._compiler:Updated metadata for node: _run_on_acc_0 with its corresponding submodule outputs
DEBUG:torch_tensorrt.dynamo._compiler:Converting submodule: _run_on_acc_0
 Input shapes: [(1, 512, 224, 224)]
 graph():
    %x : [num_users=1] = placeholder[target=x]
    %m_weight : [num_users=1] = get_attr[target=m.weight]
    %m_bias : [num_users=1] = get_attr[target=m.bias]
    %native_layer_norm : [num_users=1] = call_function[target=torch.ops.aten.native_layer_norm.default](args = (%x, [512, 224, 224], %m_weight, %m_bias, 1e-05), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%native_layer_norm, 0), kwargs = {})
    return getitem
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node x (kind: x, args: ())
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Adding input to in-progress INetwork: x [shape=[1, 512, 224, 224], dtype=DataType.HALF]
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node x [x] (Inputs: () | Outputs: (x: (1, 512, 224, 224)@torch.float16))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node m_weight (kind: m.weight, args: ())
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node m_weight [m.weight] (Inputs: () | Outputs: (m_weight: (512, 224, 224)@float16))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node m_bias (kind: m.bias, args: ())
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node m_bias [m.bias] (Inputs: () | Outputs: (m_bias: (512, 224, 224)@float16))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node m/native_layer_norm (kind: aten.native_layer_norm.default, args: ('x <Node>', ['512 <int>', '224 <int>', '224 <int>'], 'm_weight <Node>', 'm_bias <Node>', '1e-05 <float>'))
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node m/native_layer_norm [aten.native_layer_norm.default] (Inputs: (x: (1, 512, 224, 224)@torch.float16, [512, 224, 224], m_weight: (512, 224, 224)@float16, m_bias: (512, 224, 224)@float16, 1e-05) | Outputs: (native_layer_norm: ((1, 512, 224, 224)@torch.float16, (1, 1, 1, 1)@torch.float32, (1, 1, 1, 1)@torch.float32)))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node m/getitem (kind: <built-in function getitem>, args: ('native_layer_norm <Node>', '0 <int>'))
DEBUG:torch_tensorrt.dynamo.conversion.ops_evaluators:Evaluating _operator.getitem on object with name: m/getitem
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node m/getitem [<built-in function getitem>] (Inputs: (native_layer_norm: ((1, 512, 224, 224)@torch.float16, (1, 1, 1, 1)@torch.float32, (1, 1, 1, 1)@torch.float32), 0) | Outputs: (getitem: (1, 512, 224, 224)@torch.float16))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node output (kind: output, args: ('getitem <Node>',))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Marking output output0 [shape=(1, 512, 224, 224), dtype=DataType.HALF]
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node output [output] (Inputs: (getitem: (1, 512, 224, 224)@torch.float16) | Outputs: (output: ))
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:TRT INetwork construction elapsed time: 0:00:00.034771
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Not found cached TRT engines. Start building engine.
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Build TRT engine elapsed time: 0:00:00.477200
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:TRT Engine uses: 102810764 bytes of Memory
DEBUG: [Torch-TensorRT] - Deserializing Device Info: 0%8%9%0%NVIDIA GeForce RTX 4060 Ti
DEBUG: [Torch-TensorRT] - Deserialized Device Info: Device(ID: 0, Name: NVIDIA GeForce RTX 4060 Ti, SM Capability: 8.9, Type: GPU)
DEBUG: [Torch-TensorRT] - Target Device: Device(ID: 0, Name: NVIDIA GeForce RTX 4060 Ti, SM Capability: 8.9, Type: GPU)
DEBUG: [Torch-TensorRT] - Setting Device(ID: 0, Name: NVIDIA GeForce RTX 4060 Ti, SM Capability: 8.9, Type: GPU) as active device
INFO: [Torch-TensorRT] - Loaded engine size: 98 MiB
DEBUG: [Torch-TensorRT] - Deserialization required 11601 microseconds.
DEBUG: [Torch-TensorRT] - Total per-runner device persistent memory is 0
DEBUG: [Torch-TensorRT] - Total per-runner host persistent memory is 32
DEBUG: [Torch-TensorRT] - Allocated device scratch memory of size 107774464
DEBUG: [Torch-TensorRT] - - Runner scratch: 107774464 bytes
DEBUG: [Torch-TensorRT] - [runner] Allocating resources for 1 graphs.
INFO: [Torch-TensorRT] - [MemUsageChange] TensorRT-managed allocation in IExecutionContext creation: CPU +0, GPU +102, now: CPU 0, GPU 200 (MiB)
DEBUG: [Torch-TensorRT] - CUDA lazy loading is enabled.
DEBUG: [Torch-TensorRT] - Input binding name: x has TensorRT binding index: 0, Torch binding index: 0
DEBUG: [Torch-TensorRT] - Output binding name: output0 has TensorRT binding index: 1, Torch binding index: 1
DEBUG: [Torch-TensorRT] - Torch-TensorRT TensorRT Engine:
  Name: _run_on_acc_0_engine
  Inputs: [
    id: 0
      name: x
      shape: [1, 512, 224, 224]
      dtype: Half
  ]
  Outputs: [
    id: 0
      name: output0
      shape: [1, 512, 224, 224]
      dtype: Half
  ]
  Device: Device(ID: 0, Name: NVIDIA GeForce RTX 4060 Ti, SM Capability: 8.9, Type: GPU)
  Hardware Compatibility: Disabled
  Target Platform: windows_x86_64

DEBUG:torch_tensorrt.dynamo._DryRunTracker:
++++++++++++++++++++++++++++++++++++++++++++++++++ Dry-Run Results for Graph ++++++++++++++++++++++++++++++++++++++++++++++++++

The graph consists of 2 Total Operators, of which 2 operators are supported, 100.0% coverage

Compiled with: CompilationSettings(enabled_precisions={<dtype.f16: 6>}, debug=True, workspace_size=0, min_block_size=1, torch_executed_ops=set(), pass_through_build_failures=False, max_aux_streams=None, version_compatible=False, optimization_level=None, use_python_runtime=False, truncate_double=False, use_fast_partitioner=True, enable_experimental_decompositions=False, device=Device(type=DeviceType.GPU, gpu_id=0), require_full_compilation=False, disable_tf32=False, assume_dynamic_shape_support=False, sparse_weights=False, make_refittable=False, engine_capability=<EngineCapability.STANDARD: 1>, num_avg_timing_iters=1, dla_sram_size=1048576, dla_local_dram_size=1073741824, dla_global_dram_size=536870912, dryrun=False, hardware_compatible=False, timing_cache_path='C:\\Users\\HolyWu\\AppData\\Local\\Temp\\torch_tensorrt_engine_cache\\timing_cache.bin', lazy_engine_init=False, cache_built_engines=False, reuse_cached_engines=False, use_explicit_typing=False, use_fp32_acc=False, enable_weight_streaming=False)

  Graph Structure:

   Inputs: List[Tensor: (1, 512, 224, 224)@float16]
    ...
    TRT Engine #1 - Submodule name: _run_on_acc_0
     Engine Inputs: List[Tensor: (1, 512, 224, 224)@float16]
     Number of Operators in Engine: 2
     Engine Outputs: List[Tensor: (1, 512, 224, 224)@float16]
    ...
   Outputs: List[Tensor: (1, 512, 224, 224)@float16]

  ------------------------- Aggregate Stats -------------------------

   Average Number of Operators per TRT Engine: 2.0
   Most Operators in a TRT Engine: 2

  ********** Recommendations **********

   - For minimal graph segmentation, select min_block_size=2 which would generate 1 TRT engine(s)
   - The current level of graph segmentation is equivalent to selecting min_block_size=2 which generates 1 TRT engine(s)
DEBUG: [Torch-TensorRT] - Attempting to run engine (ID: _run_on_acc_0_engine); Hardware Compatible: 0
DEBUG: [Torch-TensorRT] - Input Name: x Shape: [1, 512, 224, 224]
DEBUG: [Torch-TensorRT] - Output Name: output0 Shape: [1, 512, 224, 224]

Type of change

  • Bug fix (non-breaking change which fixes an issue)

Checklist:

  • My code follows the style guidelines of this project (You can use the linters)
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas and hacks
  • I have made corresponding changes to the documentation
  • I have added tests to verify my fix or my feature
  • New and existing unit tests pass locally with my changes
  • I have added the relevant labels to my PR in so that relevant reviewers are notified

@github-actions github-actions bot added component: tests Issues re: Tests component: conversion Issues re: Conversion stage component: converters Issues re: Specific op converters component: api [Python] Issues re: Python API component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths labels Nov 3, 2024
@github-actions github-actions bot requested a review from apbose November 3, 2024 10:20
@HolyWu HolyWu force-pushed the fix_layer_norm_fp16 branch 4 times, most recently from 8148866 to 57fc8e9 Compare November 4, 2024 12:19
return layer_norm.get_output(0), None, None
layer = ctx.net.add_normalization(input, weight, bias, axes)
layer.epsilon = eps
set_layer_name(layer, target, name, source_ir)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM overall. Looks like you do not need to explicitly set the ILayer.precision or ILayer.set_output_type to set the output type of this layer with fp16 inputs

Copy link
Collaborator

@peri044 peri044 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed component: api [Python] Issues re: Python API component: conversion Issues re: Conversion stage component: converters Issues re: Specific op converters component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths component: tests Issues re: Tests
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants