Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace scaled_dot_product_attention lowering pass with decomposition #3296

Open
wants to merge 8 commits into
base: main
Choose a base branch
from

Conversation

HolyWu
Copy link
Contributor

@HolyWu HolyWu commented Nov 17, 2024

Description

Current scaled_dot_product_attention lowering pass doesn't properly handle attn_mask and hence causes large output differences in some Transformer models. Besides, a new enable_gqa argument was added in PyTorch 2.5, which is not handled in the lowering pass.

Using code modified from #3252:

import torch
import torch_tensorrt
from transformers import BartModel, BartTokenizer

dtype = torch.float

with torch.inference_mode():
    # Load tokenizer and model
    tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
    model = BartModel.from_pretrained("facebook/bart-base")
    model.eval().to("cuda", dtype)

    # Prepare inputs
    inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
    inputs = {k: v.cuda() for k, v in inputs.items()}

    # Run inference before Torch-TensorRT
    outputs_before = model(**inputs)

    # Apply Torch-TensorRT optimization
    model = torch_tensorrt.compile(model, "torch_compile", enabled_precisions={dtype}, min_block_size=1)

    # Run inference after Torch-TensorRT
    outputs_after = model(**inputs)

    # Compare outputs
    last_hidden_states_before = outputs_before.last_hidden_state
    last_hidden_states_after = outputs_after.last_hidden_state

    # Calculate the maximum absolute difference
    max_abs_diff = torch.max(torch.abs(last_hidden_states_before - last_hidden_states_after)).item()

    # Calculate the mean absolute difference
    mean_abs_diff = torch.mean(torch.abs(last_hidden_states_before - last_hidden_states_after)).item()

    # Print the outputs, maximum absolute difference, and mean absolute difference
    print("Outputs before Torch-TensorRT:")
    print(last_hidden_states_before)
    print("\nOutputs after Torch-TensorRT:")
    print(last_hidden_states_after)

    print(f"\nMaximum absolute difference: {max_abs_diff}")
    print(f"Mean absolute difference: {mean_abs_diff}")

Before Patch

Outputs before Torch-TensorRT:
tensor([[[ 2.4118,  2.3732,  1.1981,  ...,  1.8372, -0.1712, -0.7264],
         [-1.4809, -0.5843, -3.3371,  ...,  1.1434, -1.9142,  1.5422],
         [ 0.8170,  1.5384, -1.3417,  ...,  0.5091, -0.9715,  1.4299],
         ...,
         [-1.5440,  0.2834, -1.0513,  ...,  0.7554, -0.3832, -0.0514],
         [ 1.0442, -0.1567,  2.8073,  ...,  1.2079, -1.3359,  0.0742],
         [-0.0903, -0.2080,  0.1134,  ...,  1.1163, -1.0827,  0.3815]]],
       device='cuda:0')

Outputs after Torch-TensorRT:
tensor([[[ 2.1479,  0.2571,  0.9633,  ...,  1.5387,  0.4488, -0.5997],
         [ 0.0253, -0.2320, -0.5439,  ...,  1.0243, -2.3883,  0.0268],
         [ 0.3912,  1.0459, -0.8934,  ...,  0.6819, -0.6737, -0.1619],
         ...,
         [-0.0732,  0.5784, -0.8106,  ...,  0.3836, -1.0898, -0.0990],
         [ 0.4508, -0.2233,  1.3816,  ...,  0.9082, -1.3156, -0.0508],
         [-0.1438, -0.4138,  0.0810,  ...,  0.7141, -0.9133,  0.3111]]],
       device='cuda:0')

Maximum absolute difference: 5.748598575592041
Mean absolute difference: 0.7984389066696167

After Patch

Outputs before Torch-TensorRT:
tensor([[[ 2.4118,  2.3732,  1.1981,  ...,  1.8372, -0.1712, -0.7264],
         [-1.4809, -0.5843, -3.3371,  ...,  1.1434, -1.9142,  1.5422],
         [ 0.8170,  1.5384, -1.3417,  ...,  0.5091, -0.9715,  1.4299],
         ...,
         [-1.5440,  0.2834, -1.0513,  ...,  0.7554, -0.3832, -0.0514],
         [ 1.0442, -0.1567,  2.8073,  ...,  1.2079, -1.3359,  0.0742],
         [-0.0903, -0.2080,  0.1134,  ...,  1.1163, -1.0827,  0.3815]]],
       device='cuda:0')

Outputs after Torch-TensorRT:
tensor([[[ 2.4126,  2.3732,  1.1995,  ...,  1.8375, -0.1714, -0.7263],
         [-1.4797, -0.5851, -3.3367,  ...,  1.1449, -1.9142,  1.5414],
         [ 0.8186,  1.5377, -1.3421,  ...,  0.5104, -0.9731,  1.4298],
         ...,
         [-1.5419,  0.2820, -1.0494,  ...,  0.7564, -0.3840, -0.0513],
         [ 1.0435, -0.1564,  2.8093,  ...,  1.2080, -1.3349,  0.0739],
         [-0.0902, -0.2082,  0.1131,  ...,  1.1156, -1.0839,  0.3816]]],
       device='cuda:0')

Maximum absolute difference: 0.01119232177734375
Mean absolute difference: 0.000758049194701016

The difference is still large if running the above code with dtype = torch.half due to LayerNorm's compute_precision is being set. That can be resolved after merging #3272.

Type of change

  • Bug fix (non-breaking change which fixes an issue)

Checklist:

  • My code follows the style guidelines of this project (You can use the linters)
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas and hacks
  • I have made corresponding changes to the documentation
  • I have added tests to verify my fix or my feature
  • New and existing unit tests pass locally with my changes
  • I have added the relevant labels to my PR in so that relevant reviewers are notified

@github-actions github-actions bot added component: tests Issues re: Tests component: lowering Issues re: The lowering / preprocessing passes component: conversion Issues re: Conversion stage component: converters Issues re: Specific op converters component: api [Python] Issues re: Python API component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths labels Nov 17, 2024
@github-actions github-actions bot requested a review from apbose November 17, 2024 09:50
Copy link
Collaborator

@peri044 peri044 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the contribution. It seems like you're replacing the converter with decomposition. Did you observe any perf gains/regressions with this change ?

