Skip to content

Commit

Permalink
chore: wrapped module runtime api draft
Browse files Browse the repository at this point in the history
  • Loading branch information
keehyuna committed Nov 26, 2024
1 parent 7e22f61 commit 711930f
Show file tree
Hide file tree
Showing 8 changed files with 120 additions and 74 deletions.
10 changes: 4 additions & 6 deletions py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,7 @@ def compile(
use_explicit_typing: bool = _defaults.USE_EXPLICIT_TYPING,
use_fp32_acc: bool = _defaults.USE_FP32_ACC,
enable_weight_streaming: bool = _defaults.ENABLE_WEIGHT_STREAMING,
enable_wrapper_module: bool = _defaults.ENABLE_WRAPPER_MODULE,
**kwargs: Any,
) -> torch.fx.GraphModule:
"""Compile an ExportedProgram module for NVIDIA GPUs using TensorRT
Expand Down Expand Up @@ -592,6 +593,7 @@ def compile(
"use_fp32_acc": use_fp32_acc,
"enable_cross_compile_for_windows": False,
"enable_weight_streaming": enable_weight_streaming,
"enable_wrapper_module": enable_wrapper_module,
}

settings = CompilationSettings(**compilation_options)
Expand Down Expand Up @@ -835,13 +837,9 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:

dryrun_stats_display(dryrun_tracker, settings.dryrun)

if len(dryrun_tracker.to_run_in_torch) > 0:
if settings.enable_wrapper_module:
# Capture/replay a series of CUDA operations in subgraphs in a wrapped runtime module.
partitioned_module = WrapperTorchTensorRTModule(
partitioned_module,
dryrun_tracker.output_shapes,
dryrun_tracker.output_dtypes,
)
partitioned_module = WrapperTorchTensorRTModule(partitioned_module)

return partitioned_module

Expand Down
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
USE_FP32_ACC = False
ENABLE_WEIGHT_STREAMING = False
ENABLE_CROSS_COMPILE_FOR_WINDOWS = False
ENABLE_WRAPPER_MODULE = False


def default_device() -> Device:
Expand Down
2 changes: 2 additions & 0 deletions py/torch_tensorrt/dynamo/_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
ENABLE_CROSS_COMPILE_FOR_WINDOWS,
ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
ENABLE_WEIGHT_STREAMING,
ENABLE_WRAPPER_MODULE,
ENABLED_PRECISIONS,
ENGINE_CAPABILITY,
HARDWARE_COMPATIBLE,
Expand Down Expand Up @@ -125,6 +126,7 @@ class CompilationSettings:
use_fp32_acc: bool = USE_FP32_ACC
enable_weight_streaming: bool = ENABLE_WEIGHT_STREAMING
enable_cross_compile_for_windows: bool = ENABLE_CROSS_COMPILE_FOR_WINDOWS
enable_wrapper_module: bool = ENABLE_WRAPPER_MODULE


_SETTINGS_TO_BE_ENGINE_INVARIANT = (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -250,8 +250,6 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
(i.contiguous() if isinstance(i, torch.Tensor) else torch.tensor(i).cuda())
for i in inputs
]
# TODO: calculate output shape under fakeTensorMode
# fake_mode = detect_fake_mode(*inputs)
with (
torch.autograd.profiler.record_function("PythonTorchTensorRTModule:Forward")
if self.profiling_enabled
Expand Down
52 changes: 1 addition & 51 deletions py/torch_tensorrt/dynamo/runtime/_WrapperTorchTensorRTModule.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,8 @@

import torch
import torch_tensorrt
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.fx.experimental.proxy_tensor import unset_fake_temporarily
from torch_tensorrt.dynamo import partitioning
from torch_tensorrt.dynamo.conversion import DYNAMIC_DIM
from torch_tensorrt.dynamo.utils import input_is_dynamic
from torch_tensorrt.runtime._utils import _is_switch_required, _select_rt_device

logger = logging.getLogger(__name__)
Expand All @@ -21,25 +18,18 @@ class WrapperTorchTensorRTModule(torch.nn.Module): # type: ignore[misc]
"""This Wrapper runtime module is to record/replay whole cuda graph in sub modules
Args:
original_module: Unmodified FX GraphModule
compiled_module: Complied fx graphModule that will be wrapped
output_shapes: Shapes of output Tensors of the graph
output_dtypes: Output data types of the graph
Returns:
Output tensor or tensor list
"""

def __init__(
self,
compiled_module: torch.nn.Module,
output_shapes: List[torch.Size],
output_dtypes: List[torch.dtype],
):
super(WrapperTorchTensorRTModule, self).__init__()
self.compiled_module = compiled_module
self.inputs = partitioning.construct_submodule_inputs(compiled_module)
self.output_shapes = output_shapes
self.output_dtypes = output_dtypes

self._input_buffers: List[torch.Tensor] = []
self._output_buffers: List[torch.Tensor] = []
Expand All @@ -49,7 +39,6 @@ def __init__(
self.prev_cudagraphs_enabled = False
self._caller_stream: Optional[torch.cuda.Stream] = None
self._engine_stream: Optional[torch.cuda.Stream] = None
self.input_is_dynamic = input_is_dynamic(self.inputs)

# Disable cudagrphs in submodules as it will be enabled in wrapper
for name, rt_mod in self.compiled_module.named_children():
Expand Down Expand Up @@ -82,18 +71,9 @@ def validate_input_shapes(self, inputs: Sequence[torch.Tensor]) -> bool:
# x: (3, 4), y: (4, 5) --> Key: (3,4)(4,5)
new_shape_key = "".join(str(tuple(t.shape)).replace(" ", "") for t in inputs)

# If the new shape key differs from the existing one, infer new output shape
if new_shape_key != self.shape_key:
logger.debug(f"Input shape changed {self.shape_key} -> {new_shape_key}")
self.shape_key = new_shape_key

if self.input_is_dynamic:
with FakeTensorMode(allow_non_fake_inputs=True):
tmp_outputs = self.compiled_module(*inputs)
if not isinstance(tmp_outputs, (list, tuple)):
tmp_outputs = [tmp_outputs]
self.output_shapes = [tuple(output.shape) for output in tmp_outputs]
print("self.output_shapes ", self.output_shapes)
return True

return False
Expand Down Expand Up @@ -128,7 +108,6 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
self.cudagraph.reset()

self._input_buffers = [None] * len(self.inputs)
self._output_buffers = [None] * len(self.output_shapes)

if not cudagraphs_enabled and self.cudagraph:
self.cudagraph.reset()
Expand Down Expand Up @@ -202,32 +181,6 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
elif cudagraphs_enabled:
self._input_buffers[i].copy_(contiguous_inputs[i])

with (
torch.autograd.profiler.record_function(
"WrapperTorchTensorRTModule:ProcessOutputs"
)
if self.profiling_enabled
else nullcontext()
):
# create output tensors
outputs: List[torch.Tensor] = []

for o, shape in enumerate(self.output_shapes):
if DYNAMIC_DIM in shape:
raise ValueError(
"Encountered dynamic output shapes during runtime. This could mean the network has data-dependent output shapes which is not currently supported."
)

output = torch.empty(
size=shape,
dtype=self.output_dtypes[o],
device=torch.cuda.current_device(),
)

outputs.append(output)

if need_cudagraphs_record:
self._output_buffers[o] = outputs[o].clone()
with (
torch.autograd.profiler.record_function(
"WrapperTorchTensorRTModule:TensorRTRuntime"
Expand Down Expand Up @@ -277,13 +230,10 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
output_buffers = self._output_buffers
else:
output_buffers = [self._output_buffers]
for idx, o in enumerate(outputs):
o.copy_(output_buffers[idx])

outputs = [output.clone() for output in output_buffers]
if len(outputs) == 1:
return outputs[0]

return outputs
else:

return outputs
11 changes: 1 addition & 10 deletions py/torch_tensorrt/dynamo/runtime/register_fake_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from typing import Any, List

import torch
from torch._library.fake_class_registry import FakeScriptObject
from torch_tensorrt.dynamo.utils import input_is_dynamic, unwrap_tensor_shape


Expand All @@ -27,12 +26,7 @@ def fake_tensorrt_execute_engine(
modes = ["opt"]

# Get the TRTEngine class and infer output shapes based on input shapes
# If fake_trt_engine is not FakeScriptObject, assumes that it is the real object
if isinstance(fake_trt_engine, FakeScriptObject):
trt_engine = fake_trt_engine.wrapped_obj.engine
else:
trt_engine = fake_trt_engine

trt_engine = fake_trt_engine.wrapped_obj.engine
outputs_mode_dict = defaultdict(list)
for mode in modes:
input_shapes = [unwrap_tensor_shape(input, mode=mode) for input in inputs]
Expand Down Expand Up @@ -131,8 +125,5 @@ def automatic_device_memory_budget_getter(self) -> Any:
def infer_outputs(self, input_shapes: List[Any]) -> Any:
pass

def set_whole_cudagraphs(self) -> Any:
pass

def __setstate__(self, serialized_state: List[str]) -> Any:
pass
19 changes: 14 additions & 5 deletions py/torch_tensorrt/runtime/_cudagraphs.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import logging
from typing import Any
from typing import Any, Optional

import torch
import torch_tensorrt
from torch_tensorrt.dynamo.runtime._WrapperTorchTensorRTModule import (
WrapperTorchTensorRTModule,
)

if torch_tensorrt.ENABLED_FEATURES.torch_tensorrt_runtime:
_PY_RT_CUDAGRAPHS = torch.ops.tensorrt.get_cudagraphs_mode()
Expand Down Expand Up @@ -37,19 +40,25 @@ class _CudagraphsContextManager(object):
Used to enable cudagraphs as a context manager
"""

def __init__(self) -> None:
def __init__(self, module_to_wrap: Optional[torch.nn.Module]) -> None:
global _PY_RT_CUDAGRAPHS
self.old_mode = _PY_RT_CUDAGRAPHS
self.module_to_wrap = module_to_wrap

def __enter__(self) -> "_CudagraphsContextManager":
# Enable cudagraphs
set_cudagraphs_mode(True)
return self
if self.module_to_wrap:
return WrapperTorchTensorRTModule(self.module_to_wrap)
else:
return self

def __exit__(self, *args: Any) -> None:
# Set cudagraphs back to old mode
set_cudagraphs_mode(self.old_mode)


def enable_cudagraphs() -> _CudagraphsContextManager:
return _CudagraphsContextManager()
def enable_cudagraphs(
module_to_wrap: Optional[torch.nn.Module] = None,
) -> _CudagraphsContextManager:
return _CudagraphsContextManager(module_to_wrap)
97 changes: 97 additions & 0 deletions tests/py/dynamo/runtime/test_005_wrapper_cudagraphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
import torch_tensorrt as torchtrt
from parameterized import parameterized
from torch.testing._internal.common_utils import TestCase, run_tests
from torch_tensorrt.dynamo.runtime._WrapperTorchTensorRTModule import (
WrapperTorchTensorRTModule,
)

