diff --git a/py/torch_tensorrt/_compile.py b/py/torch_tensorrt/_compile.py index 89a965fee7..8a168e9001 100644 --- a/py/torch_tensorrt/_compile.py +++ b/py/torch_tensorrt/_compile.py @@ -27,7 +27,7 @@ from torch.export import ExportedProgram from torch_tensorrt.dynamo._compiler import compile as dynamo_compile from torch_tensorrt.dynamo._compiler import ( - convert_module_to_trt_engine as dynamo_convert_module_to_trt_engine, + convert_exported_program_to_serialized_trt_engine as dynamo_convert_exported_program_to_serialized_trt_engine, ) from torch_tensorrt.dynamo._tracer import trace as dynamo_trace @@ -351,7 +351,7 @@ def convert_method_to_trt_engine( torchtrt_inputs = prepare_inputs(inputs) exp_program = torch_tensorrt.dynamo.trace(module, torchtrt_inputs, **kwargs) - return dynamo_convert_module_to_trt_engine( + return dynamo_convert_exported_program_to_serialized_trt_engine( exp_program, inputs=tuple(inputs), enabled_precisions=enabled_precisions_set, diff --git a/py/torch_tensorrt/dynamo/__init__.py b/py/torch_tensorrt/dynamo/__init__.py index 83597db0b6..79bd113ab8 100644 --- a/py/torch_tensorrt/dynamo/__init__.py +++ b/py/torch_tensorrt/dynamo/__init__.py @@ -7,7 +7,7 @@ logger = logging.getLogger(__name__) if version.parse(sanitized_torch_version()) >= version.parse("2.1.dev"): - from ._compiler import compile, convert_module_to_trt_engine + from ._compiler import compile, convert_exported_program_to_serialized_trt_engine from ._exporter import export from ._refit import refit_module_weights from ._settings import CompilationSettings diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 4dcf5a22e4..0362a010df 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -79,6 +79,7 @@ def compile( dryrun: bool = _defaults.DRYRUN, hardware_compatible: bool = _defaults.HARDWARE_COMPATIBLE, timing_cache_path: str = _defaults.TIMING_CACHE_PATH, + lazy_engine_init: bool = _defaults.LAZY_ENGINE_INIT, **kwargs: Any, ) -> torch.fx.GraphModule: """Compile an ExportedProgram module for NVIDIA GPUs using TensorRT @@ -141,6 +142,7 @@ def compile( dryrun (bool): Toggle for "Dryrun" mode, running everything except conversion to TRT and logging outputs hardware_compatible (bool): Build the TensorRT engines compatible with GPU architectures other than that of the GPU on which the engine was built (currently works for NVIDIA Ampere and newer) timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation + lazy_engine_init (bool): Defer setting up engines until the compilation of all engines is complete. Can allow larger models with multiple graph breaks to compile but can lead to oversubscription of GPU memory at runtime. **kwargs: Any, Returns: torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT @@ -236,6 +238,7 @@ def compile( "dryrun": dryrun, "hardware_compatible": hardware_compatible, "timing_cache_path": timing_cache_path, + "lazy_engine_init": lazy_engine_init, } settings = CompilationSettings(**compilation_options) @@ -454,6 +457,8 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool: # Replace all FX Modules with TRT Modules for name, trt_module in trt_modules.items(): setattr(partitioned_module, name, trt_module) + if settings.lazy_engine_init: + getattr(partitioned_module, name).setup_engine() # Reset settings object to user specification after fallback to global partitioning mode if fast_partitioner_failed: @@ -464,7 +469,7 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool: return partitioned_module -def convert_module_to_trt_engine( +def convert_exported_program_to_serialized_trt_engine( exported_program: ExportedProgram, inputs: Sequence[Any], *, @@ -647,10 +652,5 @@ def convert_module_to_trt_engine( exc_info=True, ) - import io - - with io.BytesIO() as engine_bytes: - engine_bytes.write(interpreter_result.engine) - engine_bytearray: bytes = engine_bytes.getvalue() - - return engine_bytearray + serialized_engine: bytes = interpreter_result.serialized_engine + return serialized_engine diff --git a/py/torch_tensorrt/dynamo/_defaults.py b/py/torch_tensorrt/dynamo/_defaults.py index dbf0265496..2696e26936 100644 --- a/py/torch_tensorrt/dynamo/_defaults.py +++ b/py/torch_tensorrt/dynamo/_defaults.py @@ -32,6 +32,7 @@ HARDWARE_COMPATIBLE = False SUPPORTED_KERNEL_PRECISIONS = {dtype.f32, dtype.f16, dtype.bf16, dtype.i8, dtype.f8} TIMING_CACHE_PATH = os.path.join(tempfile.gettempdir(), "timing_cache.bin") +LAZY_ENGINE_INIT = False def default_device() -> Device: diff --git a/py/torch_tensorrt/dynamo/_settings.py b/py/torch_tensorrt/dynamo/_settings.py index 13c786b858..4a9792d3dc 100644 --- a/py/torch_tensorrt/dynamo/_settings.py +++ b/py/torch_tensorrt/dynamo/_settings.py @@ -16,6 +16,7 @@ ENABLED_PRECISIONS, ENGINE_CAPABILITY, HARDWARE_COMPATIBLE, + LAZY_ENGINE_INIT, MAKE_REFITABLE, MAX_AUX_STREAMS, MIN_BLOCK_SIZE, @@ -104,3 +105,4 @@ class CompilationSettings: dryrun: Union[bool, str] = DRYRUN hardware_compatible: bool = HARDWARE_COMPATIBLE timing_cache_path: str = TIMING_CACHE_PATH + lazy_engine_init: bool = LAZY_ENGINE_INIT diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index 09fcccf5d8..703a650c99 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -1,3 +1,4 @@ +import io import logging import os import warnings @@ -5,7 +6,6 @@ from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Set, Tuple import numpy as np -import tensorrt as trt import torch import torch.fx from torch.fx.node import _get_qualified_name @@ -29,6 +29,7 @@ from torch_tensorrt.fx.observer import Observer from torch_tensorrt.logging import TRT_LOGGER +import tensorrt as trt from packaging import version _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -43,7 +44,7 @@ class UnsupportedOperatorException(RuntimeError): class TRTInterpreterResult(NamedTuple): - engine: Any + serialized_engine: bytes input_names: Sequence[str] output_names: Sequence[str] @@ -358,9 +359,11 @@ def run( builder_config, self.compilation_settings.timing_cache_path ) - return TRTInterpreterResult( - serialized_engine, self._input_names, self._output_names - ) + with io.BytesIO() as engine_bytes: + engine_bytes.write(serialized_engine) + engine_str = engine_bytes.getvalue() + + return TRTInterpreterResult(engine_str, self._input_names, self._output_names) def run_node(self, n: torch.fx.Node) -> torch.fx.Node: self._cur_node_name = get_node_name(n) diff --git a/py/torch_tensorrt/dynamo/conversion/_conversion.py b/py/torch_tensorrt/dynamo/conversion/_conversion.py index ea3034cb8c..8f22a6c993 100644 --- a/py/torch_tensorrt/dynamo/conversion/_conversion.py +++ b/py/torch_tensorrt/dynamo/conversion/_conversion.py @@ -1,6 +1,5 @@ from __future__ import annotations -import io import logging from typing import List, Sequence @@ -102,33 +101,30 @@ def convert_module( settings: Compilation settings name: TRT engine name Returns: - _PythonTorchTensorRTModule or TorchTensorRTModule + PythonTorchTensorRTModule or TorchTensorRTModule """ interpreter_result = interpret_module_to_result(module, inputs, settings) - if settings.use_python_runtime or not ENABLED_FEATURES.torch_tensorrt_runtime: - if not settings.use_python_runtime: - logger.info( - "Since Torch-TensorRT runtime is not available, using Python Runtime, some features may not be available" - ) - return PythonTorchTensorRTModule( - engine=interpreter_result.engine, - input_names=list(interpreter_result.input_names), - output_names=list(interpreter_result.output_names), - settings=settings, - ) + rt_cls = PythonTorchTensorRTModule + + if ENABLED_FEATURES.torch_tensorrt_runtime and not settings.use_python_runtime: - else: from torch_tensorrt.dynamo.runtime import TorchTensorRTModule - with io.BytesIO() as engine_bytes: - engine_bytes.write(interpreter_result.engine) - engine_str = engine_bytes.getvalue() + rt_cls = TorchTensorRTModule + + elif ( + not ENABLED_FEATURES.torch_tensorrt_runtime and not settings.use_python_runtime + ): - return TorchTensorRTModule( - serialized_engine=engine_str, - name=name, - input_binding_names=list(interpreter_result.input_names), - output_binding_names=list(interpreter_result.output_names), - settings=settings, + logger.info( + "Since Torch-TensorRT runtime is not available, using Python Runtime, some features may not be available" ) + + return rt_cls( + serialized_engine=interpreter_result.serialized_engine, + input_binding_names=list(interpreter_result.input_names), + output_binding_names=list(interpreter_result.output_names), + name=name, + settings=settings, + ) diff --git a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py index 6c94b112a7..659f18af52 100644 --- a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py @@ -4,8 +4,8 @@ from contextlib import nullcontext from typing import Any, Dict, List, Optional, Sequence, Tuple -import tensorrt as trt import torch +import torch_tensorrt from torch.nn import Module from torch_tensorrt._Device import Device from torch_tensorrt._enums import dtype @@ -18,7 +18,7 @@ from torch_tensorrt.dynamo.utils import DYNAMIC_DIM from torch_tensorrt.logging import TRT_LOGGER -import torch_tensorrt +import tensorrt as trt logger = logging.getLogger(__name__) @@ -32,17 +32,45 @@ class PythonTorchTensorRTModule(Module): # type: ignore[misc] def __init__( self, - engine: bytes, - input_names: Optional[List[str]] = None, - output_names: Optional[List[str]] = None, + serialized_engine: Optional[bytes] = None, + input_binding_names: Optional[List[str]] = None, + output_binding_names: Optional[List[str]] = None, + *, + name: str = "", settings: CompilationSettings = CompilationSettings(), ): + """Takes a name, target device, serialized TensorRT engine, and binding names / order and constructs + a PyTorch ``torch.nn.Module`` around it. Uses TensorRT Python APIs to run the engine + + Arguments: + serialized_engine (bytes): Serialized TensorRT engine in the form of a bytearray + input_binding_names (List[str]): List of input TensorRT engine binding names in the order they would be passed to the TRT modules + output_binding_names (List[str]): List of output TensorRT engine binding names in the order they should be returned + + Keyword Arguments: + name (str): Name for module + settings (torch_tensorrt.dynamo.CompilationSettings): Settings used to compile engine, assumes engine was built with default compilation settings if object not passed + + Example: + + .. code-block:: py + + trt_module = PythonTorchTensorRTModule( + engine_str, + input_binding_names=["x"], + output_binding_names=["output"], + name="my_module", + settings=CompilationSettings(device=torch.cuda.current_device) + ) + + """ super(PythonTorchTensorRTModule, self).__init__() self._register_state_dict_hook(PythonTorchTensorRTModule._on_state_dict) # Run multi-gpu device check to validate engine instantiation multi_gpu_device_check() + self.name = name self.input_buffers: List[torch.Tensor] = [] self.output_buffers: List[torch.Tensor] = [] self.cudagraph: Optional[torch.cuda.CUDAGraph] = None @@ -55,9 +83,13 @@ def __init__( # Unused currently - to be used by Dynamic Shape support implementation self.memory_pool = None - self.engine = engine - self.input_names = input_names if input_names is not None else [] - self.output_names = output_names if output_names is not None else [] + self.serialized_engine = serialized_engine + self.input_names = ( + input_binding_names if input_binding_names is not None else [] + ) + self.output_names = ( + output_binding_names if output_binding_names is not None else [] + ) self.initialized = False self.target_device_id = ( settings.device.gpu_id @@ -69,12 +101,15 @@ def __init__( ) self.profiling_enabled = settings.debug if settings.debug is not None else False self.settings = settings - self._initialize() + self.engine = None + + if self.serialized_engine is not None and not self.settings.lazy_engine_init: + self.setup_engine() - def _initialize(self) -> None: + def setup_engine(self) -> None: self.initialized = True runtime = trt.Runtime(TRT_LOGGER) - self.engine = runtime.deserialize_cuda_engine(self.engine) + self.engine = runtime.deserialize_cuda_engine(self.serialized_engine) self.context = self.engine.create_execution_context() assert self.engine.num_io_tensors == ( @@ -114,8 +149,7 @@ def _check_initialized(self) -> None: raise RuntimeError("PythonTorchTensorRTModule is not initialized.") def _on_state_dict(self, state_dict: Dict[str, Any], prefix: str, _: Any) -> None: - self._check_initialized() - state_dict[prefix + "engine"] = bytearray(self.engine.serialize()) + state_dict[prefix + "engine"] = self.serialized_engine state_dict[prefix + "input_names"] = self.input_names state_dict[prefix + "output_names"] = self.output_names @@ -129,17 +163,13 @@ def _load_from_state_dict( unexpected_keys: Any, error_msgs: Any, ) -> None: - engine_bytes = state_dict[prefix + "engine"] + self.serialized_engine = state_dict[prefix + "engine"] + self.input_names = state_dict[prefix + "input_names"] + self.output_names = state_dict[prefix + "output_names"] # Run multi-gpu device check to validate engine instantiation multi_gpu_device_check() - - runtime = trt.Runtime(TRT_LOGGER) - self.engine = runtime.deserialize_cuda_engine(engine_bytes) - - self.input_names = state_dict[prefix + "input_names"] - self.output_names = state_dict[prefix + "output_names"] - self._initialize() + self.setup_engine() def __getstate__(self) -> Dict[str, Any]: state = self.__dict__.copy() diff --git a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py index 601147279a..0ab0dd49ca 100644 --- a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py @@ -53,13 +53,14 @@ class TorchTensorRTModule(torch.nn.Module): # type: ignore[misc] def __init__( self, serialized_engine: Optional[bytes] = None, - name: str = "", input_binding_names: Optional[List[str]] = None, output_binding_names: Optional[List[str]] = None, - settings: CompilationSettings = CompilationSettings(), + *, + name: str = "", + settings: CompilationSettings = CompilationSettings(), # Assumes engine was built with default compilation settings if object not passed ): """Takes a name, target device, serialized TensorRT engine, and binding names / order and constructs - a PyTorch ``torch.nn.Module`` around it. + a PyTorch ``torch.nn.Module`` around it. Uses the Torch-TensorRT runtime extension to run the engines If binding names are not provided, it is assumed that the engine binding names follow the following convention: @@ -67,12 +68,13 @@ def __init__( - ex. [x.0, x.1, x.2] -> [y.0] Arguments: - name (str): Name for module - serialized_engine (bytearray): Serialized TensorRT engine in the form of a bytearray + serialized_engine (bytes): Serialized TensorRT engine in the form of a bytearray input_binding_names (List[str]): List of input TensorRT engine binding names in the order they would be passed to the TRT modules output_binding_names (List[str]): List of output TensorRT engine binding names in the order they should be returned - target_device (torch_tensorrt.Device): Device to instantiate TensorRT engine on. Must be a compatible device i.e. same GPU model / compute capability as was used to build the engine - hardware_compatible (bool): If the engine has be built with the hardware compatibility feature enabled + + Keyword Arguments: + name (str): Name for module + settings (torch_tensorrt.dynamo.CompilationSettings): Settings used to compile engine, assumes engine was built with default compilation settings if object not passed Example: @@ -84,9 +86,10 @@ def __init__( trt_module = TorchTensorRTModule( engine_str, - name="my_module", input_binding_names=["x"], output_binding_names=["output"], + name="my_module", + settings=CompilationSettings(device=torch.cuda.current_device) ) """ @@ -102,26 +105,43 @@ def __init__( output_binding_names if output_binding_names is not None else [] ) self.name = name - target_device = ( - settings.device if settings.device is not None else Device._current_device() - ) self.hardware_compatible = settings.hardware_compatible self.settings = copy.deepcopy(settings) - if serialized_engine is not None: - self.engine = torch.classes.tensorrt.Engine( - [ - torch.ops.tensorrt.ABI_VERSION(), - self.name + "_engine" if self.name != "" else "tensorrt_engine", - target_device._to_serialized_rt_device(), - serialized_engine, - TorchTensorRTModule._pack_binding_names(self.input_binding_names), - TorchTensorRTModule._pack_binding_names(self.output_binding_names), - str(int(self.hardware_compatible)), - self.encode_metadata(settings), - ] - ) - else: - self.engine = None + self.serialized_engine = serialized_engine + self.engine = None + + if serialized_engine and not self.settings.lazy_engine_init: + self.setup_engine() + + def setup_engine(self) -> None: + """ + Setup engine for a module which has deferred engine setup. + + Will setup the TensorRT engine for this module in the case that setup has been + deferred. In the case that the engine has already been setup, will return without + changing anything. Assumes that serialized engine and settings have already been passed + to the module. + """ + if self.engine is not None: + return + + target_device = ( + self.settings.device + if self.settings.device is not None + else Device._current_device() + ) + self.engine = torch.classes.tensorrt.Engine( + [ + torch.ops.tensorrt.ABI_VERSION(), + self.name + "_engine" if self.name != "" else "tensorrt_engine", + target_device._to_serialized_rt_device(), + self.serialized_engine, + TorchTensorRTModule._pack_binding_names(self.input_binding_names), + TorchTensorRTModule._pack_binding_names(self.output_binding_names), + str(int(self.hardware_compatible)), + self.encode_metadata(self.settings), + ] + ) def encode_metadata(self, settings: Any) -> str: settings = copy.deepcopy(settings) @@ -140,9 +160,12 @@ def decode_metadata(encoded_settings: bytes) -> Any: return settings def get_extra_state(self) -> SerializedTorchTensorRTModuleFmt: + if self.engine is None and self.serialized_engine is not None: + self.setup_engine() + return ( self.name, - self.engine.__getstate__() if self.engine is not None else None, + self.engine.__getstate__() if self.engine else None, self.input_binding_names, self.output_binding_names, ) @@ -152,13 +175,13 @@ def set_extra_state(self, state: SerializedTorchTensorRTModuleFmt) -> None: if state[1] is not None: serialized_engine_info: SerializedTensorRTEngineFmt = state[1] - serialized_engine = base64.b64decode(serialized_engine_info[3]) + self.serialized_engine = base64.b64decode(serialized_engine_info[3]) self.engine = torch.classes.tensorrt.Engine( [ serialized_engine_info[ABI_TARGET_IDX], serialized_engine_info[NAME_IDX], serialized_engine_info[DEVICE_IDX], - serialized_engine, + self.serialized_engine, serialized_engine_info[INPUT_BINDING_NAMES_IDX], serialized_engine_info[OUTPUT_BINDING_NAMES_IDX], serialized_engine_info[HW_COMPATIBLE_IDX], @@ -185,7 +208,7 @@ def forward(self, *inputs: Any) -> torch.Tensor | Tuple[torch.Tensor, ...]: torch.Tensor or Tuple(torch.Tensor): Result of the engine computation """ if self.engine is None: - raise RuntimeError("Engine has not been initialized yet.") + raise RuntimeError("Engine has not been setup yet.") assert len(inputs) == len( self.input_binding_names diff --git a/tests/py/dynamo/conversion/harness.py b/tests/py/dynamo/conversion/harness.py index b4c1fad2dc..26f54d4d7b 100644 --- a/tests/py/dynamo/conversion/harness.py +++ b/tests/py/dynamo/conversion/harness.py @@ -3,7 +3,7 @@ import logging import time import unittest -from typing import Callable, List, Optional, Set, Tuple +from typing import Callable, List, Optional, Tuple import torch import torch_tensorrt @@ -64,6 +64,7 @@ def run_test( atol, check_dtype=True, pyt_inputs=None, + rt_cls=PythonTorchTensorRTModule, ): with torch.no_grad(): cuda_inputs = [] @@ -74,10 +75,11 @@ def run_test( interpreter_result = interpreter.run() sec = time.perf_counter() - start _LOGGER.info(f"Interpreter run time(s): {sec}") - trt_mod = PythonTorchTensorRTModule( - interpreter_result.engine, - interpreter_result.input_names, - interpreter_result.output_names, + trt_mod = rt_cls( + serialized_engine=interpreter_result.serialized_engine, + input_binding_names=list(interpreter_result.input_names), + output_binding_names=list(interpreter_result.output_names), + name="test_engine", ) mod = mod.cuda() if pyt_inputs is not None: @@ -132,6 +134,7 @@ def run_test_custom_compare_results( interpreter, comparators: List[Tuple[Callable, List]], fp16_mode=False, + rt_cls=PythonTorchTensorRTModule, ): """ Runs the test and compares the result using the provided comparators. @@ -154,10 +157,11 @@ def run_test_custom_compare_results( self.assert_has_op(mod, expected_ops) interpreter_result = interpreter.run() - trt_mod = PythonTorchTensorRTModule( - interpreter_result.engine, - interpreter_result.input_names, - interpreter_result.output_names, + trt_mod = rt_cls( + serialized_engine=interpreter_result.serialized_engine, + input_binding_names=list(interpreter_result.input_names), + output_binding_names=list(interpreter_result.output_names), + name="test_engine", ) res_trt = trt_mod(*cuda_inputs).cpu() res_cpu = mod(*cuda_inputs).cpu() diff --git a/tests/py/dynamo/models/test_dtype_support.py b/tests/py/dynamo/models/test_dtype_support.py index 1ab0848828..bb24c7284e 100644 --- a/tests/py/dynamo/models/test_dtype_support.py +++ b/tests/py/dynamo/models/test_dtype_support.py @@ -178,6 +178,14 @@ def forward(self, x): ) +@unittest.skipIf( + torch.cuda.get_device_properties(torch.cuda.current_device()).major < 8 + or ( + torch.cuda.get_device_properties(torch.cuda.current_device()).major == 8 + and torch.cuda.get_device_properties(torch.cuda.current_device()).major == 7 + ), + "Platform does not have BF16 support", +) class TestBF16Support(TestCase): @unittest.skipIf( not torch_tensorrt.ENABLED_FEATURES.torch_tensorrt_runtime, diff --git a/tests/py/dynamo/models/test_models_export.py b/tests/py/dynamo/models/test_models_export.py index 9fdab1a9d0..c18d49954e 100644 --- a/tests/py/dynamo/models/test_models_export.py +++ b/tests/py/dynamo/models/test_models_export.py @@ -30,7 +30,6 @@ def test_resnet18(ir): "pass_through_build_failures": True, "optimization_level": 1, "min_block_size": 8, - "ir": "dynamo", } trt_mod = torchtrt.compile(model, **compile_spec) @@ -61,7 +60,6 @@ def test_mobilenet_v2(ir): "pass_through_build_failures": True, "optimization_level": 1, "min_block_size": 8, - "ir": "dynamo", } trt_mod = torchtrt.compile(model, **compile_spec) @@ -92,7 +90,6 @@ def test_efficientnet_b0(ir): "pass_through_build_failures": True, "optimization_level": 1, "min_block_size": 8, - "ir": "dynamo", } trt_mod = torchtrt.compile(model, **compile_spec) @@ -170,7 +167,6 @@ def test_resnet18_half(ir): "pass_through_build_failures": True, "optimization_level": 1, "min_block_size": 8, - "ir": "dynamo", } trt_mod = torchtrt.compile(model, **compile_spec) diff --git a/tests/py/dynamo/runtime/conftest.py b/tests/py/dynamo/runtime/conftest.py new file mode 100644 index 0000000000..0dedfa3d2f --- /dev/null +++ b/tests/py/dynamo/runtime/conftest.py @@ -0,0 +1,21 @@ +# type: ignore + +import pytest + + +def pytest_addoption(parser): + parser.addoption( + "--ir", + metavar="Internal Representation", + nargs=1, + type=str, + required=False, + help="IR to compile with", + choices=["dynamo", "torch_compile"], + ) + + +@pytest.fixture +def ir(request): + ir_opt = request.config.getoption("--ir") + return ir_opt[0] if ir_opt else "dynamo" diff --git a/tests/py/dynamo/runtime/test_convert_module_to_trt_engine.py b/tests/py/dynamo/runtime/test_convert_module_to_trt_engine.py index c23684646a..b513ff46c8 100644 --- a/tests/py/dynamo/runtime/test_convert_module_to_trt_engine.py +++ b/tests/py/dynamo/runtime/test_convert_module_to_trt_engine.py @@ -1,11 +1,12 @@ import unittest -import tensorrt as trt import torch import torch_tensorrt from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity +import tensorrt as trt + class TestConvertModuleToTrtEngine(unittest.TestCase): def test_convert_module(self): @@ -21,8 +22,10 @@ def forward(self, a, b): exp_program = torch.export.export(model, (input_data_0, input_data_1)) # Convert to TensorRT engine - trt_engine_str = torch_tensorrt.dynamo.convert_module_to_trt_engine( - exp_program, inputs=(input_data_0, input_data_1) + trt_engine_str = ( + torch_tensorrt.dynamo.convert_exported_program_to_serialized_trt_engine( + exp_program, inputs=(input_data_0, input_data_1) + ) ) # Inference on TRT Engine diff --git a/tests/py/dynamo/runtime/test_lazy_engine_init.py b/tests/py/dynamo/runtime/test_lazy_engine_init.py new file mode 100644 index 0000000000..1f3de69eb3 --- /dev/null +++ b/tests/py/dynamo/runtime/test_lazy_engine_init.py @@ -0,0 +1,329 @@ +# type: ignore +import os +import tempfile +import unittest + +import torch +import torch_tensorrt +import torch_tensorrt as torchtrt +import torchvision.models as models +from torch.testing._internal.common_utils import TestCase +from torch_tensorrt.dynamo import CompilationSettings +from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity +from torch_tensorrt.runtime import PythonTorchTensorRTModule, TorchTensorRTModule + +assertions = unittest.TestCase() + + +def assert_close(outputs, ref_outputs): + if type(outputs) not in (list, tuple): + outputs = [outputs] + + if type(ref_outputs) not in ( + list, + tuple, + torch.return_types.max, + torch.return_types.min, + ): + ref_outputs = [ref_outputs] + + for out, ref in zip(outputs, ref_outputs): + if not isinstance(ref, torch.Tensor): + if len(out.shape) == 0: + ref = torch.tensor(ref) + else: + ref = torch.tensor([ref]) + ref = ref.cpu() # to_dtype test has cases with gpu output + torch.testing.assert_close( + out.cpu(), + ref.cpu(), + rtol=1e-03, + atol=1e-03, + equal_nan=True, + check_dtype=True, + ) + + +class TestLazyEngineInit(TestCase): + + def test_lazy_engine_init_py(self): + class Test(torch.nn.Module): + def forward(self, a, b): + return torch.add(a, b) + + # Prepare the input data + input_data_0, input_data_1 = torch.randn((2, 4)), torch.randn((2, 4)) + + # Create a model + model = Test() + exp_program = torch.export.export(model, (input_data_0, input_data_1)) + + # Convert to TensorRT engine + trt_engine_str = ( + torch_tensorrt.dynamo.convert_exported_program_to_serialized_trt_engine( + exp_program, inputs=(input_data_0, input_data_1) + ) + ) + + # Inference on TRT Engine + trt_module = PythonTorchTensorRTModule( + trt_engine_str, + ["a", "b"], + ["output0"], + settings=CompilationSettings(lazy_engine_init=True), + ) + + assertions.assertTrue( + trt_module.engine is None, + msg="Engine was proactively instantiated even though lazy engine loading was enabled", + ) + + with assertions.assertRaises(Exception): + trt_output = trt_module(input_data_0, input_data_1).cpu() + + trt_module.setup_engine() + assertions.assertTrue(trt_module.engine, msg="Engine was not setup") + + trt_output = trt_module(input_data_0, input_data_1).cpu() + + # Inference on PyTorch model + model_output = model(input_data_0, input_data_1) + + assert_close(trt_output, model_output) + + @unittest.skipIf( + not torch_tensorrt.ENABLED_FEATURES.torch_tensorrt_runtime, + "Torch-TensorRT Runtime is not available", + ) + def test_lazy_engine_init_cpp(self): + class Test(torch.nn.Module): + def forward(self, a, b): + return torch.add(a, b) + + # Prepare the input data + input_data_0, input_data_1 = torch.randn((2, 4)), torch.randn((2, 4)) + + # Create a model + model = Test() + exp_program = torch.export.export(model, (input_data_0, input_data_1)) + + # Convert to TensorRT engine + trt_engine_str = ( + torch_tensorrt.dynamo.convert_exported_program_to_serialized_trt_engine( + exp_program, inputs=(input_data_0, input_data_1) + ) + ) + + # Inference on TRT Engine + trt_module = TorchTensorRTModule( + trt_engine_str, + ["a", "b"], + ["output0"], + settings=CompilationSettings(lazy_engine_init=True), + ) + + assertions.assertTrue( + trt_module.engine is None, + msg="Engine was proactively instantiated even though lazy engine loading was enabled", + ) + + with assertions.assertRaises(Exception): + trt_output = trt_module( + input_data_0.to("cuda"), input_data_1.to("cuda") + ).cpu() + + trt_module.setup_engine() + assertions.assertTrue(trt_module.engine is not None, msg="Engine was not setup") + + trt_output = trt_module(input_data_0.to("cuda"), input_data_1.to("cuda")).cpu() + + # Inference on PyTorch model + model_output = model(input_data_0, input_data_1) + + assert_close(trt_output, model_output) + + def test_lazy_engine_init_py_e2e(self): + model = models.resnet18(pretrained=True).eval().to("cuda") + input = torch.randn((1, 3, 224, 224)).to("cuda") + + compile_spec = { + "inputs": [ + torchtrt.Input( + input.shape, dtype=torch.float, format=torch.contiguous_format + ) + ], + "device": torchtrt.Device("cuda:0"), + "enabled_precisions": {torch.float}, + "pass_through_build_failures": True, + "optimization_level": 1, + "min_block_size": 1, + "ir": "dynamo", + "lazy_engine_init": True, + "use_python_runtime": True, + } + + trt_mod = torchtrt.compile(model, **compile_spec) + cos_sim = cosine_similarity(model(input), trt_mod(input)) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"Resnet18 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + # Clean up model env + torch._dynamo.reset() + + @unittest.skipIf( + not torch_tensorrt.ENABLED_FEATURES.torch_tensorrt_runtime, + "Torch-TensorRT Runtime is not available", + ) + def test_lazy_engine_init_cpp_e2e(self): + model = models.resnet18(pretrained=False).eval().to("cuda") + input = torch.randn((1, 3, 224, 224)).to("cuda") + + compile_spec = { + "inputs": [ + torchtrt.Input( + input.shape, dtype=torch.float, format=torch.contiguous_format + ) + ], + "device": torchtrt.Device("cuda:0"), + "enabled_precisions": {torch.float}, + "pass_through_build_failures": True, + "optimization_level": 1, + "min_block_size": 1, + "ir": "dynamo", + "lazy_engine_init": True, + "use_python_runtime": False, + } + + trt_mod = torchtrt.compile(model, **compile_spec) + cos_sim = cosine_similarity(model(input), trt_mod(input)) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"Resnet18 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + # Clean up model env + torch._dynamo.reset() + + @unittest.skipIf( + not torch_tensorrt.ENABLED_FEATURES.torch_tensorrt_runtime, + "Torch-TensorRT Runtime is not available", + ) + def test_lazy_engine_init_cpp_serialization(self): + model = models.resnet18(pretrained=False).eval().to("cuda") + input = torch.randn((1, 3, 224, 224)).to("cuda") + + compile_spec = { + "inputs": [ + torchtrt.Input( + input.shape, dtype=torch.float, format=torch.contiguous_format + ) + ], + "device": torchtrt.Device("cuda:0"), + "enabled_precisions": {torch.float}, + "pass_through_build_failures": True, + "optimization_level": 1, + "min_block_size": 1, + "ir": "dynamo", + "lazy_engine_init": True, + "use_python_runtime": False, + } + + trt_mod = torchtrt.compile(model, **compile_spec) + + with tempfile.TemporaryDirectory() as tmpdir: + torch_tensorrt.save( + trt_mod, os.path.join(tmpdir, "tmp_trt_mod.ep"), inputs=[input] + ) + new_trt_mod = torch.export.load(os.path.join(tmpdir, "tmp_trt_mod.ep")) + + loaded_trt_mod = new_trt_mod.module() + cos_sim = cosine_similarity(model(input), trt_mod(input)) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"Resnet18 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + # Clean up model env + torch._dynamo.reset() + + def test_lazy_engine_init_py_hybrid_graph(self): + class Test(torch.nn.Module): + def forward(self, a, b): + w = torch.add(a, b) + x = 2 * b + y = torch.sub(w, a) + z = torch.add(y, x) + return w, x, y, z + + # Prepare the input data + input_data_0, input_data_1 = torch.randn((2, 4)).to("cuda"), torch.randn( + (2, 4) + ).to("cuda") + + # Create a model + model = Test() + exp_program = torch.export.export(model, (input_data_0, input_data_1)) + + compile_spec = { + "inputs": (input_data_0, input_data_1), + "device": torchtrt.Device("cuda:0"), + "enabled_precisions": {torch.float}, + "pass_through_build_failures": True, + "optimization_level": 1, + "min_block_size": 1, + "ir": "dynamo", + "lazy_engine_init": True, + "use_python_runtime": True, + "torch_executed_ops": [torch.ops.aten.sub.Tensor], + } + + trt_mod = torchtrt.dynamo.compile(exp_program, **compile_spec) + assert_close( + trt_mod(input_data_0, input_data_1), model(input_data_0, input_data_1) + ) + + # Clean up model env + torch._dynamo.reset() + + @unittest.skipIf( + not torch_tensorrt.ENABLED_FEATURES.torch_tensorrt_runtime, + "Torch-TensorRT Runtime is not available", + ) + def test_lazy_engine_init_cpp_hybrid_graph(self): + class Test(torch.nn.Module): + def forward(self, a, b): + x = torch.add(a, b) + y = torch.sub(x, 2 * b) + z = torch.add(y, b) + return z + + # Prepare the input data + input_data_0, input_data_1 = torch.randn((2, 4)).to("cuda"), torch.randn( + (2, 4) + ).to("cuda") + + # Create a model + model = Test() + exp_program = torch.export.export(model, (input_data_0, input_data_1)) + + compile_spec = { + "inputs": (input_data_0, input_data_1), + "device": torchtrt.Device("cuda:0"), + "enabled_precisions": {torch.float}, + "pass_through_build_failures": True, + "optimization_level": 1, + "min_block_size": 1, + "ir": "dynamo", + "lazy_engine_init": True, + "use_python_runtime": False, + "torch_executed_ops": [torch.ops.aten.sub.Tensor], + } + + trt_mod = torchtrt.dynamo.compile(exp_program, **compile_spec) + assert_close( + trt_mod(input_data_0, input_data_1), model(input_data_0, input_data_1) + ) + + # Clean up model env + torch._dynamo.reset() diff --git a/tests/py/ts/api/test_collections.py b/tests/py/ts/api/test_collections.py index 7dc79b09b4..c7532064aa 100644 --- a/tests/py/ts/api/test_collections.py +++ b/tests/py/ts/api/test_collections.py @@ -13,7 +13,7 @@ def find_repo_root(max_depth=10): dir_path = os.path.dirname(os.path.realpath(__file__)) for i in range(max_depth): files = os.listdir(dir_path) - if "WORKSPACE" in files: + if "MODULE.bazel" in files: return dir_path else: dir_path = os.path.dirname(dir_path)