py/torch_tensorrt/dynamo/lowering/_decompositions.py Outdated Show resolved Hide resolved
tests/py/dynamo/lowering/test_decompositions.py Outdated Show resolved Hide resolved
tests/py/dynamo/lowering/test_decompositions.py Outdated Show resolved Hide resolved
@HolyWu
Copy link
Contributor Author

HolyWu commented Nov 20, 2024

Did you observe any perf gains/regressions with this change ?

The performance is very close in my test.

Benchmarking with:

from __future__ import annotations

import os

import numpy as np
import torch
import torch_tensorrt

os.environ["CI_BUILD"] = "1"

times = 20


@torch.inference_mode()
def benchmark(model: torch.nn.Module, inputs: list[tuple[torch.Tensor, torch.Tensor, torch.Tensor]]) -> np.ndarray:
    # Warm up
    for i in range(3):
        model(*inputs[i])

    torch.cuda.synchronize()

    start_events = [torch.cuda.Event(enable_timing=True) for _ in range(times)]
    end_events = [torch.cuda.Event(enable_timing=True) for _ in range(times)]

    for i in range(times):
        torch.cuda._sleep(1_000_000)

        start_events[i].record()
        model(*inputs[i])
        end_events[i].record()

    torch.cuda.synchronize()

    timings = [s.elapsed_time(e) for s, e in zip(start_events, end_events)]
    return np.array(timings)


class MyModule(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()

    def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
        return torch.nn.functional.scaled_dot_product_attention(query, key, value)


torch.manual_seed(12345)
model = MyModule().eval().cuda().half()

inputs = [
    torch_tensorrt.Input((32, 64, 128, 256), dtype=torch.half),
    torch_tensorrt.Input((32, 64, 128, 256), dtype=torch.half),
    torch_tensorrt.Input((32, 64, 128, 256), dtype=torch.half),
]

trt_model = torch_tensorrt.compile(
    model, "dynamo", inputs, enabled_precisions={torch.half}, debug=True, min_block_size=1
)

inputs = [
    (
        torch.rand(32, 64, 128, 256, dtype=torch.half, device="cuda"),
        torch.rand(32, 64, 128, 256, dtype=torch.half, device="cuda"),
        torch.rand(32, 64, 128, 256, dtype=torch.half, device="cuda"),
    )
    for _ in range(times)
]
timing = benchmark(trt_model, inputs)

print("")
print("Timing:")
print(f"Min={timing.min()} ms, Mean={timing.mean()} ms, Max={timing.max()} ms")
print("")

with torch.inference_mode():
    for i in range(times):
        torch.testing.assert_close(trt_model(*inputs[i]), model(*inputs[i]), rtol=5e-3, atol=5e-3)
print("assert_close passed")

torch._dynamo.reset()

Converter

DEBUG:torch_tensorrt.dynamo.lowering.passes.remove_detach:Removed 0 detach nodes:
graph():
    %query : [num_users=1] = placeholder[target=query]
    %key : [num_users=1] = placeholder[target=key]
    %value : [num_users=1] = placeholder[target=value]
    %scaled_dot_product_attention : [num_users=1] = call_function[target=torch.ops.aten.scaled_dot_product_attention.default](args = (%query, %key, %value), kwargs = {})
    return (scaled_dot_product_attention,)
DEBUG:torch_tensorrt.dynamo._compiler:Input graph: graph():
    %query : [num_users=1] = placeholder[target=query]
    %key : [num_users=1] = placeholder[target=key]
    %value : [num_users=1] = placeholder[target=value]
    %_scaled_dot_product_efficient_attention : [num_users=1] = call_function[target=torch.ops.aten._scaled_dot_product_efficient_attention.default](args = (%query, %key, %value, None, False), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%_scaled_dot_product_efficient_attention, 0), kwargs = {})
    return (getitem,)
DEBUG:torch_tensorrt.dynamo.lowering.passes.constant_folding:Graph after constant folding:
graph():
    %query : [num_users=1] = placeholder[target=query]
    %key : [num_users=1] = placeholder[target=key]
    %value : [num_users=1] = placeholder[target=value]
    %_scaled_dot_product_efficient_attention : [num_users=1] = call_function[target=torch.ops.aten._scaled_dot_product_efficient_attention.default](args = (%query, %key, %value, None, False), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%_scaled_dot_product_efficient_attention, 0), kwargs = {})
    return (getitem,)
DEBUG:torch_tensorrt.dynamo.lowering.passes.lower_scaled_dot_product_attention:Graph after lowering scaled dot product attention:
graph():
    %query : [num_users=1] = placeholder[target=query]
    %key : [num_users=1] = placeholder[target=key]
    %value : [num_users=1] = placeholder[target=value]
    %scaled_dot_product_attention : [num_users=1] = call_function[target=torch._C._nn.scaled_dot_product_attention](args = (%query, %key, %value, None, 0.0, False), kwargs = {})
    return (scaled_dot_product_attention,)
DEBUG:torch_tensorrt.dynamo.lowering.passes.remove_assert_scalar:Removed 0 assert_scalar nodes:
graph():
    %query : [num_users=1] = placeholder[target=query]
    %key : [num_users=1] = placeholder[target=key]
    %value : [num_users=1] = placeholder[target=value]
    %scaled_dot_product_attention : [num_users=1] = call_function[target=torch._C._nn.scaled_dot_product_attention](args = (%query, %key, %value, None, 0.0, False), kwargs = {})
    return (scaled_dot_product_attention,)
DEBUG:torch_tensorrt.dynamo.lowering.passes.accumulate_fp32_matmul:Skipping FP32 accumulation for matmul layers as use_fp32_acc is not enabled in the compilation settings
DEBUG:torch_tensorrt.dynamo._compiler:Lowered Input graph: graph():
    %query : [num_users=1] = placeholder[target=query]
    %key : [num_users=1] = placeholder[target=key]
    %value : [num_users=1] = placeholder[target=value]
    %scaled_dot_product_attention : [num_users=1] = call_function[target=torch._C._nn.scaled_dot_product_attention](args = (%query, %key, %value, None, 0.0, False), kwargs = {})
    return (scaled_dot_product_attention,)
DEBUG:torch_tensorrt.dynamo.partitioning._global_partitioner:
Supported Nodes:
- torch._C._nn.scaled_dot_product_attention + Operator Count: 1

DEBUG:torch_tensorrt.dynamo.partitioning._global_partitioner:
All Nodes Supported

