-
Notifications
You must be signed in to change notification settings - Fork 351
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Replace scaled_dot_product_attention lowering pass with decomposition #3296
Open
HolyWu
wants to merge
8
commits into
pytorch:main
Choose a base branch
from
HolyWu:sdpa_decomposition
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
github-actions
bot
added
component: tests
Issues re: Tests
component: lowering
Issues re: The lowering / preprocessing passes
component: conversion
Issues re: Conversion stage
component: converters
Issues re: Specific op converters
component: api [Python]
Issues re: Python API
component: dynamo
Issues relating to the `torch.compile` or `torch._dynamo.export` paths
labels
Nov 17, 2024
peri044
reviewed
Nov 18, 2024
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the contribution. It seems like you're replacing the converter with decomposition. Did you observe any perf gains/regressions with this change ?
The performance is very close in my test. Benchmarking with: from __future__ import annotations
import os
import numpy as np
import torch
import torch_tensorrt
os.environ["CI_BUILD"] = "1"
times = 20
@torch.inference_mode()
def benchmark(model: torch.nn.Module, inputs: list[tuple[torch.Tensor, torch.Tensor, torch.Tensor]]) -> np.ndarray:
# Warm up
for i in range(3):
model(*inputs[i])
torch.cuda.synchronize()
start_events = [torch.cuda.Event(enable_timing=True) for _ in range(times)]
end_events = [torch.cuda.Event(enable_timing=True) for _ in range(times)]
for i in range(times):
torch.cuda._sleep(1_000_000)
start_events[i].record()
model(*inputs[i])
end_events[i].record()
torch.cuda.synchronize()
timings = [s.elapsed_time(e) for s, e in zip(start_events, end_events)]
return np.array(timings)
class MyModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
return torch.nn.functional.scaled_dot_product_attention(query, key, value)
torch.manual_seed(12345)
model = MyModule().eval().cuda().half()
inputs = [
torch_tensorrt.Input((32, 64, 128, 256), dtype=torch.half),
torch_tensorrt.Input((32, 64, 128, 256), dtype=torch.half),
torch_tensorrt.Input((32, 64, 128, 256), dtype=torch.half),
]
trt_model = torch_tensorrt.compile(
model, "dynamo", inputs, enabled_precisions={torch.half}, debug=True, min_block_size=1
)
inputs = [
(
torch.rand(32, 64, 128, 256, dtype=torch.half, device="cuda"),
torch.rand(32, 64, 128, 256, dtype=torch.half, device="cuda"),
torch.rand(32, 64, 128, 256, dtype=torch.half, device="cuda"),
)
for _ in range(times)
]
timing = benchmark(trt_model, inputs)
print("")
print("Timing:")
print(f"Min={timing.min()} ms, Mean={timing.mean()} ms, Max={timing.max()} ms")
print("")
with torch.inference_mode():
for i in range(times):
torch.testing.assert_close(trt_model(*inputs[i]), model(*inputs[i]), rtol=5e-3, atol=5e-3)
print("assert_close passed")
torch._dynamo.reset() ConverterDEBUG:torch_tensorrt.dynamo.lowering.passes.remove_detach:Removed 0 detach nodes:
graph():
%query : [num_users=1] = placeholder[target=query]
%key : [num_users=1] = placeholder[target=key]
%value : [num_users=1] = placeholder[target=value]
%scaled_dot_product_attention : [num_users=1] = call_function[target=torch.ops.aten.scaled_dot_product_attention.default](args = (%query, %key, %value), kwargs = {})
return (scaled_dot_product_attention,)
DEBUG:torch_tensorrt.dynamo._compiler:Input graph: graph():
%query : [num_users=1] = placeholder[target=query]
%key : [num_users=1] = placeholder[target=key]
%value : [num_users=1] = placeholder[target=value]
%_scaled_dot_product_efficient_attention : [num_users=1] = call_function[target=torch.ops.aten._scaled_dot_product_efficient_attention.default](args = (%query, %key, %value, None, False), kwargs = {})
%getitem : [num_users=1] = call_function[target=operator.getitem](args = (%_scaled_dot_product_efficient_attention, 0), kwargs = {})
return (getitem,)
DEBUG:torch_tensorrt.dynamo.lowering.passes.constant_folding:Graph after constant folding:
graph():
%query : [num_users=1] = placeholder[target=query]
%key : [num_users=1] = placeholder[target=key]
%value : [num_users=1] = placeholder[target=value]
%_scaled_dot_product_efficient_attention : [num_users=1] = call_function[target=torch.ops.aten._scaled_dot_product_efficient_attention.default](args = (%query, %key, %value, None, False), kwargs = {})
%getitem : [num_users=1] = call_function[target=operator.getitem](args = (%_scaled_dot_product_efficient_attention, 0), kwargs = {})
return (getitem,)
DEBUG:torch_tensorrt.dynamo.lowering.passes.lower_scaled_dot_product_attention:Graph after lowering scaled dot product attention:
graph():
%query : [num_users=1] = placeholder[target=query]
%key : [num_users=1] = placeholder[target=key]
%value : [num_users=1] = placeholder[target=value]
%scaled_dot_product_attention : [num_users=1] = call_function[target=torch._C._nn.scaled_dot_product_attention](args = (%query, %key, %value, None, 0.0, False), kwargs = {})
return (scaled_dot_product_attention,)
DEBUG:torch_tensorrt.dynamo.lowering.passes.remove_assert_scalar:Removed 0 assert_scalar nodes:
graph():
%query : [num_users=1] = placeholder[target=query]
%key : [num_users=1] = placeholder[target=key]
%value : [num_users=1] = placeholder[target=value]
%scaled_dot_product_attention : [num_users=1] = call_function[target=torch._C._nn.scaled_dot_product_attention](args = (%query, %key, %value, None, 0.0, False), kwargs = {})
return (scaled_dot_product_attention,)
DEBUG:torch_tensorrt.dynamo.lowering.passes.accumulate_fp32_matmul:Skipping FP32 accumulation for matmul layers as use_fp32_acc is not enabled in the compilation settings
DEBUG:torch_tensorrt.dynamo._compiler:Lowered Input graph: graph():
%query : [num_users=1] = placeholder[target=query]
%key : [num_users=1] = placeholder[target=key]
%value : [num_users=1] = placeholder[target=value]
%scaled_dot_product_attention : [num_users=1] = call_function[target=torch._C._nn.scaled_dot_product_attention](args = (%query, %key, %value, None, 0.0, False), kwargs = {})
return (scaled_dot_product_attention,)
DEBUG:torch_tensorrt.dynamo.partitioning._global_partitioner:
Supported Nodes:
- torch._C._nn.scaled_dot_product_attention + Operator Count: 1
DEBUG:torch_tensorrt.dynamo.partitioning._global_partitioner:
All Nodes Supported
DEBUG:torch_tensorrt.dynamo._compiler:Detected support for 1 operators out of 1 in subgraph.
INFO:torch_tensorrt.dynamo._compiler:Partitioning the graph via the fast partitioner
DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:
Number of TensorRT-Accelerated Engines Generated: 1
DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:
Supported Nodes:
- torch._C._nn.scaled_dot_product_attention + Operator Count: 1
DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:
All Nodes Supported
DEBUG:torch_tensorrt.dynamo._compiler:Updated metadata for node: _run_on_acc_0 with its corresponding submodule outputs
DEBUG:torch_tensorrt.dynamo._compiler:Converting submodule: _run_on_acc_0
Input shapes: [(32, 64, 128, 256), (32, 64, 128, 256), (32, 64, 128, 256)]
graph():
%query : [num_users=1] = placeholder[target=query]
%key : [num_users=1] = placeholder[target=key]
%value : [num_users=1] = placeholder[target=value]
%scaled_dot_product_attention : [num_users=1] = call_function[target=torch._C._nn.scaled_dot_product_attention](args = (%query, %key, %value, None, 0.0, False), kwargs = {})
return scaled_dot_product_attention
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node query (kind: query, args: ())
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Adding input to in-progress INetwork: query [shape=[32, 64, 128, 256], dtype=DataType.HALF]
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node query [query] (Inputs: () | Outputs: (query: (32, 64, 128, 256)@torch.float16))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node key (kind: key, args: ())
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Adding input to in-progress INetwork: key [shape=[32, 64, 128, 256], dtype=DataType.HALF]
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node key [key] (Inputs: () | Outputs: (key: (32, 64, 128, 256)@torch.float16))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node value (kind: value, args: ())
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Adding input to in-progress INetwork: value [shape=[32, 64, 128, 256], dtype=DataType.HALF]
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node value [value] (Inputs: () | Outputs: (value: (32, 64, 128, 256)@torch.float16))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node scaled_dot_product_attention (kind: <built-in function scaled_dot_product_attention>, args: ('query <Node>', 'key <Node>', 'value <Node>', 'None <NoneType>', '0.0 <float>', 'False <bool>'))
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node scaled_dot_product_attention [<built-in function scaled_dot_product_attention>] (Inputs: (query: (32, 64, 128, 256)@torch.float16, key: (32, 64, 128, 256)@torch.float16, value: (32, 64, 128, 256)@torch.float16, None, 0.0, False) | Outputs: (scaled_dot_product_attention: (32, 64, 128, 256)@torch.float16))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node output (kind: output, args: ('scaled_dot_product_attention <Node>',))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Marking output output0 [shape=(32, 64, 128, 256), dtype=DataType.HALF]
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node output [output] (Inputs: (scaled_dot_product_attention: (32, 64, 128, 256)@torch.float16) | Outputs: (output: ))
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:TRT INetwork construction elapsed time: 0:00:00.004000
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Not found cached TRT engines. Start building engine.
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Build TRT engine elapsed time: 0:00:00.370738
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:TRT Engine uses: 62148 bytes of Memory
DEBUG: [Torch-TensorRT] - Deserializing Device Info: 0%8%9%0%NVIDIA GeForce RTX 4060 Ti
DEBUG: [Torch-TensorRT] - Deserialized Device Info: Device(ID: 0, Name: NVIDIA GeForce RTX 4060 Ti, SM Capability: 8.9, Type: GPU)
DEBUG: [Torch-TensorRT] - Target Device: Device(ID: 0, Name: NVIDIA GeForce RTX 4060 Ti, SM Capability: 8.9, Type: GPU)
DEBUG: [Torch-TensorRT] - Setting Device(ID: 0, Name: NVIDIA GeForce RTX 4060 Ti, SM Capability: 8.9, Type: GPU) as active device
INFO: [Torch-TensorRT] - Loaded engine size: 0 MiB
DEBUG: [Torch-TensorRT] - Deserialization required 106 microseconds.
INFO: [Torch-TensorRT] - [MS] Running engine with multi stream info
INFO: [Torch-TensorRT] - [MS] Number of aux streams is 2
INFO: [Torch-TensorRT] - [MS] Number of total worker streams is 3
INFO: [Torch-TensorRT] - [MS] The main stream provided by execute/enqueue calls is the first worker stream
DEBUG: [Torch-TensorRT] - Total per-runner device persistent memory is 0
DEBUG: [Torch-TensorRT] - Total per-runner host persistent memory is 80
DEBUG: [Torch-TensorRT] - Allocated device scratch memory of size 536870912
DEBUG: [Torch-TensorRT] - - Runner scratch: 536870912 bytes
INFO: [Torch-TensorRT] - [MemUsageChange] TensorRT-managed allocation in IExecutionContext creation: CPU +0, GPU +512, now: CPU 0, GPU 512 (MiB)
DEBUG: [Torch-TensorRT] - CUDA lazy loading is enabled.
DEBUG: [Torch-TensorRT] - Input binding name: query has TensorRT binding index: 0, Torch binding index: 0
DEBUG: [Torch-TensorRT] - Input binding name: key has TensorRT binding index: 1, Torch binding index: 1
DEBUG: [Torch-TensorRT] - Input binding name: value has TensorRT binding index: 2, Torch binding index: 2
DEBUG: [Torch-TensorRT] - Output binding name: output0 has TensorRT binding index: 3, Torch binding index: 3
DEBUG: [Torch-TensorRT] - Torch-TensorRT TensorRT Engine:
Name: _run_on_acc_0_engine
Inputs: [
id: 0
name: query
shape: [32, 64, 128, 256]
dtype: Half
id: 1
name: key
shape: [32, 64, 128, 256]
dtype: Half
id: 2
name: value
shape: [32, 64, 128, 256]
dtype: Half
]
Outputs: [
id: 0
name: output0
shape: [32, 64, 128, 256]
dtype: Half
]
Device: Device(ID: 0, Name: NVIDIA GeForce RTX 4060 Ti, SM Capability: 8.9, Type: GPU)
Hardware Compatibility: Disabled
Target Platform: windows_x86_64
DEBUG:torch_tensorrt.dynamo._DryRunTracker:
++++++++++++++++++++++++++++++++++++++++++++++++++ Dry-Run Results for Graph ++++++++++++++++++++++++++++++++++++++++++++++++++
The graph consists of 1 Total Operators, of which 1 operators are supported, 100.0% coverage
Compiled with: CompilationSettings(enabled_precisions={<dtype.f16: 6>}, debug=True, workspace_size=0, min_block_size=1, torch_executed_ops=set(), pass_through_build_failures=False, max_aux_streams=None, version_compatible=False, optimization_level=None, use_python_runtime=False, truncate_double=False, use_fast_partitioner=True, enable_experimental_decompositions=False, device=Device(type=DeviceType.GPU, gpu_id=0), require_full_compilation=False, disable_tf32=False, assume_dynamic_shape_support=False, sparse_weights=False, make_refittable=False, engine_capability=<EngineCapability.STANDARD: 1>, num_avg_timing_iters=1, dla_sram_size=1048576, dla_local_dram_size=1073741824, dla_global_dram_size=536870912, dryrun=False, hardware_compatible=False, timing_cache_path='C:\\Users\\HolyWu\\AppData\\Local\\Temp\\torch_tensorrt_engine_cache\\timing_cache.bin', lazy_engine_init=False, cache_built_engines=False, reuse_cached_engines=False, use_explicit_typing=False, use_fp32_acc=False, enable_weight_streaming=False, enable_cross_compile_for_windows=False)
Graph Structure:
Inputs: List[Tensor: (32, 64, 128, 256)@float16, Tensor: (32, 64, 128, 256)@float16, Tensor: (32, 64, 128, 256)@float16]
...
TRT Engine #1 - Submodule name: _run_on_acc_0
Engine Inputs: List[Tensor: (32, 64, 128, 256)@float16, Tensor: (32, 64, 128, 256)@float16, Tensor: (32, 64, 128, 256)@float16]
Number of Operators in Engine: 1
Engine Outputs: List[Tensor: (32, 64, 128, 256)@float16]
...
Outputs: List[Tensor: (32, 64, 128, 256)@float16]
------------------------- Aggregate Stats -------------------------
Average Number of Operators per TRT Engine: 1.0
Most Operators in a TRT Engine: 1
********** Recommendations **********
- For minimal graph segmentation, select min_block_size=1 which would generate 1 TRT engine(s)
- The current level of graph segmentation is equivalent to selecting min_block_size=1 which generates 1 TRT engine(s)
Timing:
Min=9.552767753601074 ms, Mean=9.709788846969605 ms, Max=10.579968452453613 ms
assert_close passed DecompositionDEBUG:torch_tensorrt.dynamo.lowering.passes.remove_detach:Removed 0 detach nodes:
graph():
%query : [num_users=1] = placeholder[target=query]
%key : [num_users=1] = placeholder[target=key]
%value : [num_users=1] = placeholder[target=value]
%scaled_dot_product_attention : [num_users=1] = call_function[target=torch.ops.aten.scaled_dot_product_attention.default](args = (%query, %key, %value), kwargs = {})
return (scaled_dot_product_attention,)
DEBUG:torch_tensorrt.dynamo._compiler:Input graph: graph():
%query : [num_users=1] = placeholder[target=query]
%key : [num_users=1] = placeholder[target=key]
%value : [num_users=1] = placeholder[target=value]
%full : [num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([128, 128], 0), kwargs = {dtype: torch.float16, layout: torch.strided, device: cuda:0, pin_memory: False})
%permute : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%key, [0, 1, 3, 2]), kwargs = {})
%expand : [num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%query, [32, 64, 128, 256]), kwargs = {})
%view : [num_users=1] = call_function[target=torch.ops.aten.view.default](args = (%expand, [2048, 128, 256]), kwargs = {})
%expand_1 : [num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%permute, [32, 64, 256, 128]), kwargs = {})
%view_1 : [num_users=1] = call_function[target=torch.ops.aten.view.default](args = (%expand_1, [2048, 256, 128]), kwargs = {})
%bmm : [num_users=1] = call_function[target=torch.ops.aten.bmm.default](args = (%view, %view_1), kwargs = {})
%view_2 : [num_users=1] = call_function[target=torch.ops.aten.view.default](args = (%bmm, [32, 64, 128, 128]), kwargs = {})
%scalar_tensor : [num_users=1] = call_function[target=torch.ops.aten.scalar_tensor.default](args = (256,), kwargs = {dtype: torch.int32, device: cpu, pin_memory: False})
%sqrt : [num_users=1] = call_function[target=torch.ops.aten.sqrt.default](args = (%scalar_tensor,), kwargs = {})
%div : [num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%view_2, %sqrt), kwargs = {})
%add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%div, %full), kwargs = {})
%_softmax : [num_users=1] = call_function[target=torch.ops.aten._softmax.default](args = (%add, -1, False), kwargs = {})
%expand_2 : [num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%_softmax, [32, 64, 128, 128]), kwargs = {})
%view_3 : [num_users=1] = call_function[target=torch.ops.aten.view.default](args = (%expand_2, [2048, 128, 128]), kwargs = {})
%expand_3 : [num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%value, [32, 64, 128, 256]), kwargs = {})
%view_4 : [num_users=1] = call_function[target=torch.ops.aten.view.default](args = (%expand_3, [2048, 128, 256]), kwargs = {})
%bmm_1 : [num_users=1] = call_function[target=torch.ops.aten.bmm.default](args = (%view_3, %view_4), kwargs = {})
%view_5 : [num_users=1] = call_function[target=torch.ops.aten.view.default](args = (%bmm_1, [32, 64, 128, 256]), kwargs = {})
return (view_5,)
DEBUG:torch_tensorrt.dynamo.lowering.passes.constant_folding:Graph after constant folding:
graph():
%query : [num_users=1] = placeholder[target=query]
%key : [num_users=1] = placeholder[target=key]
%value : [num_users=1] = placeholder[target=value]
%_frozen_param0 : [num_users=1] = get_attr[target=_frozen_param0]
%permute : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%key, [0, 1, 3, 2]), kwargs = {})
%expand : [num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%query, [32, 64, 128, 256]), kwargs = {})
%view : [num_users=1] = call_function[target=torch.ops.aten.view.default](args = (%expand, [2048, 128, 256]), kwargs = {})
%expand_1 : [num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%permute, [32, 64, 256, 128]), kwargs = {})
%view_1 : [num_users=1] = call_function[target=torch.ops.aten.view.default](args = (%expand_1, [2048, 256, 128]), kwargs = {})
%bmm : [num_users=1] = call_function[target=torch.ops.aten.bmm.default](args = (%view, %view_1), kwargs = {})
%view_2 : [num_users=1] = call_function[target=torch.ops.aten.view.default](args = (%bmm, [32, 64, 128, 128]), kwargs = {})
%_frozen_param1 : [num_users=1] = get_attr[target=_frozen_param1]
%div : [num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%view_2, %_frozen_param1), kwargs = {})
%add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%div, %_frozen_param0), kwargs = {})
%_softmax : [num_users=1] = call_function[target=torch.ops.aten._softmax.default](args = (%add, -1, False), kwargs = {})
%expand_2 : [num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%_softmax, [32, 64, 128, 128]), kwargs = {})
%view_3 : [num_users=1] = call_function[target=torch.ops.aten.view.default](args = (%expand_2, [2048, 128, 128]), kwargs = {})
%expand_3 : [num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%value, [32, 64, 128, 256]), kwargs = {})
%view_4 : [num_users=1] = call_function[target=torch.ops.aten.view.default](args = (%expand_3, [2048, 128, 256]), kwargs = {})
%bmm_1 : [num_users=1] = call_function[target=torch.ops.aten.bmm.default](args = (%view_3, %view_4), kwargs = {})
%view_5 : [num_users=1] = call_function[target=torch.ops.aten.view.default](args = (%bmm_1, [32, 64, 128, 256]), kwargs = {})
return (view_5,)
DEBUG:torch_tensorrt.dynamo.lowering.passes.view_to_reshape:Graph after replacing view with reshape:
graph():
%query : [num_users=1] = placeholder[target=query]
%key : [num_users=1] = placeholder[target=key]
%value : [num_users=1] = placeholder[target=value]
%_frozen_param0 : [num_users=1] = get_attr[target=_frozen_param0]
%permute : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%key, [0, 1, 3, 2]), kwargs = {})
%expand : [num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%query, [32, 64, 128, 256]), kwargs = {})
%expand_1 : [num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%permute, [32, 64, 256, 128]), kwargs = {})
%reshape_default : [num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%expand, [2048, 128, 256]), kwargs = {})
%reshape_default_1 : [num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%expand_1, [2048, 256, 128]), kwargs = {})
%bmm : [num_users=1] = call_function[target=torch.ops.aten.bmm.default](args = (%reshape_default, %reshape_default_1), kwargs = {})
%_frozen_param1 : [num_users=1] = get_attr[target=_frozen_param1]
%reshape_default_2 : [num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%bmm, [32, 64, 128, 128]), kwargs = {})
%div : [num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%reshape_default_2, %_frozen_param1), kwargs = {})
%add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%div, %_frozen_param0), kwargs = {})
%_softmax : [num_users=1] = call_function[target=torch.ops.aten._softmax.default](args = (%add, -1, False), kwargs = {})
%expand_2 : [num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%_softmax, [32, 64, 128, 128]), kwargs = {})
%expand_3 : [num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%value, [32, 64, 128, 256]), kwargs = {})
%reshape_default_3 : [num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%expand_2, [2048, 128, 128]), kwargs = {})
%reshape_default_4 : [num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%expand_3, [2048, 128, 256]), kwargs = {})
%bmm_1 : [num_users=1] = call_function[target=torch.ops.aten.bmm.default](args = (%reshape_default_3, %reshape_default_4), kwargs = {})
%reshape_default_5 : [num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%bmm_1, [32, 64, 128, 256]), kwargs = {})
return (reshape_default_5,)
DEBUG:torch_tensorrt.dynamo.lowering.passes.remove_assert_scalar:Removed 0 assert_scalar nodes:
graph():
%query : [num_users=1] = placeholder[target=query]
%key : [num_users=1] = placeholder[target=key]
%value : [num_users=1] = placeholder[target=value]
%_frozen_param0 : [num_users=1] = get_attr[target=_frozen_param0]
%permute : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%key, [0, 1, 3, 2]), kwargs = {})
%expand : [num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%query, [32, 64, 128, 256]), kwargs = {})
%expand_1 : [num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%permute, [32, 64, 256, 128]), kwargs = {})
%reshape_default : [num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%expand, [2048, 128, 256]), kwargs = {})
%reshape_default_1 : [num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%expand_1, [2048, 256, 128]), kwargs = {})
%bmm : [num_users=1] = call_function[target=torch.ops.aten.bmm.default](args = (%reshape_default, %reshape_default_1), kwargs = {})
%_frozen_param1 : [num_users=1] = get_attr[target=_frozen_param1]
%reshape_default_2 : [num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%bmm, [32, 64, 128, 128]), kwargs = {})
%div : [num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%reshape_default_2, %_frozen_param1), kwargs = {})
%add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%div, %_frozen_param0), kwargs = {})
%_softmax : [num_users=1] = call_function[target=torch.ops.aten._softmax.default](args = (%add, -1, False), kwargs = {})
%expand_2 : [num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%_softmax, [32, 64, 128, 128]), kwargs = {})
%expand_3 : [num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%value, [32, 64, 128, 256]), kwargs = {})
%reshape_default_3 : [num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%expand_2, [2048, 128, 128]), kwargs = {})
%reshape_default_4 : [num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%expand_3, [2048, 128, 256]), kwargs = {})
%bmm_1 : [num_users=1] = call_function[target=torch.ops.aten.bmm.default](args = (%reshape_default_3, %reshape_default_4), kwargs = {})
%reshape_default_5 : [num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%bmm_1, [32, 64, 128, 256]), kwargs = {})
return (reshape_default_5,)
DEBUG:torch_tensorrt.dynamo.lowering.passes.accumulate_fp32_matmul:Skipping FP32 accumulation for matmul layers as use_fp32_acc is not enabled in the compilation settings
DEBUG:torch_tensorrt.dynamo._compiler:Lowered Input graph: graph():
%query : [num_users=1] = placeholder[target=query]
%key : [num_users=1] = placeholder[target=key]
%value : [num_users=1] = placeholder[target=value]
%_frozen_param0 : [num_users=1] = get_attr[target=_frozen_param0]
%permute : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%key, [0, 1, 3, 2]), kwargs = {})
%expand : [num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%query, [32, 64, 128, 256]), kwargs = {})
%expand_1 : [num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%permute, [32, 64, 256, 128]), kwargs = {})
%reshape_default : [num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%expand, [2048, 128, 256]), kwargs = {})
%reshape_default_1 : [num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%expand_1, [2048, 256, 128]), kwargs = {})
%bmm : [num_users=1] = call_function[target=torch.ops.aten.bmm.default](args = (%reshape_default, %reshape_default_1), kwargs = {})
%_frozen_param1 : [num_users=1] = get_attr[target=_frozen_param1]
%reshape_default_2 : [num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%bmm, [32, 64, 128, 128]), kwargs = {})
%div : [num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%reshape_default_2, %_frozen_param1), kwargs = {})
%add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%div, %_frozen_param0), kwargs = {})
%_softmax : [num_users=1] = call_function[target=torch.ops.aten._softmax.default](args = (%add, -1, False), kwargs = {})
%expand_2 : [num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%_softmax, [32, 64, 128, 128]), kwargs = {})
%expand_3 : [num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%value, [32, 64, 128, 256]), kwargs = {})
%reshape_default_3 : [num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%expand_2, [2048, 128, 128]), kwargs = {})
%reshape_default_4 : [num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%expand_3, [2048, 128, 256]), kwargs = {})
%bmm_1 : [num_users=1] = call_function[target=torch.ops.aten.bmm.default](args = (%reshape_default_3, %reshape_default_4), kwargs = {})
%reshape_default_5 : [num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%bmm_1, [32, 64, 128, 256]), kwargs = {})
return (reshape_default_5,)
DEBUG:torch_tensorrt.dynamo.partitioning._global_partitioner:
Supported Nodes:
- torch.ops.aten.permute.default + Operator Count: 1
- torch.ops.aten.expand.default + Operator Count: 4
- torch.ops.aten.reshape.default + Operator Count: 6
- torch.ops.aten.bmm.default + Operator Count: 2
- torch.ops.aten.div.Tensor + Operator Count: 1
- torch.ops.aten.add.Tensor + Operator Count: 1
- torch.ops.aten._softmax.default + Operator Count: 1
DEBUG:torch_tensorrt.dynamo.partitioning._global_partitioner:
All Nodes Supported
DEBUG:torch_tensorrt.dynamo._compiler:Detected support for 16 operators out of 16 in subgraph.
INFO:torch_tensorrt.dynamo._compiler:Partitioning the graph via the fast partitioner
DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:
Number of TensorRT-Accelerated Engines Generated: 1
DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:
Supported Nodes:
- torch.ops.aten.permute.default + Operator Count: 1
- torch.ops.aten.expand.default + Operator Count: 4
- torch.ops.aten.reshape.default + Operator Count: 6
- torch.ops.aten.bmm.default + Operator Count: 2
- torch.ops.aten.div.Tensor + Operator Count: 1
- torch.ops.aten.add.Tensor + Operator Count: 1
- torch.ops.aten._softmax.default + Operator Count: 1
DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:
All Nodes Supported
DEBUG:torch_tensorrt.dynamo._compiler:Updated metadata for node: _run_on_acc_0 with its corresponding submodule outputs
DEBUG:torch_tensorrt.dynamo._compiler:Converting submodule: _run_on_acc_0
Input shapes: [(32, 64, 128, 256), (32, 64, 128, 256), (32, 64, 128, 256)]
graph():
%key : [num_users=1] = placeholder[target=key]
%permute : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%key, [0, 1, 3, 2]), kwargs = {})
%query : [num_users=1] = placeholder[target=query]
%expand : [num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%query, [32, 64, 128, 256]), kwargs = {})
%expand_1 : [num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%permute, [32, 64, 256, 128]), kwargs = {})
%reshape_default : [num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%expand, [2048, 128, 256]), kwargs = {})
%reshape_default_1 : [num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%expand_1, [2048, 256, 128]), kwargs = {})
%bmm : [num_users=1] = call_function[target=torch.ops.aten.bmm.default](args = (%reshape_default, %reshape_default_1), kwargs = {})
%reshape_default_2 : [num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%bmm, [32, 64, 128, 128]), kwargs = {})
%_frozen_param1 : [num_users=1] = get_attr[target=_frozen_param1]
%div : [num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%reshape_default_2, %_frozen_param1), kwargs = {})
%_frozen_param0 : [num_users=1] = get_attr[target=_frozen_param0]
%add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%div, %_frozen_param0), kwargs = {})
%_softmax : [num_users=1] = call_function[target=torch.ops.aten._softmax.default](args = (%add, -1, False), kwargs = {})
%expand_2 : [num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%_softmax, [32, 64, 128, 128]), kwargs = {})
%value : [num_users=1] = placeholder[target=value]
%expand_3 : [num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%value, [32, 64, 128, 256]), kwargs = {})
%reshape_default_3 : [num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%expand_2, [2048, 128, 128]), kwargs = {})
%reshape_default_4 : [num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%expand_3, [2048, 128, 256]), kwargs = {})
%bmm_1 : [num_users=1] = call_function[target=torch.ops.aten.bmm.default](args = (%reshape_default_3, %reshape_default_4), kwargs = {})
%reshape_default_5 : [num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%bmm_1, [32, 64, 128, 256]), kwargs = {})
return reshape_default_5
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node key (kind: key, args: ())
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Adding input to in-progress INetwork: key [shape=[32, 64, 128, 256], dtype=DataType.HALF]
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node key [key] (Inputs: () | Outputs: (key: (32, 64, 128, 256)@torch.float16))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node /permute (kind: aten.permute.default, args: ('key <Node>', ['0 <int>', '1 <int>', '3 <int>', '2 <int>']))
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node /permute [aten.permute.default] (Inputs: (key: (32, 64, 128, 256)@torch.float16, [0, 1, 3, 2]) | Outputs: (permute: (32, 64, 256, 128)@torch.float16))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node query (kind: query, args: ())
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Adding input to in-progress INetwork: query [shape=[32, 64, 128, 256], dtype=DataType.HALF]
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node query [query] (Inputs: () | Outputs: (query: (32, 64, 128, 256)@torch.float16))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node /expand (kind: aten.expand.default, args: ('query <Node>', ['32 <int>', '64 <int>', '128 <int>', '256 <int>']))
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node /expand [aten.expand.default] (Inputs: (query: (32, 64, 128, 256)@torch.float16, [32, 64, 128, 256]) | Outputs: (expand: (32, 64, 128, 256)@torch.float16))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node /expand_1 (kind: aten.expand.default, args: ('permute <Node>', ['32 <int>', '64 <int>', '256 <int>', '128 <int>']))
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node /expand_1 [aten.expand.default] (Inputs: (permute: (32, 64, 256, 128)@torch.float16, [32, 64, 256, 128]) | Outputs: (expand_1: (32, 64, 256, 128)@torch.float16))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node /reshape_default (kind: aten.reshape.default, args: ('expand <Node>', ['2048 <int>', '128 <int>', '256 <int>']))
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node /reshape_default [aten.reshape.default] (Inputs: (expand: (32, 64, 128, 256)@torch.float16, [2048, 128, 256]) | Outputs: (reshape_default: (2048, 128, 256)@torch.float16))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node /reshape_default_1 (kind: aten.reshape.default, args: ('expand_1 <Node>', ['2048 <int>', '256 <int>', '128 <int>']))
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node /reshape_default_1 [aten.reshape.default] (Inputs: (expand_1: (32, 64, 256, 128)@torch.float16, [2048, 256, 128]) | Outputs: (reshape_default_1: (2048, 256, 128)@torch.float16))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node /bmm (kind: aten.bmm.default, args: ('reshape_default <Node>', 'reshape_default_1 <Node>'))
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node /bmm [aten.bmm.default] (Inputs: (reshape_default: (2048, 128, 256)@torch.float16, reshape_default_1: (2048, 256, 128)@torch.float16) | Outputs: (bmm: (2048, 128, 128)@torch.float16))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node /reshape_default_2 (kind: aten.reshape.default, args: ('bmm <Node>', ['32 <int>', '64 <int>', '128 <int>', '128 <int>']))
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node /reshape_default_2 [aten.reshape.default] (Inputs: (bmm: (2048, 128, 128)@torch.float16, [32, 64, 128, 128]) | Outputs: (reshape_default_2: (32, 64, 128, 128)@torch.float16))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node _frozen_param1 (kind: _frozen_param1, args: ())
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node _frozen_param1 [_frozen_param1] (Inputs: () | Outputs: (_frozen_param1: ()@float32))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node /div (kind: aten.div.Tensor, args: ('reshape_default_2 <Node>', '_frozen_param1 <Node>'))
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node /div [aten.div.Tensor] (Inputs: (reshape_default_2: (32, 64, 128, 128)@torch.float16, _frozen_param1: ()@float32) | Outputs: (div: (32, 64, 128, 128)@torch.float16))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node _frozen_param0 (kind: _frozen_param0, args: ())
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node _frozen_param0 [_frozen_param0] (Inputs: () | Outputs: (_frozen_param0: (128, 128)@float16))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node /add (kind: aten.add.Tensor, args: ('div <Node>', '_frozen_param0 <Node>'))
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node /add [aten.add.Tensor] (Inputs: (div: (32, 64, 128, 128)@torch.float16, _frozen_param0: (128, 128)@float16) | Outputs: (add: (32, 64, 128, 128)@torch.float16))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node /_softmax (kind: aten._softmax.default, args: ('add <Node>', '-1 <int>', 'False <bool>'))
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node /_softmax [aten._softmax.default] (Inputs: (add: (32, 64, 128, 128)@torch.float16, -1, False) | Outputs: (_softmax: (32, 64, 128, 128)@torch.float16))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node /expand_2 (kind: aten.expand.default, args: ('_softmax <Node>', ['32 <int>', '64 <int>', '128 <int>', '128 <int>']))
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node /expand_2 [aten.expand.default] (Inputs: (_softmax: (32, 64, 128, 128)@torch.float16, [32, 64, 128, 128]) | Outputs: (expand_2: (32, 64, 128, 128)@torch.float16))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node value (kind: value, args: ())
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Adding input to in-progress INetwork: value [shape=[32, 64, 128, 256], dtype=DataType.HALF]
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node value [value] (Inputs: () | Outputs: (value: (32, 64, 128, 256)@torch.float16))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node /expand_3 (kind: aten.expand.default, args: ('value <Node>', ['32 <int>', '64 <int>', '128 <int>', '256 <int>']))
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node /expand_3 [aten.expand.default] (Inputs: (value: (32, 64, 128, 256)@torch.float16, [32, 64, 128, 256]) | Outputs: (expand_3: (32, 64, 128, 256)@torch.float16))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node /reshape_default_3 (kind: aten.reshape.default, args: ('expand_2 <Node>', ['2048 <int>', '128 <int>', '128 <int>']))
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node /reshape_default_3 [aten.reshape.default] (Inputs: (expand_2: (32, 64, 128, 128)@torch.float16, [2048, 128, 128]) | Outputs: (reshape_default_3: (2048, 128, 128)@torch.float16))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node /reshape_default_4 (kind: aten.reshape.default, args: ('expand_3 <Node>', ['2048 <int>', '128 <int>', '256 <int>']))
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node /reshape_default_4 [aten.reshape.default] (Inputs: (expand_3: (32, 64, 128, 256)@torch.float16, [2048, 128, 256]) | Outputs: (reshape_default_4: (2048, 128, 256)@torch.float16))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node /bmm_1 (kind: aten.bmm.default, args: ('reshape_default_3 <Node>', 'reshape_default_4 <Node>'))
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node /bmm_1 [aten.bmm.default] (Inputs: (reshape_default_3: (2048, 128, 128)@torch.float16, reshape_default_4: (2048, 128, 256)@torch.float16) | Outputs: (bmm_1: (2048, 128, 256)@torch.float16))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node /reshape_default_5 (kind: aten.reshape.default, args: ('bmm_1 <Node>', ['32 <int>', '64 <int>', '128 <int>', '256 <int>']))
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node /reshape_default_5 [aten.reshape.default] (Inputs: (bmm_1: (2048, 128, 256)@torch.float16, [32, 64, 128, 256]) | Outputs: (reshape_default_5: (32, 64, 128, 256)@torch.float16))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node output (kind: output, args: ('reshape_default_5 <Node>',))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Marking output output0 [shape=(32, 64, 128, 256), dtype=DataType.HALF]
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node output [output] (Inputs: (reshape_default_5: (32, 64, 128, 256)@torch.float16) | Outputs: (output: ))
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:TRT INetwork construction elapsed time: 0:00:00.011572
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Not found cached TRT engines. Start building engine.
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Build TRT engine elapsed time: 0:00:00.472751
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:TRT Engine uses: 95868 bytes of Memory
DEBUG: [Torch-TensorRT] - Deserializing Device Info: 0%8%9%0%NVIDIA GeForce RTX 4060 Ti
DEBUG: [Torch-TensorRT] - Deserialized Device Info: Device(ID: 0, Name: NVIDIA GeForce RTX 4060 Ti, SM Capability: 8.9, Type: GPU)
DEBUG: [Torch-TensorRT] - Target Device: Device(ID: 0, Name: NVIDIA GeForce RTX 4060 Ti, SM Capability: 8.9, Type: GPU)
DEBUG: [Torch-TensorRT] - Setting Device(ID: 0, Name: NVIDIA GeForce RTX 4060 Ti, SM Capability: 8.9, Type: GPU) as active device
INFO: [Torch-TensorRT] - Loaded engine size: 0 MiB
DEBUG: [Torch-TensorRT] - Deserialization required 1132 microseconds.
INFO: [Torch-TensorRT] - [MS] Running engine with multi stream info
INFO: [Torch-TensorRT] - [MS] Number of aux streams is 2
INFO: [Torch-TensorRT] - [MS] Number of total worker streams is 3
INFO: [Torch-TensorRT] - [MS] The main stream provided by execute/enqueue calls is the first worker stream
DEBUG: [Torch-TensorRT] - Total per-runner device persistent memory is 0
DEBUG: [Torch-TensorRT] - Total per-runner host persistent memory is 80
DEBUG: [Torch-TensorRT] - Allocated device scratch memory of size 536870912
DEBUG: [Torch-TensorRT] - - Runner scratch: 536870912 bytes
INFO: [Torch-TensorRT] - [MemUsageChange] TensorRT-managed allocation in IExecutionContext creation: CPU +0, GPU +512, now: CPU 0, GPU 512 (MiB)
DEBUG: [Torch-TensorRT] - CUDA lazy loading is enabled.
DEBUG: [Torch-TensorRT] - Input binding name: key has TensorRT binding index: 0, Torch binding index: 0
DEBUG: [Torch-TensorRT] - Input binding name: query has TensorRT binding index: 1, Torch binding index: 1
DEBUG: [Torch-TensorRT] - Input binding name: value has TensorRT binding index: 2, Torch binding index: 2
DEBUG: [Torch-TensorRT] - Output binding name: output0 has TensorRT binding index: 3, Torch binding index: 3
DEBUG: [Torch-TensorRT] - Torch-TensorRT TensorRT Engine:
Name: _run_on_acc_0_engine
Inputs: [
id: 0
name: key
shape: [32, 64, 128, 256]
dtype: Half
id: 1
name: query
shape: [32, 64, 128, 256]
dtype: Half
id: 2
name: value
shape: [32, 64, 128, 256]
dtype: Half
]
Outputs: [
id: 0
name: output0
shape: [32, 64, 128, 256]
dtype: Half
]
Device: Device(ID: 0, Name: NVIDIA GeForce RTX 4060 Ti, SM Capability: 8.9, Type: GPU)
Hardware Compatibility: Disabled
Target Platform: windows_x86_64
DEBUG:torch_tensorrt.dynamo._DryRunTracker:
++++++++++++++++++++++++++++++++++++++++++++++++++ Dry-Run Results for Graph ++++++++++++++++++++++++++++++++++++++++++++++++++
The graph consists of 16 Total Operators, of which 16 operators are supported, 100.0% coverage
Compiled with: CompilationSettings(enabled_precisions={<dtype.f16: 6>}, debug=True, workspace_size=0, min_block_size=1, torch_executed_ops=set(), pass_through_build_failures=False, max_aux_streams=None, version_compatible=False, optimization_level=None, use_python_runtime=False, truncate_double=False, use_fast_partitioner=True, enable_experimental_decompositions=False, device=Device(type=DeviceType.GPU, gpu_id=0), require_full_compilation=False, disable_tf32=False, assume_dynamic_shape_support=False, sparse_weights=False, make_refittable=False, engine_capability=<EngineCapability.STANDARD: 1>, num_avg_timing_iters=1, dla_sram_size=1048576, dla_local_dram_size=1073741824, dla_global_dram_size=536870912, dryrun=False, hardware_compatible=False, timing_cache_path='C:\\Users\\HolyWu\\AppData\\Local\\Temp\\torch_tensorrt_engine_cache\\timing_cache.bin', lazy_engine_init=False, cache_built_engines=False, reuse_cached_engines=False, use_explicit_typing=False, use_fp32_acc=False, enable_weight_streaming=False, enable_cross_compile_for_windows=False)
Graph Structure:
Inputs: List[Tensor: (32, 64, 128, 256)@float16, Tensor: (32, 64, 128, 256)@float16, Tensor: (32, 64, 128, 256)@float16]
...
TRT Engine #1 - Submodule name: _run_on_acc_0
Engine Inputs: List[Tensor: (32, 64, 128, 256)@float16, Tensor: (32, 64, 128, 256)@float16, Tensor: (32, 64, 128, 256)@float16]
Number of Operators in Engine: 16
Engine Outputs: List[Tensor: (32, 64, 128, 256)@float16]
...
Outputs: List[Tensor: (32, 64, 128, 256)@float16]
------------------------- Aggregate Stats -------------------------
Average Number of Operators per TRT Engine: 16.0
Most Operators in a TRT Engine: 16
********** Recommendations **********
- For minimal graph segmentation, select min_block_size=16 which would generate 1 TRT engine(s)
- The current level of graph segmentation is equivalent to selecting min_block_size=16 which generates 1 TRT engine(s)
Timing:
Min=9.578495979309082 ms, Mean=9.71664638519287 ms, Max=10.313728332519531 ms
assert_close passed |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Labels
cla signed
component: api [Python]
Issues re: Python API
component: conversion
Issues re: Conversion stage
component: converters
Issues re: Specific op converters
component: dynamo
Issues relating to the `torch.compile` or `torch._dynamo.export` paths
component: lowering
Issues re: The lowering / preprocessing passes
component: tests
Issues re: Tests
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Description
Current scaled_dot_product_attention lowering pass doesn't properly handle
attn_mask
and hence causes large output differences in some Transformer models. Besides, a newenable_gqa
argument was added in PyTorch 2.5, which is not handled in the lowering pass.Using code modified from #3252:
Before Patch
After Patch
The difference is still large if running the above code with
dtype = torch.half
due to LayerNorm's compute_precision is being set. That can be resolved after merging #3272.Type of change
Checklist: