Skip to content

Commit

Permalink
feat: Lazy engine initialization (#2997)
Browse files Browse the repository at this point in the history
Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>
  • Loading branch information
narendasan authored Aug 5, 2024
1 parent 577c5c4 commit 1d5dd56
Show file tree
Hide file tree
Showing 16 changed files with 523 additions and 107 deletions.
4 changes: 2 additions & 2 deletions py/torch_tensorrt/_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from torch.export import ExportedProgram
from torch_tensorrt.dynamo._compiler import compile as dynamo_compile
from torch_tensorrt.dynamo._compiler import (
convert_module_to_trt_engine as dynamo_convert_module_to_trt_engine,
convert_exported_program_to_serialized_trt_engine as dynamo_convert_exported_program_to_serialized_trt_engine,
)
from torch_tensorrt.dynamo._tracer import trace as dynamo_trace

Expand Down Expand Up @@ -351,7 +351,7 @@ def convert_method_to_trt_engine(
torchtrt_inputs = prepare_inputs(inputs)
exp_program = torch_tensorrt.dynamo.trace(module, torchtrt_inputs, **kwargs)

return dynamo_convert_module_to_trt_engine(
return dynamo_convert_exported_program_to_serialized_trt_engine(
exp_program,
inputs=tuple(inputs),
enabled_precisions=enabled_precisions_set,
Expand Down
2 changes: 1 addition & 1 deletion py/torch_tensorrt/dynamo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
logger = logging.getLogger(__name__)

if version.parse(sanitized_torch_version()) >= version.parse("2.1.dev"):
from ._compiler import compile, convert_module_to_trt_engine
from ._compiler import compile, convert_exported_program_to_serialized_trt_engine
from ._exporter import export
from ._refit import refit_module_weights
from ._settings import CompilationSettings
Expand Down
16 changes: 8 additions & 8 deletions py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def compile(
dryrun: bool = _defaults.DRYRUN,
hardware_compatible: bool = _defaults.HARDWARE_COMPATIBLE,
timing_cache_path: str = _defaults.TIMING_CACHE_PATH,
lazy_engine_init: bool = _defaults.LAZY_ENGINE_INIT,
**kwargs: Any,
) -> torch.fx.GraphModule:
"""Compile an ExportedProgram module for NVIDIA GPUs using TensorRT
Expand Down Expand Up @@ -141,6 +142,7 @@ def compile(
dryrun (bool): Toggle for "Dryrun" mode, running everything except conversion to TRT and logging outputs
hardware_compatible (bool): Build the TensorRT engines compatible with GPU architectures other than that of the GPU on which the engine was built (currently works for NVIDIA Ampere and newer)
timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation
lazy_engine_init (bool): Defer setting up engines until the compilation of all engines is complete. Can allow larger models with multiple graph breaks to compile but can lead to oversubscription of GPU memory at runtime.
**kwargs: Any,
Returns:
torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT
Expand Down Expand Up @@ -236,6 +238,7 @@ def compile(
"dryrun": dryrun,
"hardware_compatible": hardware_compatible,
"timing_cache_path": timing_cache_path,
"lazy_engine_init": lazy_engine_init,
}

settings = CompilationSettings(**compilation_options)
Expand Down Expand Up @@ -454,6 +457,8 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
# Replace all FX Modules with TRT Modules
for name, trt_module in trt_modules.items():
setattr(partitioned_module, name, trt_module)
if settings.lazy_engine_init:
getattr(partitioned_module, name).setup_engine()

# Reset settings object to user specification after fallback to global partitioning mode
if fast_partitioner_failed:
Expand All @@ -464,7 +469,7 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
return partitioned_module


def convert_module_to_trt_engine(
def convert_exported_program_to_serialized_trt_engine(
exported_program: ExportedProgram,
inputs: Sequence[Any],
*,
Expand Down Expand Up @@ -647,10 +652,5 @@ def convert_module_to_trt_engine(
exc_info=True,
)

import io

with io.BytesIO() as engine_bytes:
engine_bytes.write(interpreter_result.engine)
engine_bytearray: bytes = engine_bytes.getvalue()

return engine_bytearray
serialized_engine: bytes = interpreter_result.serialized_engine
return serialized_engine
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
HARDWARE_COMPATIBLE = False
SUPPORTED_KERNEL_PRECISIONS = {dtype.f32, dtype.f16, dtype.bf16, dtype.i8, dtype.f8}
TIMING_CACHE_PATH = os.path.join(tempfile.gettempdir(), "timing_cache.bin")
LAZY_ENGINE_INIT = False


def default_device() -> Device:
Expand Down
2 changes: 2 additions & 0 deletions py/torch_tensorrt/dynamo/_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
ENABLED_PRECISIONS,
ENGINE_CAPABILITY,
HARDWARE_COMPATIBLE,
LAZY_ENGINE_INIT,
MAKE_REFITABLE,
MAX_AUX_STREAMS,
MIN_BLOCK_SIZE,
Expand Down Expand Up @@ -104,3 +105,4 @@ class CompilationSettings:
dryrun: Union[bool, str] = DRYRUN
hardware_compatible: bool = HARDWARE_COMPATIBLE
timing_cache_path: str = TIMING_CACHE_PATH
lazy_engine_init: bool = LAZY_ENGINE_INIT
13 changes: 8 additions & 5 deletions py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import io
import logging
import os
import warnings
from datetime import datetime
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Set, Tuple

import numpy as np
import tensorrt as trt
import torch
import torch.fx
from torch.fx.node import _get_qualified_name
Expand All @@ -29,6 +29,7 @@
from torch_tensorrt.fx.observer import Observer
from torch_tensorrt.logging import TRT_LOGGER

import tensorrt as trt
from packaging import version

_LOGGER: logging.Logger = logging.getLogger(__name__)
Expand All @@ -43,7 +44,7 @@ class UnsupportedOperatorException(RuntimeError):


class TRTInterpreterResult(NamedTuple):
engine: Any
serialized_engine: bytes
input_names: Sequence[str]
output_names: Sequence[str]

Expand Down Expand Up @@ -358,9 +359,11 @@ def run(
builder_config, self.compilation_settings.timing_cache_path
)

return TRTInterpreterResult(
serialized_engine, self._input_names, self._output_names
)
with io.BytesIO() as engine_bytes:
engine_bytes.write(serialized_engine)
engine_str = engine_bytes.getvalue()

return TRTInterpreterResult(engine_str, self._input_names, self._output_names)

def run_node(self, n: torch.fx.Node) -> torch.fx.Node:
self._cur_node_name = get_node_name(n)
Expand Down
42 changes: 19 additions & 23 deletions py/torch_tensorrt/dynamo/conversion/_conversion.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

import io
import logging
from typing import List, Sequence

Expand Down Expand Up @@ -102,33 +101,30 @@ def convert_module(
settings: Compilation settings
name: TRT engine name
Returns:
_PythonTorchTensorRTModule or TorchTensorRTModule
PythonTorchTensorRTModule or TorchTensorRTModule
"""
interpreter_result = interpret_module_to_result(module, inputs, settings)

if settings.use_python_runtime or not ENABLED_FEATURES.torch_tensorrt_runtime:
if not settings.use_python_runtime:
logger.info(
"Since Torch-TensorRT runtime is not available, using Python Runtime, some features may not be available"
)
return PythonTorchTensorRTModule(
engine=interpreter_result.engine,
input_names=list(interpreter_result.input_names),
output_names=list(interpreter_result.output_names),
settings=settings,
)
rt_cls = PythonTorchTensorRTModule

if ENABLED_FEATURES.torch_tensorrt_runtime and not settings.use_python_runtime:

else:
from torch_tensorrt.dynamo.runtime import TorchTensorRTModule

with io.BytesIO() as engine_bytes:
engine_bytes.write(interpreter_result.engine)
engine_str = engine_bytes.getvalue()
rt_cls = TorchTensorRTModule

elif (
not ENABLED_FEATURES.torch_tensorrt_runtime and not settings.use_python_runtime
):

return TorchTensorRTModule(
serialized_engine=engine_str,
name=name,
input_binding_names=list(interpreter_result.input_names),
output_binding_names=list(interpreter_result.output_names),
settings=settings,
logger.info(
"Since Torch-TensorRT runtime is not available, using Python Runtime, some features may not be available"
)

return rt_cls(
serialized_engine=interpreter_result.serialized_engine,
input_binding_names=list(interpreter_result.input_names),
output_binding_names=list(interpreter_result.output_names),
name=name,
settings=settings,
)
72 changes: 51 additions & 21 deletions py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from contextlib import nullcontext
from typing import Any, Dict, List, Optional, Sequence, Tuple

import tensorrt as trt
import torch
import torch_tensorrt
from torch.nn import Module
from torch_tensorrt._Device import Device
from torch_tensorrt._enums import dtype
Expand All @@ -18,7 +18,7 @@
from torch_tensorrt.dynamo.utils import DYNAMIC_DIM
from torch_tensorrt.logging import TRT_LOGGER

import torch_tensorrt
import tensorrt as trt

logger = logging.getLogger(__name__)

Expand All @@ -32,17 +32,45 @@ class PythonTorchTensorRTModule(Module): # type: ignore[misc]

def __init__(
self,
engine: bytes,
input_names: Optional[List[str]] = None,
output_names: Optional[List[str]] = None,
serialized_engine: Optional[bytes] = None,
input_binding_names: Optional[List[str]] = None,
output_binding_names: Optional[List[str]] = None,
*,
name: str = "",
settings: CompilationSettings = CompilationSettings(),
):
"""Takes a name, target device, serialized TensorRT engine, and binding names / order and constructs
a PyTorch ``torch.nn.Module`` around it. Uses TensorRT Python APIs to run the engine
Arguments:
serialized_engine (bytes): Serialized TensorRT engine in the form of a bytearray
input_binding_names (List[str]): List of input TensorRT engine binding names in the order they would be passed to the TRT modules
output_binding_names (List[str]): List of output TensorRT engine binding names in the order they should be returned
Keyword Arguments:
name (str): Name for module
settings (torch_tensorrt.dynamo.CompilationSettings): Settings used to compile engine, assumes engine was built with default compilation settings if object not passed
Example:
.. code-block:: py
trt_module = PythonTorchTensorRTModule(
engine_str,
input_binding_names=["x"],
output_binding_names=["output"],
name="my_module",
settings=CompilationSettings(device=torch.cuda.current_device)
)
"""
super(PythonTorchTensorRTModule, self).__init__()
self._register_state_dict_hook(PythonTorchTensorRTModule._on_state_dict)

# Run multi-gpu device check to validate engine instantiation
multi_gpu_device_check()

self.name = name
self.input_buffers: List[torch.Tensor] = []
self.output_buffers: List[torch.Tensor] = []
self.cudagraph: Optional[torch.cuda.CUDAGraph] = None
Expand All @@ -55,9 +83,13 @@ def __init__(
# Unused currently - to be used by Dynamic Shape support implementation
self.memory_pool = None

self.engine = engine
self.input_names = input_names if input_names is not None else []
self.output_names = output_names if output_names is not None else []
self.serialized_engine = serialized_engine
self.input_names = (
input_binding_names if input_binding_names is not None else []
)
self.output_names = (
output_binding_names if output_binding_names is not None else []
)
self.initialized = False
self.target_device_id = (
settings.device.gpu_id
Expand All @@ -69,12 +101,15 @@ def __init__(
)
self.profiling_enabled = settings.debug if settings.debug is not None else False
self.settings = settings
self._initialize()
self.engine = None

if self.serialized_engine is not None and not self.settings.lazy_engine_init:
self.setup_engine()

def _initialize(self) -> None:
def setup_engine(self) -> None:
self.initialized = True
runtime = trt.Runtime(TRT_LOGGER)
self.engine = runtime.deserialize_cuda_engine(self.engine)
self.engine = runtime.deserialize_cuda_engine(self.serialized_engine)
self.context = self.engine.create_execution_context()

assert self.engine.num_io_tensors == (
Expand Down Expand Up @@ -114,8 +149,7 @@ def _check_initialized(self) -> None:
raise RuntimeError("PythonTorchTensorRTModule is not initialized.")

def _on_state_dict(self, state_dict: Dict[str, Any], prefix: str, _: Any) -> None:
self._check_initialized()
state_dict[prefix + "engine"] = bytearray(self.engine.serialize())
state_dict[prefix + "engine"] = self.serialized_engine
state_dict[prefix + "input_names"] = self.input_names
state_dict[prefix + "output_names"] = self.output_names

Expand All @@ -129,17 +163,13 @@ def _load_from_state_dict(
unexpected_keys: Any,
error_msgs: Any,
) -> None:
engine_bytes = state_dict[prefix + "engine"]
self.serialized_engine = state_dict[prefix + "engine"]
self.input_names = state_dict[prefix + "input_names"]
self.output_names = state_dict[prefix + "output_names"]

# Run multi-gpu device check to validate engine instantiation
multi_gpu_device_check()

runtime = trt.Runtime(TRT_LOGGER)
self.engine = runtime.deserialize_cuda_engine(engine_bytes)

self.input_names = state_dict[prefix + "input_names"]
self.output_names = state_dict[prefix + "output_names"]
self._initialize()
self.setup_engine()

def __getstate__(self) -> Dict[str, Any]:
state = self.__dict__.copy()
Expand Down
Loading

0 comments on commit 1d5dd56

Please sign in to comment.