DEBUG:torch_tensorrt.dynamo._compiler:Detected support for 1 operators out of 1 in subgraph.
INFO:torch_tensorrt.dynamo._compiler:Partitioning the graph via the fast partitioner
DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:
Number of TensorRT-Accelerated Engines Generated: 1
DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:
Supported Nodes:
- torch._C._nn.scaled_dot_product_attention + Operator Count: 1

DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:
All Nodes Supported

DEBUG:torch_tensorrt.dynamo._compiler:Updated metadata for node: _run_on_acc_0 with its corresponding submodule outputs
DEBUG:torch_tensorrt.dynamo._compiler:Converting submodule: _run_on_acc_0
 Input shapes: [(32, 64, 128, 256), (32, 64, 128, 256), (32, 64, 128, 256)]
 graph():
    %query : [num_users=1] = placeholder[target=query]
    %key : [num_users=1] = placeholder[target=key]
    %value : [num_users=1] = placeholder[target=value]
    %scaled_dot_product_attention : [num_users=1] = call_function[target=torch._C._nn.scaled_dot_product_attention](args = (%query, %key, %value, None, 0.0, False), kwargs = {})
    return scaled_dot_product_attention
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node query (kind: query, args: ())
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Adding input to in-progress INetwork: query [shape=[32, 64, 128, 256], dtype=DataType.HALF]
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node query [query] (Inputs: () | Outputs: (query: (32, 64, 128, 256)@torch.float16))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node key (kind: key, args: ())
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Adding input to in-progress INetwork: key [shape=[32, 64, 128, 256], dtype=DataType.HALF]
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node key [key] (Inputs: () | Outputs: (key: (32, 64, 128, 256)@torch.float16))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node value (kind: value, args: ())
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Adding input to in-progress INetwork: value [shape=[32, 64, 128, 256], dtype=DataType.HALF]
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node value [value] (Inputs: () | Outputs: (value: (32, 64, 128, 256)@torch.float16))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node scaled_dot_product_attention (kind: <built-in function scaled_dot_product_attention>, args: ('query <Node>', 'key <Node>', 'value <Node>', 'None <NoneType>', '0.0 <float>', 'False <bool>'))
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node scaled_dot_product_attention [<built-in function scaled_dot_product_attention>] (Inputs: (query: (32, 64, 128, 256)@torch.float16, key: (32, 64, 128, 256)@torch.float16, value: (32, 64, 128, 256)@torch.float16, None, 0.0, False) | Outputs: (scaled_dot_product_attention: (32, 64, 128, 256)@torch.float16))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node output (kind: output, args: ('scaled_dot_product_attention <Node>',))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Marking output output0 [shape=(32, 64, 128, 256), dtype=DataType.HALF]
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node output [output] (Inputs: (scaled_dot_product_attention: (32, 64, 128, 256)@torch.float16) | Outputs: (output: ))
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:TRT INetwork construction elapsed time: 0:00:00.004000
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Not found cached TRT engines. Start building engine.
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Build TRT engine elapsed time: 0:00:00.370738
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:TRT Engine uses: 62148 bytes of Memory
DEBUG: [Torch-TensorRT] - Deserializing Device Info: 0%8%9%0%NVIDIA GeForce RTX 4060 Ti
DEBUG: [Torch-TensorRT] - Deserialized Device Info: Device(ID: 0, Name: NVIDIA GeForce RTX 4060 Ti, SM Capability: 8.9, Type: GPU)
DEBUG: [Torch-TensorRT] - Target Device: Device(ID: 0, Name: NVIDIA GeForce RTX 4060 Ti, SM Capability: 8.9, Type: GPU)
DEBUG: [Torch-TensorRT] - Setting Device(ID: 0, Name: NVIDIA GeForce RTX 4060 Ti, SM Capability: 8.9, Type: GPU) as active device
INFO: [Torch-TensorRT] - Loaded engine size: 0 MiB
DEBUG: [Torch-TensorRT] - Deserialization required 106 microseconds.
INFO: [Torch-TensorRT] - [MS] Running engine with multi stream info
INFO: [Torch-TensorRT] - [MS] Number of aux streams is 2
INFO: [Torch-TensorRT] - [MS] Number of total worker streams is 3
INFO: [Torch-TensorRT] - [MS] The main stream provided by execute/enqueue calls is the first worker stream
DEBUG: [Torch-TensorRT] - Total per-runner device persistent memory is 0
DEBUG: [Torch-TensorRT] - Total per-runner host persistent memory is 80
DEBUG: [Torch-TensorRT] - Allocated device scratch memory of size 536870912
DEBUG: [Torch-TensorRT] - - Runner scratch: 536870912 bytes
INFO: [Torch-TensorRT] - [MemUsageChange] TensorRT-managed allocation in IExecutionContext creation: CPU +0, GPU +512, now: CPU 0, GPU 512 (MiB)
DEBUG: [Torch-TensorRT] - CUDA lazy loading is enabled.
DEBUG: [Torch-TensorRT] - Input binding name: query has TensorRT binding index: 0, Torch binding index: 0
DEBUG: [Torch-TensorRT] - Input binding name: key has TensorRT binding index: 1, Torch binding index: 1
DEBUG: [Torch-TensorRT] - Input binding name: value has TensorRT binding index: 2, Torch binding index: 2
DEBUG: [Torch-TensorRT] - Output binding name: output0 has TensorRT binding index: 3, Torch binding index: 3
DEBUG: [Torch-TensorRT] - Torch-TensorRT TensorRT Engine:
  Name: _run_on_acc_0_engine
  Inputs: [
    id: 0
      name: query
      shape: [32, 64, 128, 256]
      dtype: Half
    id: 1
      name: key
      shape: [32, 64, 128, 256]
      dtype: Half
    id: 2
      name: value
      shape: [32, 64, 128, 256]
      dtype: Half
  ]
  Outputs: [
    id: 0
      name: output0
      shape: [32, 64, 128, 256]
      dtype: Half
  ]
  Device: Device(ID: 0, Name: NVIDIA GeForce RTX 4060 Ti, SM Capability: 8.9, Type: GPU)
  Hardware Compatibility: Disabled
  Target Platform: windows_x86_64

DEBUG:torch_tensorrt.dynamo._DryRunTracker:
++++++++++++++++++++++++++++++++++++++++++++++++++ Dry-Run Results for Graph ++++++++++++++++++++++++++++++++++++++++++++++++++