INPUT_SIZE = (3, 16, 16)
TRIALS = 5
Expand Down Expand Up @@ -197,6 +200,100 @@ def forward(self, x):
)
torch._dynamo.reset()

@parameterized.expand(
[
("python_runtime", True),
("cpp_runtime", False),
]
)
def test_wrapper_cudagraphs_api(self, _, use_python_runtime):
"""
3 api draft
"""

class SampleModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv1d(64, 6, 3)
self.relu = torch.nn.ReLU()

def forward(self, x):
out = 1 + self.conv(x)
out = self.relu(out)
return out

model = SampleModel().eval().cuda()
input_list = []
trt_out_list = []
ref_out_list = []

for _ in range(TRIALS):
input = [torch.randn((64, 32), dtype=torch.float32).cuda()]
input_list.append(input)
fx_graph = torch.fx.symbolic_trace(model)

# 1. Compiler option: enable_wrapper_module=True
optimized_model = torchtrt.compile(
fx_graph,
inputs=input_list[0],
ir="dynamo",
min_block_size=1,
cache_built_engines=False,
reuse_cached_engines=False,
torch_executed_ops={"torch.ops.aten.convolution.default"},
use_python_runtime=use_python_runtime,
enable_wrapper_module=True,
)

with torchtrt.runtime.enable_cudagraphs():
for i in range(TRIALS):
trt_out_list.append(optimized_model(*input_list[i]))
ref_out_list.append(fx_graph(*input_list[i]))

# Compiler again to generate normal module
optimized_model = torchtrt.compile(
fx_graph,
inputs=input_list[0],
ir="dynamo",
min_block_size=1,
cache_built_engines=False,
reuse_cached_engines=False,
torch_executed_ops={"torch.ops.aten.convolution.default"},
use_python_runtime=use_python_runtime,
)
# This is current cuda runtime api
with torchtrt.runtime.enable_cudagraphs():
for i in range(TRIALS):
trt_out_list.append(optimized_model(*input_list[i]))
ref_out_list.append(fx_graph(*input_list[i]))

# 2. Optional parameter in existing cuda runtime api
# WrapperTorchTensorRTModule can be simplified to have only cuda graph path
with torchtrt.runtime.enable_cudagraphs(optimized_model) as wrapped_module:
for i in range(TRIALS):
trt_out_list.append(wrapped_module(*input_list[i]))
ref_out_list.append(fx_graph(*input_list[i]))

# 3. Use Wrapper module directly
wrapped_module = WrapperTorchTensorRTModule(optimized_model)
with torchtrt.runtime.enable_cudagraphs():
for i in range(TRIALS):
trt_out_list.append(wrapped_module(*input_list[i]))
ref_out_list.append(fx_graph(*input_list[i]))

for optimized_model_results, torch_model_results in zip(
trt_out_list, ref_out_list
):
torch.testing.assert_close(
torch_model_results,
optimized_model_results,
rtol=5e-03,
atol=5e-03,
equal_nan=True,
check_dtype=True,
)
torch._dynamo.reset()


if __name__ == "__main__":
run_tests()

0 comments on commit 711930f

Please sign in to comment.