The graph consists of 1 Total Operators, of which 1 operators are supported, 100.0% coverage

Compiled with: CompilationSettings(enabled_precisions={<dtype.f16: 6>}, debug=True, workspace_size=0, min_block_size=1, torch_executed_ops=set(), pass_through_build_failures=False, max_aux_streams=None, version_compatible=False, optimization_level=None, use_python_runtime=False, truncate_double=False, use_fast_partitioner=True, enable_experimental_decompositions=False, device=Device(type=DeviceType.GPU, gpu_id=0), require_full_compilation=False, disable_tf32=False, assume_dynamic_shape_support=False, sparse_weights=False, make_refittable=False, engine_capability=<EngineCapability.STANDARD: 1>, num_avg_timing_iters=1, dla_sram_size=1048576, dla_local_dram_size=1073741824, dla_global_dram_size=536870912, dryrun=False, hardware_compatible=False, timing_cache_path='C:\\Users\\HolyWu\\AppData\\Local\\Temp\\torch_tensorrt_engine_cache\\timing_cache.bin', lazy_engine_init=False, cache_built_engines=False, reuse_cached_engines=False, use_explicit_typing=False, use_fp32_acc=False, enable_weight_streaming=False, enable_cross_compile_for_windows=False)

  Graph Structure:

   Inputs: List[Tensor: (32, 64, 128, 256)@float16, Tensor: (32, 64, 128, 256)@float16, Tensor: (32, 64, 128, 256)@float16]
    ...
    TRT Engine #1 - Submodule name: _run_on_acc_0
     Engine Inputs: List[Tensor: (32, 64, 128, 256)@float16, Tensor: (32, 64, 128, 256)@float16, Tensor: (32, 64, 128, 256)@float16]
     Number of Operators in Engine: 1
     Engine Outputs: List[Tensor: (32, 64, 128, 256)@float16]
    ...
   Outputs: List[Tensor: (32, 64, 128, 256)@float16]

  ------------------------- Aggregate Stats -------------------------

   Average Number of Operators per TRT Engine: 1.0
   Most Operators in a TRT Engine: 1

  ********** Recommendations **********

   - For minimal graph segmentation, select min_block_size=1 which would generate 1 TRT engine(s)
   - The current level of graph segmentation is equivalent to selecting min_block_size=1 which generates 1 TRT engine(s)

Timing:
Min=9.552767753601074 ms, Mean=9.709788846969605 ms, Max=10.579968452453613 ms

assert_close passed

Decomposition

DEBUG:torch_tensorrt.dynamo.lowering.passes.remove_detach:Removed 0 detach nodes:
graph():
    %query : [num_users=1] = placeholder[target=query]
    %key : [num_users=1] = placeholder[target=key]
    %value : [num_users=1] = placeholder[target=value]
    %scaled_dot_product_attention : [num_users=1] = call_function[target=torch.ops.aten.scaled_dot_product_attention.default](args = (%query, %key, %value), kwargs = {})
    return (scaled_dot_product_attention,)
DEBUG:torch_tensorrt.dynamo._compiler:Input graph: graph():
    %query : [num_users=1] = placeholder[target=query]
    %key : [num_users=1] = placeholder[target=key]
    %value : [num_users=1] = placeholder[target=value]
    %full : [num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([128, 128], 0), kwargs = {dtype: torch.float16, layout: torch.strided, device: cuda:0, pin_memory: False})
    %permute : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%key, [0, 1, 3, 2]), kwargs = {})
    %expand : [num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%query, [32, 64, 128, 256]), kwargs = {})
    %view : [num_users=1] = call_function[target=torch.ops.aten.view.default](args = (%expand, [2048, 128, 256]), kwargs = {})
    %expand_1 : [num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%permute, [32, 64, 256, 128]), kwargs = {})
    %view_1 : [num_users=1] = call_function[target=torch.ops.aten.view.default](args = (%expand_1, [2048, 256, 128]), kwargs = {})
    %bmm : [num_users=1] = call_function[target=torch.ops.aten.bmm.default](args = (%view, %view_1), kwargs = {})
    %view_2 : [num_users=1] = call_function[target=torch.ops.aten.view.default](args = (%bmm, [32, 64, 128, 128]), kwargs = {})
    %scalar_tensor : [num_users=1] = call_function[target=torch.ops.aten.scalar_tensor.default](args = (256,), kwargs = {dtype: torch.int32, device: cpu, pin_memory: False})
    %sqrt : [num_users=1] = call_function[target=torch.ops.aten.sqrt.default](args = (%scalar_tensor,), kwargs = {})
    %div : [num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%view_2, %sqrt), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%div, %full), kwargs = {})
    %_softmax : [num_users=1] = call_function[target=torch.ops.aten._softmax.default](args = (%add, -1, False), kwargs = {})
    %expand_2 : [num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%_softmax, [32, 64, 128, 128]), kwargs = {})
    %view_3 : [num_users=1] = call_function[target=torch.ops.aten.view.default](args = (%expand_2, [2048, 128, 128]), kwargs = {})
    %expand_3 : [num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%value, [32, 64, 128, 256]), kwargs = {})
    %view_4 : [num_users=1] = call_function[target=torch.ops.aten.view.default](args = (%expand_3, [2048, 128, 256]), kwargs = {})
    %bmm_1 : [num_users=1] = call_function[target=torch.ops.aten.bmm.default](args = (%view_3, %view_4), kwargs = {})
    %view_5 : [num_users=1] = call_function[target=torch.ops.aten.view.default](args = (%bmm_1, [32, 64, 128, 256]), kwargs = {})
    return (view_5,)
DEBUG:torch_tensorrt.dynamo.lowering.passes.constant_folding:Graph after constant folding:
graph():
    %query : [num_users=1] = placeholder[target=query]
    %key : [num_users=1] = placeholder[target=key]
    %value : [num_users=1] = placeholder[target=value]
    %_frozen_param0 : [num_users=1] = get_attr[target=_frozen_param0]
    %permute : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%key, [0, 1, 3, 2]), kwargs = {})
    %expand : [num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%query, [32, 64, 128, 256]), kwargs = {})
    %view : [num_users=1] = call_function[target=torch.ops.aten.view.default](args = (%expand, [2048, 128, 256]), kwargs = {})
    %expand_1 : [num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%permute, [32, 64, 256, 128]), kwargs = {})
    %view_1 : [num_users=1] = call_function[target=torch.ops.aten.view.default](args = (%expand_1, [2048, 256, 128]), kwargs = {})
    %bmm : [num_users=1] = call_function[target=torch.ops.aten.bmm.default](args = (%view, %view_1), kwargs = {})
    %view_2 : [num_users=1] = call_function[target=torch.ops.aten.view.default](args = (%bmm, [32, 64, 128, 128]), kwargs = {})
    %_frozen_param1 : [num_users=1] = get_attr[target=_frozen_param1]
    %div : [num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%view_2, %_frozen_param1), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%div, %_frozen_param0), kwargs = {})
    %_softmax : [num_users=1] = call_function[target=torch.ops.aten._softmax.default](args = (%add, -1, False), kwargs = {})
    %expand_2 : [num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%_softmax, [32, 64, 128, 128]), kwargs = {})
    %view_3 : [num_users=1] = call_function[target=torch.ops.aten.view.default](args = (%expand_2, [2048, 128, 128]), kwargs = {})
    %expand_3 : [num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%value, [32, 64, 128, 256]), kwargs = {})
    %view_4 : [num_users=1] = call_function[target=torch.ops.aten.view.default](args = (%expand_3, [2048, 128, 256]), kwargs = {})
    %bmm_1 : [num_users=1] = call_function[target=torch.ops.aten.bmm.default](args = (%view_3, %view_4), kwargs = {})
    %view_5 : [num_users=1] = call_function[target=torch.ops.aten.view.default](args = (%bmm_1, [32, 64, 128, 256]), kwargs = {})
    return (view_5,)
DEBUG:torch_tensorrt.dynamo.lowering.passes.view_to_reshape:Graph after replacing view with reshape:
graph():
    %query : [num_users=1] = placeholder[target=query]
    %key : [num_users=1] = placeholder[target=key]
    %value : [num_users=1] = placeholder[target=value]
    %_frozen_param0 : [num_users=1] = get_attr[target=_frozen_param0]
    %permute : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%key, [0, 1, 3, 2]), kwargs = {})
    %expand : [num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%query, [32, 64, 128, 256]), kwargs = {})
    %expand_1 : [num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%permute, [32, 64, 256, 128]), kwargs = {})
    %reshape_default : [num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%expand, [2048, 128, 256]), kwargs = {})
    %reshape_default_1 : [num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%expand_1, [2048, 256, 128]), kwargs = {})
    %bmm : [num_users=1] = call_function[target=torch.ops.aten.bmm.default](args = (%reshape_default, %reshape_default_1), kwargs = {})
    %_frozen_param1 : [num_users=1] = get_attr[target=_frozen_param1]
    %reshape_default_2 : [num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%bmm, [32, 64, 128, 128]), kwargs = {})
    %div : [num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%reshape_default_2, %_frozen_param1), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%div, %_frozen_param0), kwargs = {})
    %_softmax : [num_users=1] = call_function[target=torch.ops.aten._softmax.default](args = (%add, -1, False), kwargs = {})
    %expand_2 : [num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%_softmax, [32, 64, 128, 128]), kwargs = {})
    %expand_3 : [num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%value, [32, 64, 128, 256]), kwargs = {})
    %reshape_default_3 : [num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%expand_2, [2048, 128, 128]), kwargs = {})
    %reshape_default_4 : [num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%expand_3, [2048, 128, 256]), kwargs = {})
    %bmm_1 : [num_users=1] = call_function[target=torch.ops.aten.bmm.default](args = (%reshape_default_3, %reshape_default_4), kwargs = {})
    %reshape_default_5 : [num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%bmm_1, [32, 64, 128, 256]), kwargs = {})
    return (reshape_default_5,)
DEBUG:torch_tensorrt.dynamo.lowering.passes.remove_assert_scalar:Removed 0 assert_scalar nodes:
graph():
    %query : [num_users=1] = placeholder[target=query]
    %key : [num_users=1] = placeholder[target=key]
    %value : [num_users=1] = placeholder[target=value]
    %_frozen_param0 : [num_users=1] = get_attr[target=_frozen_param0]
    %permute : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%key, [0, 1, 3, 2]), kwargs = {})
    %expand : [num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%query, [32, 64, 128, 256]), kwargs = {})
    %expand_1 : [num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%permute, [32, 64, 256, 128]), kwargs = {})
    %reshape_default : [num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%expand, [2048, 128, 256]), kwargs = {})
    %reshape_default_1 : [num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%expand_1, [2048, 256, 128]), kwargs = {})
    %bmm : [num_users=1] = call_function[target=torch.ops.aten.bmm.default](args = (%reshape_default, %reshape_default_1), kwargs = {})
    %_frozen_param1 : [num_users=1] = get_attr[target=_frozen_param1]
    %reshape_default_2 : [num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%bmm, [32, 64, 128, 128]), kwargs = {})
    %div : [num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%reshape_default_2, %_frozen_param1), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%div, %_frozen_param0), kwargs = {})
    %_softmax : [num_users=1] = call_function[target=torch.ops.aten._softmax.default](args = (%add, -1, False), kwargs = {})
    %expand_2 : [num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%_softmax, [32, 64, 128, 128]), kwargs = {})
    %expand_3 : [num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%value, [32, 64, 128, 256]), kwargs = {})
    %reshape_default_3 : [num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%expand_2, [2048, 128, 128]), kwargs = {})
    %reshape_default_4 : [num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%expand_3, [2048, 128, 256]), kwargs = {})
    %bmm_1 : [num_users=1] = call_function[target=torch.ops.aten.bmm.default](args = (%reshape_default_3, %reshape_default_4), kwargs = {})
    %reshape_default_5 : [num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%bmm_1, [32, 64, 128, 256]), kwargs = {})
    return (reshape_default_5,)
DEBUG:torch_tensorrt.dynamo.lowering.passes.accumulate_fp32_matmul:Skipping FP32 accumulation for matmul layers as use_fp32_acc is not enabled in the compilation settings
DEBUG:torch_tensorrt.dynamo._compiler:Lowered Input graph: graph():
    %query : [num_users=1] = placeholder[target=query]
    %key : [num_users=1] = placeholder[target=key]
    %value : [num_users=1] = placeholder[target=value]
    %_frozen_param0 : [num_users=1] = get_attr[target=_frozen_param0]
    %permute : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%key, [0, 1, 3, 2]), kwargs = {})
    %expand : [num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%query, [32, 64, 128, 256]), kwargs = {})
    %expand_1 : [num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%permute, [32, 64, 256, 128]), kwargs = {})
    %reshape_default : [num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%expand, [2048, 128, 256]), kwargs = {})
    %reshape_default_1 : [num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%expand_1, [2048, 256, 128]), kwargs = {})
    %bmm : [num_users=1] = call_function[target=torch.ops.aten.bmm.default](args = (%reshape_default, %reshape_default_1), kwargs = {})
    %_frozen_param1 : [num_users=1] = get_attr[target=_frozen_param1]
    %reshape_default_2 : [num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%bmm, [32, 64, 128, 128]), kwargs = {})
    %div : [num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%reshape_default_2, %_frozen_param1), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%div, %_frozen_param0), kwargs = {})
    %_softmax : [num_users=1] = call_function[target=torch.ops.aten._softmax.default](args = (%add, -1, False), kwargs = {})
    %expand_2 : [num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%_softmax, [32, 64, 128, 128]), kwargs = {})
    %expand_3 : [num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%value, [32, 64, 128, 256]), kwargs = {})
    %reshape_default_3 : [num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%expand_2, [2048, 128, 128]), kwargs = {})
    %reshape_default_4 : [num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%expand_3, [2048, 128, 256]), kwargs = {})
    %bmm_1 : [num_users=1] = call_function[target=torch.ops.aten.bmm.default](args = (%reshape_default_3, %reshape_default_4), kwargs = {})
    %reshape_default_5 : [num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%bmm_1, [32, 64, 128, 256]), kwargs = {})
    return (reshape_default_5,)
DEBUG:torch_tensorrt.dynamo.partitioning._global_partitioner:
Supported Nodes:
- torch.ops.aten.permute.default + Operator Count: 1
- torch.ops.aten.expand.default + Operator Count: 4
- torch.ops.aten.reshape.default + Operator Count: 6
- torch.ops.aten.bmm.default + Operator Count: 2
- torch.ops.aten.div.Tensor + Operator Count: 1
- torch.ops.aten.add.Tensor + Operator Count: 1
- torch.ops.aten._softmax.default + Operator Count: 1

DEBUG:torch_tensorrt.dynamo.partitioning._global_partitioner:
All Nodes Supported

DEBUG:torch_tensorrt.dynamo._compiler:Detected support for 16 operators out of 16 in subgraph.
INFO:torch_tensorrt.dynamo._compiler:Partitioning the graph via the fast partitioner
DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:
Number of TensorRT-Accelerated Engines Generated: 1
DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:
Supported Nodes:
- torch.ops.aten.permute.default + Operator Count: 1
- torch.ops.aten.expand.default + Operator Count: 4
- torch.ops.aten.reshape.default + Operator Count: 6
- torch.ops.aten.bmm.default + Operator Count: 2
- torch.ops.aten.div.Tensor + Operator Count: 1
- torch.ops.aten.add.Tensor + Operator Count: 1
- torch.ops.aten._softmax.default + Operator Count: 1

DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:
All Nodes Supported

DEBUG:torch_tensorrt.dynamo._compiler:Updated metadata for node: _run_on_acc_0 with its corresponding submodule outputs
DEBUG:torch_tensorrt.dynamo._compiler:Converting submodule: _run_on_acc_0
 Input shapes: [(32, 64, 128, 256), (32, 64, 128, 256), (32, 64, 128, 256)]
 graph():
    %key : [num_users=1] = placeholder[target=key]
    %permute : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%key, [0, 1, 3, 2]), kwargs = {})
    %query : [num_users=1] = placeholder[target=query]
    %expand : [num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%query, [32, 64, 128, 256]), kwargs = {})
    %expand_1 : [num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%permute, [32, 64, 256, 128]), kwargs = {})
    %reshape_default : [num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%expand, [2048, 128, 256]), kwargs = {})
    %reshape_default_1 : [num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%expand_1, [2048, 256, 128]), kwargs = {})
    %bmm : [num_users=1] = call_function[target=torch.ops.aten.bmm.default](args = (%reshape_default, %reshape_default_1), kwargs = {})
    %reshape_default_2 : [num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%bmm, [32, 64, 128, 128]), kwargs = {})
    %_frozen_param1 : [num_users=1] = get_attr[target=_frozen_param1]
    %div : [num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%reshape_default_2, %_frozen_param1), kwargs = {})
    %_frozen_param0 : [num_users=1] = get_attr[target=_frozen_param0]
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%div, %_frozen_param0), kwargs = {})
    %_softmax : [num_users=1] = call_function[target=torch.ops.aten._softmax.default](args = (%add, -1, False), kwargs = {})
    %expand_2 : [num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%_softmax, [32, 64, 128, 128]), kwargs = {})
    %value : [num_users=1] = placeholder[target=value]
    %expand_3 : [num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%value, [32, 64, 128, 256]), kwargs = {})
    %reshape_default_3 : [num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%expand_2, [2048, 128, 128]), kwargs = {})
    %reshape_default_4 : [num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%expand_3, [2048, 128, 256]), kwargs = {})
    %bmm_1 : [num_users=1] = call_function[target=torch.ops.aten.bmm.default](args = (%reshape_default_3, %reshape_default_4), kwargs = {})
    %reshape_default_5 : [num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%bmm_1, [32, 64, 128, 256]), kwargs = {})
    return reshape_default_5
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node key (kind: key, args: ())
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Adding input to in-progress INetwork: key [shape=[32, 64, 128, 256], dtype=DataType.HALF]
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node key [key] (Inputs: () | Outputs: (key: (32, 64, 128, 256)@torch.float16))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node /permute (kind: aten.permute.default, args: ('key <Node>', ['0 <int>', '1 <int>', '3 <int>', '2 <int>']))
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node /permute [aten.permute.default] (Inputs: (key: (32, 64, 128, 256)@torch.float16, [0, 1, 3, 2]) | Outputs: (permute: (32, 64, 256, 128)@torch.float16))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node query (kind: query, args: ())
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Adding input to in-progress INetwork: query [shape=[32, 64, 128, 256], dtype=DataType.HALF]
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node query [query] (Inputs: () | Outputs: (query: (32, 64, 128, 256)@torch.float16))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node /expand (kind: aten.expand.default, args: ('query <Node>', ['32 <int>', '64 <int>', '128 <int>', '256 <int>']))
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node /expand [aten.expand.default] (Inputs: (query: (32, 64, 128, 256)@torch.float16, [32, 64, 128, 256]) | Outputs: (expand: (32, 64, 128, 256)@torch.float16))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node /expand_1 (kind: aten.expand.default, args: ('permute <Node>', ['32 <int>', '64 <int>', '256 <int>', '128 <int>']))
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node /expand_1 [aten.expand.default] (Inputs: (permute: (32, 64, 256, 128)@torch.float16, [32, 64, 256, 128]) | Outputs: (expand_1: (32, 64, 256, 128)@torch.float16))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node /reshape_default (kind: aten.reshape.default, args: ('expand <Node>', ['2048 <int>', '128 <int>', '256 <int>']))
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node /reshape_default [aten.reshape.default] (Inputs: (expand: (32, 64, 128, 256)@torch.float16, [2048, 128, 256]) | Outputs: (reshape_default: (2048, 128, 256)@torch.float16))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node /reshape_default_1 (kind: aten.reshape.default, args: ('expand_1 <Node>', ['2048 <int>', '256 <int>', '128 <int>']))
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node /reshape_default_1 [aten.reshape.default] (Inputs: (expand_1: (32, 64, 256, 128)@torch.float16, [2048, 256, 128]) | Outputs: (reshape_default_1: (2048, 256, 128)@torch.float16))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node /bmm (kind: aten.bmm.default, args: ('reshape_default <Node>', 'reshape_default_1 <Node>'))
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node /bmm [aten.bmm.default] (Inputs: (reshape_default: (2048, 128, 256)@torch.float16, reshape_default_1: (2048, 256, 128)@torch.float16) | Outputs: (bmm: (2048, 128, 128)@torch.float16))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node /reshape_default_2 (kind: aten.reshape.default, args: ('bmm <Node>', ['32 <int>', '64 <int>', '128 <int>', '128 <int>']))
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node /reshape_default_2 [aten.reshape.default] (Inputs: (bmm: (2048, 128, 128)@torch.float16, [32, 64, 128, 128]) | Outputs: (reshape_default_2: (32, 64, 128, 128)@torch.float16))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node _frozen_param1 (kind: _frozen_param1, args: ())
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node _frozen_param1 [_frozen_param1] (Inputs: () | Outputs: (_frozen_param1: ()@float32))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node /div (kind: aten.div.Tensor, args: ('reshape_default_2 <Node>', '_frozen_param1 <Node>'))
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node /div [aten.div.Tensor] (Inputs: (reshape_default_2: (32, 64, 128, 128)@torch.float16, _frozen_param1: ()@float32) | Outputs: (div: (32, 64, 128, 128)@torch.float16))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node _frozen_param0 (kind: _frozen_param0, args: ())
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node _frozen_param0 [_frozen_param0] (Inputs: () | Outputs: (_frozen_param0: (128, 128)@float16))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node /add (kind: aten.add.Tensor, args: ('div <Node>', '_frozen_param0 <Node>'))
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node /add [aten.add.Tensor] (Inputs: (div: (32, 64, 128, 128)@torch.float16, _frozen_param0: (128, 128)@float16) | Outputs: (add: (32, 64, 128, 128)@torch.float16))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node /_softmax (kind: aten._softmax.default, args: ('add <Node>', '-1 <int>', 'False <bool>'))
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node /_softmax [aten._softmax.default] (Inputs: (add: (32, 64, 128, 128)@torch.float16, -1, False) | Outputs: (_softmax: (32, 64, 128, 128)@torch.float16))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node /expand_2 (kind: aten.expand.default, args: ('_softmax <Node>', ['32 <int>', '64 <int>', '128 <int>', '128 <int>']))
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node /expand_2 [aten.expand.default] (Inputs: (_softmax: (32, 64, 128, 128)@torch.float16, [32, 64, 128, 128]) | Outputs: (expand_2: (32, 64, 128, 128)@torch.float16))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node value (kind: value, args: ())
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Adding input to in-progress INetwork: value [shape=[32, 64, 128, 256], dtype=DataType.HALF]
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node value [value] (Inputs: () | Outputs: (value: (32, 64, 128, 256)@torch.float16))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node /expand_3 (kind: aten.expand.default, args: ('value <Node>', ['32 <int>', '64 <int>', '128 <int>', '256 <int>']))
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node /expand_3 [aten.expand.default] (Inputs: (value: (32, 64, 128, 256)@torch.float16, [32, 64, 128, 256]) | Outputs: (expand_3: (32, 64, 128, 256)@torch.float16))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node /reshape_default_3 (kind: aten.reshape.default, args: ('expand_2 <Node>', ['2048 <int>', '128 <int>', '128 <int>']))
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node /reshape_default_3 [aten.reshape.default] (Inputs: (expand_2: (32, 64, 128, 128)@torch.float16, [2048, 128, 128]) | Outputs: (reshape_default_3: (2048, 128, 128)@torch.float16))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node /reshape_default_4 (kind: aten.reshape.default, args: ('expand_3 <Node>', ['2048 <int>', '128 <int>', '256 <int>']))
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node /reshape_default_4 [aten.reshape.default] (Inputs: (expand_3: (32, 64, 128, 256)@torch.float16, [2048, 128, 256]) | Outputs: (reshape_default_4: (2048, 128, 256)@torch.float16))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node /bmm_1 (kind: aten.bmm.default, args: ('reshape_default_3 <Node>', 'reshape_default_4 <Node>'))
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node /bmm_1 [aten.bmm.default] (Inputs: (reshape_default_3: (2048, 128, 128)@torch.float16, reshape_default_4: (2048, 128, 256)@torch.float16) | Outputs: (bmm_1: (2048, 128, 256)@torch.float16))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node /reshape_default_5 (kind: aten.reshape.default, args: ('bmm_1 <Node>', ['32 <int>', '64 <int>', '128 <int>', '256 <int>']))
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node /reshape_default_5 [aten.reshape.default] (Inputs: (bmm_1: (2048, 128, 256)@torch.float16, [32, 64, 128, 256]) | Outputs: (reshape_default_5: (32, 64, 128, 256)@torch.float16))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node output (kind: output, args: ('reshape_default_5 <Node>',))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Marking output output0 [shape=(32, 64, 128, 256), dtype=DataType.HALF]
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node output [output] (Inputs: (reshape_default_5: (32, 64, 128, 256)@torch.float16) | Outputs: (output: ))
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:TRT INetwork construction elapsed time: 0:00:00.011572
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Not found cached TRT engines. Start building engine.
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Build TRT engine elapsed time: 0:00:00.472751
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:TRT Engine uses: 95868 bytes of Memory
DEBUG: [Torch-TensorRT] - Deserializing Device Info: 0%8%9%0%NVIDIA GeForce RTX 4060 Ti
DEBUG: [Torch-TensorRT] - Deserialized Device Info: Device(ID: 0, Name: NVIDIA GeForce RTX 4060 Ti, SM Capability: 8.9, Type: GPU)
DEBUG: [Torch-TensorRT] - Target Device: Device(ID: 0, Name: NVIDIA GeForce RTX 4060 Ti, SM Capability: 8.9, Type: GPU)
DEBUG: [Torch-TensorRT] - Setting Device(ID: 0, Name: NVIDIA GeForce RTX 4060 Ti, SM Capability: 8.9, Type: GPU) as active device
INFO: [Torch-TensorRT] - Loaded engine size: 0 MiB
DEBUG: [Torch-TensorRT] - Deserialization required 1132 microseconds.
INFO: [Torch-TensorRT] - [MS] Running engine with multi stream info
INFO: [Torch-TensorRT] - [MS] Number of aux streams is 2
INFO: [Torch-TensorRT] - [MS] Number of total worker streams is 3
INFO: [Torch-TensorRT] - [MS] The main stream provided by execute/enqueue calls is the first worker stream
DEBUG: [Torch-TensorRT] - Total per-runner device persistent memory is 0
DEBUG: [Torch-TensorRT] - Total per-runner host persistent memory is 80
DEBUG: [Torch-TensorRT] - Allocated device scratch memory of size 536870912
DEBUG: [Torch-TensorRT] - - Runner scratch: 536870912 bytes
INFO: [Torch-TensorRT] - [MemUsageChange] TensorRT-managed allocation in IExecutionContext creation: CPU +0, GPU +512, now: CPU 0, GPU 512 (MiB)
DEBUG: [Torch-TensorRT] - CUDA lazy loading is enabled.
DEBUG: [Torch-TensorRT] - Input binding name: key has TensorRT binding index: 0, Torch binding index: 0
DEBUG: [Torch-TensorRT] - Input binding name: query has TensorRT binding index: 1, Torch binding index: 1
DEBUG: [Torch-TensorRT] - Input binding name: value has TensorRT binding index: 2, Torch binding index: 2
DEBUG: [Torch-TensorRT] - Output binding name: output0 has TensorRT binding index: 3, Torch binding index: 3
DEBUG: [Torch-TensorRT] - Torch-TensorRT TensorRT Engine:
  Name: _run_on_acc_0_engine
  Inputs: [
    id: 0
      name: key
      shape: [32, 64, 128, 256]
      dtype: Half
    id: 1
      name: query
      shape: [32, 64, 128, 256]
      dtype: Half
    id: 2
      name: value
      shape: [32, 64, 128, 256]
      dtype: Half
  ]
  Outputs: [
    id: 0
      name: output0
      shape: [32, 64, 128, 256]
      dtype: Half
  ]
  Device: Device(ID: 0, Name: NVIDIA GeForce RTX 4060 Ti, SM Capability: 8.9, Type: GPU)
  Hardware Compatibility: Disabled
  Target Platform: windows_x86_64

DEBUG:torch_tensorrt.dynamo._DryRunTracker:
++++++++++++++++++++++++++++++++++++++++++++++++++ Dry-Run Results for Graph ++++++++++++++++++++++++++++++++++++++++++++++++++

The graph consists of 16 Total Operators, of which 16 operators are supported, 100.0% coverage

Compiled with: CompilationSettings(enabled_precisions={<dtype.f16: 6>}, debug=True, workspace_size=0, min_block_size=1, torch_executed_ops=set(), pass_through_build_failures=False, max_aux_streams=None, version_compatible=False, optimization_level=None, use_python_runtime=False, truncate_double=False, use_fast_partitioner=True, enable_experimental_decompositions=False, device=Device(type=DeviceType.GPU, gpu_id=0), require_full_compilation=False, disable_tf32=False, assume_dynamic_shape_support=False, sparse_weights=False, make_refittable=False, engine_capability=<EngineCapability.STANDARD: 1>, num_avg_timing_iters=1, dla_sram_size=1048576, dla_local_dram_size=1073741824, dla_global_dram_size=536870912, dryrun=False, hardware_compatible=False, timing_cache_path='C:\\Users\\HolyWu\\AppData\\Local\\Temp\\torch_tensorrt_engine_cache\\timing_cache.bin', lazy_engine_init=False, cache_built_engines=False, reuse_cached_engines=False, use_explicit_typing=False, use_fp32_acc=False, enable_weight_streaming=False, enable_cross_compile_for_windows=False)

  Graph Structure:

   Inputs: List[Tensor: (32, 64, 128, 256)@float16, Tensor: (32, 64, 128, 256)@float16, Tensor: (32, 64, 128, 256)@float16]
    ...
    TRT Engine #1 - Submodule name: _run_on_acc_0
     Engine Inputs: List[Tensor: (32, 64, 128, 256)@float16, Tensor: (32, 64, 128, 256)@float16, Tensor: (32, 64, 128, 256)@float16]
     Number of Operators in Engine: 16
     Engine Outputs: List[Tensor: (32, 64, 128, 256)@float16]
    ...
   Outputs: List[Tensor: (32, 64, 128, 256)@float16]

  ------------------------- Aggregate Stats -------------------------

   Average Number of Operators per TRT Engine: 16.0
   Most Operators in a TRT Engine: 16

  ********** Recommendations **********

   - For minimal graph segmentation, select min_block_size=16 which would generate 1 TRT engine(s)
   - The current level of graph segmentation is equivalent to selecting min_block_size=16 which generates 1 TRT engine(s)

Timing:
Min=9.578495979309082 ms, Mean=9.71664638519287 ms, Max=10.313728332519531 ms

assert_close passed

@HolyWu HolyWu requested a review from peri044 November 24, 2024 09:44
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed component: api [Python] Issues re: Python API component: conversion Issues re: Conversion stage component: converters Issues re: Specific op converters component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths component: lowering Issues re: The lowering / preprocessing passes component: tests Issues re: Tests
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants