From 40349a8be3f1694586fe2ff0e5a9c98f4db80fb9 Mon Sep 17 00:00:00 2001 From: Evan Li Date: Thu, 19 Sep 2024 08:38:28 -0700 Subject: [PATCH 01/52] support weight-stripped engine and REFIT_IDENTICAL flag --- py/torch_tensorrt/dynamo/_compiler.py | 6 + py/torch_tensorrt/dynamo/_defaults.py | 2 + py/torch_tensorrt/dynamo/_settings.py | 8 + .../dynamo/conversion/_TRTInterpreter.py | 60 ++++---- .../dynamo/conversion/_conversion.py | 1 + .../runtime/_PythonTorchTensorRTModule.py | 54 ++++++- .../dynamo/runtime/_TorchTensorRTModule.py | 4 + tests/py/dynamo/models/test_engine_cache.py | 11 +- .../models/test_weight_stripped_engine.py | 143 ++++++++++++++++++ 9 files changed, 249 insertions(+), 40 deletions(-) create mode 100644 tests/py/dynamo/models/test_weight_stripped_engine.py diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index b28dc53a14..4ef0fdb472 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -90,6 +90,8 @@ def compile( custom_engine_cache: Optional[BaseEngineCache] = _defaults.CUSTOM_ENGINE_CACHE, use_explicit_typing: bool = _defaults.USE_EXPLICIT_TYPING, use_fp32_acc: bool = _defaults.USE_FP32_ACC, + refit_identical_engine_weights: bool = _defaults.REFIT_IDENTICAL_ENGINE_WEIGHTS, + strip_engine_weights: bool = _defaults.STRIP_ENGINE_WEIGHTS, **kwargs: Any, ) -> torch.fx.GraphModule: """Compile an ExportedProgram module for NVIDIA GPUs using TensorRT @@ -162,6 +164,8 @@ def compile( custom_engine_cache (Optional[BaseEngineCache]): Engine cache instance to use for saving and loading engines. Users can provide their own engine cache by inheriting from BaseEngineCache. If used, engine_cache_dir and engine_cache_size will be ignored. use_explicit_typing (bool): This flag enables strong typing in TensorRT compilation which respects the precisions set in the Pytorch model. This is useful when users have mixed precision graphs. use_fp32_acc (bool): This option inserts cast to FP32 nodes around matmul layers and TensorRT ensures the accumulation of matmul happens in FP32. Use this only when FP16 precision is configured in enabled_precisions. + refit_identical_engine_weights (bool): Refit engines with identical weights. This is useful when the same model is compiled multiple times with different inputs and the weights are the same. This will save time by reusing the same engine for different inputs. + strip_engine_weights (bool): Strip engine weights from the serialized engine. This is useful when the engine is to be deployed in an environment where the weights are not required. **kwargs: Any, Returns: torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT @@ -301,6 +305,8 @@ def compile( "reuse_cached_engines": reuse_cached_engines, "use_explicit_typing": use_explicit_typing, "use_fp32_acc": use_fp32_acc, + "refit_identical_engine_weights": refit_identical_engine_weights, + "strip_engine_weights": strip_engine_weights, } settings = CompilationSettings(**compilation_options) diff --git a/py/torch_tensorrt/dynamo/_defaults.py b/py/torch_tensorrt/dynamo/_defaults.py index de99df71e0..000f6e80ec 100644 --- a/py/torch_tensorrt/dynamo/_defaults.py +++ b/py/torch_tensorrt/dynamo/_defaults.py @@ -42,6 +42,8 @@ CUSTOM_ENGINE_CACHE = None USE_EXPLICIT_TYPING = False USE_FP32_ACC = False +REFIT_IDENTICAL_ENGINE_WEIGHTS = False +STRIP_ENGINE_WEIGHTS = False def default_device() -> Device: diff --git a/py/torch_tensorrt/dynamo/_settings.py b/py/torch_tensorrt/dynamo/_settings.py index 98865c683e..5ec22cbc8d 100644 --- a/py/torch_tensorrt/dynamo/_settings.py +++ b/py/torch_tensorrt/dynamo/_settings.py @@ -24,9 +24,11 @@ NUM_AVG_TIMING_ITERS, OPTIMIZATION_LEVEL, PASS_THROUGH_BUILD_FAILURES, + REFIT_IDENTICAL_ENGINE_WEIGHTS, REQUIRE_FULL_COMPILATION, REUSE_CACHED_ENGINES, SPARSE_WEIGHTS, + STRIP_ENGINE_WEIGHTS, TIMING_CACHE_PATH, TRUNCATE_DOUBLE, USE_EXPLICIT_TYPING, @@ -82,6 +84,8 @@ class CompilationSettings: reuse_cached_engines (bool): Whether to load the compiled TRT engines from storage use_strong_typing (bool): This flag enables strong typing in TensorRT compilation which respects the precisions set in the Pytorch model. This is useful when users have mixed precision graphs. use_fp32_acc (bool): This option inserts cast to FP32 nodes around matmul layers and TensorRT ensures the accumulation of matmul happens in FP32. Use this only when FP16 precision is configured in enabled_precisions. + refit_identical_engine_weights (bool): Whether to refit the engine with identical weights + strip_engine_weights (bool): Whether to strip the engine weights """ enabled_precisions: Set[dtype] = field(default_factory=lambda: ENABLED_PRECISIONS) @@ -118,6 +122,8 @@ class CompilationSettings: reuse_cached_engines: bool = REUSE_CACHED_ENGINES use_explicit_typing: bool = USE_EXPLICIT_TYPING use_fp32_acc: bool = USE_FP32_ACC + refit_identical_engine_weights: bool = REFIT_IDENTICAL_ENGINE_WEIGHTS + strip_engine_weights: bool = STRIP_ENGINE_WEIGHTS _SETTINGS_TO_BE_ENGINE_INVARIANT = ( @@ -130,6 +136,8 @@ class CompilationSettings: "make_refittable", "engine_capability", "hardware_compatible", + "refit_identical_engine_weights", + "strip_engine_weights", ) diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index 1358e034c7..9d47ca668f 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -288,7 +288,16 @@ def _populate_trt_builder_config( builder_config.clear_flag(trt.BuilderFlag.TF32) if self.compilation_settings.make_refittable: - builder_config.set_flag(trt.BuilderFlag.REFIT) + if version.parse(trt.__version__) >= version.parse("10.0"): + if self.compilation_settings.refit_identical_engine_weights: + builder_config.set_flag(trt.BuilderFlag.REFIT_IDENTICAL) + else: + builder_config.set_flag(trt.BuilderFlag.REFIT) + else: + builder_config.set_flag(trt.BuilderFlag.REFIT) + + if self.compilation_settings.strip_engine_weights: + builder_config.set_flag(trt.BuilderFlag.STRIP_PLAN) if strict_type_constraints: builder_config.set_flag(trt.BuilderFlag.STRICT_TYPES) @@ -553,7 +562,7 @@ def run( cached_data = self.engine_cache.check(hash_val) if cached_data is not None: # hit the cache ( - serialized_engine, + unrefitted_serialized_engine, self._input_names, self._output_names, cached_engine_input_specs, @@ -584,31 +593,12 @@ def run( "Found the cached engine that corresponds to this graph. It is directly loaded." ) - runtime = trt.Runtime(TRT_LOGGER) - engine = runtime.deserialize_cuda_engine(serialized_engine) - - from torch_tensorrt.dynamo._refit import ( - _refit_single_trt_engine_with_gm, - ) - - # TODO: Fast refit is problematic for now. It will fail if the engine has batch_norm layers. - # We set weight_name_map=None to use slow refit anyway for now. Will fix it in the future. - _refit_single_trt_engine_with_gm( - new_gm=self.module, - old_engine=engine, - input_list=self.input_specs, - settings=self.compilation_settings, - weight_name_map=None, - ) - - serialized_engine = engine.serialize() - with io.BytesIO() as engine_bytes: - engine_bytes.write(serialized_engine) - engine_str = engine_bytes.getvalue() + engine_bytes.write(unrefitted_serialized_engine) + unrefitted_engine_str = engine_bytes.getvalue() return TRTInterpreterResult( - engine_str, + unrefitted_engine_str, self._input_names, self._output_names, self.weight_name_map, @@ -630,19 +620,24 @@ def run( builder_config, self.compilation_settings.timing_cache_path ) - serialized_engine = self.builder.build_serialized_network( + # if strip_engine_weights is true, the serialized engine need to be refitted before using + maybe_unrefitted_serialized_engine = self.builder.build_serialized_network( self.ctx.net, builder_config ) - assert serialized_engine + assert maybe_unrefitted_serialized_engine _LOGGER.info( f"Build TRT engine elapsed time: {datetime.now() - build_engine_start_time}" ) - _LOGGER.info(f"TRT Engine uses: {serialized_engine.nbytes} bytes of Memory") + _LOGGER.info( + f"TRT Engine uses: {maybe_unrefitted_serialized_engine.nbytes} bytes of Memory" + ) self._save_timing_cache( builder_config, self.compilation_settings.timing_cache_path ) + + # if strip_engine_weights is true, the weight-stripped engine will be saved in engine cache if ( self.engine_cache is not None and self.compilation_settings.cache_built_engines @@ -650,7 +645,7 @@ def run( self.engine_cache.insert( hash_val, ( - serialized_engine, + maybe_unrefitted_serialized_engine, self._input_names, self._output_names, self.input_specs, @@ -660,11 +655,14 @@ def run( ) with io.BytesIO() as engine_bytes: - engine_bytes.write(serialized_engine) - engine_str = engine_bytes.getvalue() + engine_bytes.write(maybe_unrefitted_serialized_engine) + maybe_unrefitted_engine_str = engine_bytes.getvalue() return TRTInterpreterResult( - engine_str, self._input_names, self._output_names, self.weight_name_map + maybe_unrefitted_engine_str, + self._input_names, + self._output_names, + self.weight_name_map, ) def run_node(self, n: torch.fx.Node) -> torch.fx.Node: diff --git a/py/torch_tensorrt/dynamo/conversion/_conversion.py b/py/torch_tensorrt/dynamo/conversion/_conversion.py index 06fade9674..aa7ff05cc8 100644 --- a/py/torch_tensorrt/dynamo/conversion/_conversion.py +++ b/py/torch_tensorrt/dynamo/conversion/_conversion.py @@ -165,4 +165,5 @@ def convert_module( name=name, settings=settings, weight_name_map=interpreter_result.weight_name_map, + graph_module=module, ) diff --git a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py index 1f84b7c400..7b18415167 100644 --- a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py @@ -12,7 +12,7 @@ from torch_tensorrt._Device import Device from torch_tensorrt._enums import Platform, dtype from torch_tensorrt.dynamo._settings import CompilationSettings -from torch_tensorrt.dynamo.utils import DYNAMIC_DIM +from torch_tensorrt.dynamo.utils import DYNAMIC_DIM, get_model_device from torch_tensorrt.logging import TRT_LOGGER from torch_tensorrt.runtime._utils import ( _is_switch_required, @@ -38,7 +38,8 @@ def __init__( *, name: str = "", settings: CompilationSettings = CompilationSettings(), - weight_name_map: Any = None, + weight_name_map: Optional[dict[Any, Any]] = None, + graph_module: torch.fx.GraphModule = None, ): """Takes a name, target device, serialized TensorRT engine, and binding names / order and constructs a PyTorch ``torch.nn.Module`` around it. Uses TensorRT Python APIs to run the engine @@ -51,6 +52,8 @@ def __init__( Keyword Arguments: name (str): Name for module settings (torch_tensorrt.dynamo.CompilationSettings): Settings used to compile engine, assumes engine was built with default compilation settings if object not passed + weight_name_map (dict): Mapping of engine weight name to state_dict weight name + graph_module (torch.fx.GraphModule): GraphModule used to refit the weights Example: @@ -105,6 +108,7 @@ def __init__( self.settings = settings self.engine = None self.weight_name_map = weight_name_map + self.graph_module = graph_module # may be used to refit the weights self.target_platform = Platform.current_platform() if self.serialized_engine is not None and not self.settings.lazy_engine_init: @@ -120,6 +124,52 @@ def setup_engine(self) -> None: self.engine = runtime.deserialize_cuda_engine(self.serialized_engine) self.context = self.engine.create_execution_context() + if self.settings.strip_engine_weights: + assert ( + self.settings.make_refittable + ), "weight-stripped engines must be refittable, please set make_refittable=True" + + # Refit the weights + refitter = trt.Refitter(self.engine, TRT_LOGGER) + refittable_weights = refitter.get_all_weights() + torch_device = get_model_device(self.graph_module) + + for layer_name in refittable_weights: + trt_wt_location = ( + trt.TensorLocation.DEVICE + if torch_device.type == "cuda" + else trt.TensorLocation.HOST + ) + from torch_tensorrt.dynamo._refit import ( + construct_refit_mapping_from_weight_name_map, + ) + + mapping = construct_refit_mapping_from_weight_name_map( + self.weight_name_map, self.graph_module.state_dict() + ) + + for layer_name in refittable_weights: + if layer_name not in mapping: + logger.warning(f"{layer_name} is not found in weight mapping.") + continue + # Use Numpy to create weights + weight, weight_dtype = mapping[layer_name] + trt_wt_tensor = trt.Weights( + weight_dtype, weight.data_ptr(), torch.numel(weight) + ) + refitter.set_named_weights( + layer_name, trt_wt_tensor, trt_wt_location + ) + assert ( + len(refitter.get_missing_weights()) == 0 + ), "Fast refitting failed due to incomplete mapping" + + # Refit the engine + if refitter.refit_cuda_engine(): + logger.info("Engine refitted successfully!") + else: + logger.info("Engine refit failed!") + assert self.engine.num_io_tensors == ( len(self.input_names) + len(self.output_names) ) diff --git a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py index 7bf42da7f0..ccfbad352e 100644 --- a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py @@ -79,6 +79,7 @@ def __init__( name: str = "", settings: CompilationSettings = CompilationSettings(), # Assumes engine was built with default compilation settings if object not passed weight_name_map: Optional[dict[Any, Any]] = None, + graph_module: torch.fx.GraphModule = None, ): """Takes a name, target device, serialized TensorRT engine, and binding names / order and constructs a PyTorch ``torch.nn.Module`` around it. Uses the Torch-TensorRT runtime extension to run the engines @@ -96,6 +97,8 @@ def __init__( Keyword Arguments: name (str): Name for module settings (torch_tensorrt.dynamo.CompilationSettings): Settings used to compile engine, assumes engine was built with default compilation settings if object not passed + weight_name_map (dict): Mapping of engine weight name to state_dict weight name + graph_module (torch.fx.GraphModule): GraphModule used to refit the weights Example: @@ -129,6 +132,7 @@ def __init__( self.hardware_compatible = settings.hardware_compatible self.settings = copy.deepcopy(settings) self.weight_name_map = weight_name_map + self.graph_module = graph_module self.serialized_engine = serialized_engine self.engine = None diff --git a/tests/py/dynamo/models/test_engine_cache.py b/tests/py/dynamo/models/test_engine_cache.py index 367f68c1f6..cd720bc030 100644 --- a/tests/py/dynamo/models/test_engine_cache.py +++ b/tests/py/dynamo/models/test_engine_cache.py @@ -206,6 +206,7 @@ def remove_timing_cache(path=TIMING_CACHE_PATH): start = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) for i in range(3): + # remove timing cache and reset dynamo for engine caching messurement remove_timing_cache() torch._dynamo.reset() if i == 0: @@ -220,7 +221,7 @@ def remove_timing_cache(path=TIMING_CACHE_PATH): trt_gm = torch_trt.dynamo.compile( exp_program, tuple(inputs), - use_python_runtime=False, + use_python_runtime=True, enabled_precisions={torch.float}, debug=False, min_block_size=1, @@ -231,7 +232,6 @@ def remove_timing_cache(path=TIMING_CACHE_PATH): ) end.record() torch.cuda.synchronize() - torch._dynamo.reset() times.append(start.elapsed_time(end)) results.append(trt_gm(*inputs)) @@ -285,7 +285,7 @@ def test_dynamo_compile_with_custom_engine_cache(self): trt_gm = torch_trt.dynamo.compile( exp_program, tuple(inputs), - use_python_runtime=False, + use_python_runtime=True, enabled_precisions={torch.float}, debug=False, min_block_size=1, @@ -332,7 +332,7 @@ def test_dynamo_compile_change_input_shape(self): trt_gm = torch_trt.dynamo.compile( torch.export.export(model, args=inputs), inputs=inputs, - use_python_runtime=False, + use_python_runtime=True, enabled_precisions={torch.float}, debug=False, min_block_size=1, @@ -402,7 +402,6 @@ def remove_timing_cache(path=TIMING_CACHE_PATH): results.append(compiled_model(*inputs)) # trigger the compilation end.record() torch.cuda.synchronize() - torch._dynamo.reset() times.append(start.elapsed_time(end)) cos_sim = cosine_similarity(results[0], results[1]) @@ -441,7 +440,6 @@ def test_torch_compile_with_custom_engine_cache(self): start = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) for i in range(3): - # remove timing cache and reset dynamo for engine caching messurement if i == 0: cache_built_engines = False reuse_cached_engines = False @@ -501,7 +499,6 @@ def test_torch_compile_change_input_shape(self): custom_engine_cache = MyEngineCache(engine_cache_dir) for i in range(3): - # remove timing cache and reset dynamo for engine caching messurement inputs = [torch.rand((4 * (i + 1), 3, 224, 224)).to("cuda")] compiled_model = torch.compile( model, diff --git a/tests/py/dynamo/models/test_weight_stripped_engine.py b/tests/py/dynamo/models/test_weight_stripped_engine.py new file mode 100644 index 0000000000..4cef33e082 --- /dev/null +++ b/tests/py/dynamo/models/test_weight_stripped_engine.py @@ -0,0 +1,143 @@ +import os +import shutil +import unittest + +import torch +import torch_tensorrt as torch_trt +import torchvision.models as models +from torch.testing._internal.common_utils import TestCase +from torch_tensorrt.dynamo._defaults import TIMING_CACHE_PATH +from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity + +assertions = unittest.TestCase() + + +class TestEngineCache(TestCase): + def test_weight_stripped_engine(self): + model = models.resnet18(pretrained=True).eval().to("cuda") + example_inputs = (torch.randn((100, 3, 224, 224)).to("cuda"),) + # Mark the dim0 of inputs as dynamic + batch = torch.export.Dim("batch", min=1, max=200) + exp_program = torch.export.export( + model, args=example_inputs, dynamic_shapes={"x": {0: batch}} + ) + + engine_cache_dir = "/tmp/test_weight_stripped_engine" + if os.path.exists(engine_cache_dir): + shutil.rmtree(engine_cache_dir) + + def remove_timing_cache(path=TIMING_CACHE_PATH): + if os.path.exists(path): + os.remove(path) + + inputs = [torch.rand((128, 3, 224, 224)).to("cuda")] + results = [] + + # run pytorch model + results.append(model(*inputs)) + + remove_timing_cache() + torch._dynamo.reset() + + trt_gm = torch_trt.dynamo.compile( + exp_program, + tuple(inputs), + use_python_runtime=True, + enabled_precisions={torch.float}, + debug=False, + min_block_size=1, + make_refittable=True, + refit_identical_engine_weights=False, + cache_built_engines=False, + reuse_cached_engines=False, + engine_cache_dir=engine_cache_dir, + strip_engine_weights=True, + ) + results.append(trt_gm(*inputs)) + + cos_sim = cosine_similarity(results[0], results[1]) + + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"results[0] doesn't match with results[1]. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + def test_dynamo_compile_with_refittable_weight_stripped_engine(self): + model = models.resnet18(pretrained=True).eval().to("cuda") + example_inputs = (torch.randn((100, 3, 224, 224)).to("cuda"),) + # Mark the dim0 of inputs as dynamic + batch = torch.export.Dim("batch", min=1, max=200) + exp_program = torch.export.export( + model, args=example_inputs, dynamic_shapes={"x": {0: batch}} + ) + + engine_cache_dir = ( + "/tmp/test_dynamo_compile_with_refittable_weight_stripped_engine" + ) + if os.path.exists(engine_cache_dir): + shutil.rmtree(engine_cache_dir) + + def remove_timing_cache(path=TIMING_CACHE_PATH): + if os.path.exists(path): + os.remove(path) + + # The 1st iteration is to measure the compilation time without engine caching + # The 2nd and 3rd iterations are to measure the compilation time with engine caching. + # Since the 2nd iteration needs to compile and save the engine, it will be slower than the 1st iteration. + # The 3rd iteration should be faster than the 1st iteration because it loads the cached engine. + inputs = [torch.rand((128, 3, 224, 224)).to("cuda")] + results = [] + times = [] + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + for i in range(3): + remove_timing_cache() + torch._dynamo.reset() + if i == 0: + cache_built_engines = False + reuse_cached_engines = False + else: + cache_built_engines = True + reuse_cached_engines = True + + torch.cuda.synchronize() + start.record() + trt_gm = torch_trt.dynamo.compile( + exp_program, + tuple(inputs), + use_python_runtime=True, + enabled_precisions={torch.float}, + debug=False, + min_block_size=1, + make_refittable=True, + refit_identical_engine_weights=True, + strip_engine_weights=True, + cache_built_engines=cache_built_engines, + reuse_cached_engines=reuse_cached_engines, + engine_cache_dir=engine_cache_dir, + ) + end.record() + torch.cuda.synchronize() + times.append(start.elapsed_time(end)) + results.append(trt_gm(*inputs)) + + assertions.assertNotEqual(results[0].sum(), 0, msg="results[0] are all zeros") + assertions.assertNotEqual(results[1].sum(), 0, msg="results[1] are all zeros") + assertions.assertNotEqual(results[2].sum(), 0, msg="results[2] are all zeros") + + cos_sim = cosine_similarity(results[0], results[1]) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"results[0] doesn't match with results[1]. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + cos_sim = cosine_similarity(results[1], results[2]) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"results[1] doesn't match with results[2]. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + assertions.assertTrue( + times[0] > times[2], + msg=f"Engine caching didn't speed up the compilation. Time taken without engine caching: {times[0]} ms, time taken with engine caching: {times[2]} ms", + ) From 5d7c677b97df8a5290984521b658b4e7d2ec87b0 Mon Sep 17 00:00:00 2001 From: Evan Li Date: Fri, 20 Sep 2024 10:33:22 -0700 Subject: [PATCH 02/52] refactor with new design --- py/torch_tensorrt/dynamo/_compiler.py | 6 + .../dynamo/conversion/_TRTInterpreter.py | 68 +++-- .../dynamo/conversion/_conversion.py | 1 - .../runtime/_PythonTorchTensorRTModule.py | 51 +--- .../dynamo/runtime/_TorchTensorRTModule.py | 3 - tests/py/dynamo/models/test_engine_cache.py | 6 +- .../models/test_weight_stripped_engine.py | 235 +++++++++++++++--- 7 files changed, 266 insertions(+), 104 deletions(-) diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 4ef0fdb472..a3ec343a6b 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -550,6 +550,8 @@ def convert_exported_program_to_serialized_trt_engine( timing_cache_path: str = _defaults.TIMING_CACHE_PATH, use_explicit_typing: bool = _defaults.USE_EXPLICIT_TYPING, use_fp32_acc: bool = _defaults.USE_FP32_ACC, + refit_identical_engine_weights: bool = _defaults.REFIT_IDENTICAL_ENGINE_WEIGHTS, + strip_engine_weights: bool = _defaults.STRIP_ENGINE_WEIGHTS, **kwargs: Any, ) -> bytes: """Convert an ExportedProgram to a serialized TensorRT engine @@ -610,6 +612,8 @@ def convert_exported_program_to_serialized_trt_engine( timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation use_explicit_typing (bool): This flag enables strong typing in TensorRT compilation which respects the precisions set in the Pytorch model. This is useful when users have mixed precision graphs. use_fp32_acc (bool): This option inserts cast to FP32 nodes around matmul layers and TensorRT ensures the accumulation of matmul happens in FP32. Use this only when FP16 precision is configured in enabled_precisions. + refit_identical_engine_weights (bool): Refit engines with identical weights. This is useful when the same model is compiled multiple times with different inputs and the weights are the same. This will save time by reusing the same engine for different inputs. + strip_engine_weights (bool): Strip engine weights from the serialized engine. This is useful when the engine is to be deployed in an environment where the weights are not required. Returns: bytes: Serialized TensorRT engine, can either be saved to a file or deserialized via TensorRT APIs """ @@ -685,6 +689,8 @@ def convert_exported_program_to_serialized_trt_engine( "timing_cache_path": timing_cache_path, "use_explicit_typing": use_explicit_typing, "use_fp32_acc": use_fp32_acc, + "refit_identical_engine_weights": refit_identical_engine_weights, + "strip_engine_weights": strip_engine_weights, } exported_program = pre_export_lowering(exported_program) diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index 9d47ca668f..cb49ca7f5f 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -562,7 +562,7 @@ def run( cached_data = self.engine_cache.check(hash_val) if cached_data is not None: # hit the cache ( - unrefitted_serialized_engine, + serialized_engine, self._input_names, self._output_names, cached_engine_input_specs, @@ -593,12 +593,38 @@ def run( "Found the cached engine that corresponds to this graph. It is directly loaded." ) + # refit the cached engine with the new graph module + if not self.compilation_settings.strip_engine_weights: + runtime = trt.Runtime(TRT_LOGGER) + engine = runtime.deserialize_cuda_engine(serialized_engine) + + from torch_tensorrt.dynamo._refit import ( + _refit_single_trt_engine_with_gm, + ) + + _refit_single_trt_engine_with_gm( + new_gm=self.module, + old_engine=engine, + input_list=self.input_specs, + settings=self.compilation_settings, + weight_name_map=self.weight_name_map, + ) + + # Serialize the refitted engine where the EXCLUDE_WEIGHTS flag must be cleared + serialization_config = engine.create_serialization_config() + serialization_config.clear_flag( + trt.SerializationFlag.EXCLUDE_WEIGHTS + ) + serialized_engine = engine.serialize_with_config( + serialization_config + ) + with io.BytesIO() as engine_bytes: - engine_bytes.write(unrefitted_serialized_engine) - unrefitted_engine_str = engine_bytes.getvalue() + engine_bytes.write(serialized_engine) + engine_str = engine_bytes.getvalue() return TRTInterpreterResult( - unrefitted_engine_str, + engine_str, self._input_names, self._output_names, self.weight_name_map, @@ -620,32 +646,44 @@ def run( builder_config, self.compilation_settings.timing_cache_path ) - # if strip_engine_weights is true, the serialized engine need to be refitted before using - maybe_unrefitted_serialized_engine = self.builder.build_serialized_network( + serialized_engine = self.builder.build_serialized_network( self.ctx.net, builder_config ) - assert maybe_unrefitted_serialized_engine + assert serialized_engine _LOGGER.info( f"Build TRT engine elapsed time: {datetime.now() - build_engine_start_time}" ) - _LOGGER.info( - f"TRT Engine uses: {maybe_unrefitted_serialized_engine.nbytes} bytes of Memory" - ) + _LOGGER.info(f"TRT Engine uses: {serialized_engine.nbytes} bytes of Memory") self._save_timing_cache( builder_config, self.compilation_settings.timing_cache_path ) - # if strip_engine_weights is true, the weight-stripped engine will be saved in engine cache if ( self.engine_cache is not None and self.compilation_settings.cache_built_engines ): + assert ( + self.compilation_settings.make_refittable + ), "weight-stripped engines must be refittable, please set make_refittable=True" + + # no matter what compilation_settings is, we cache the weight-stripped engine + if self.compilation_settings.strip_engine_weights: + weight_stripped_serialized_engine = serialized_engine + else: + runtime = trt.Runtime(TRT_LOGGER) + engine = runtime.deserialize_cuda_engine(serialized_engine) + serialization_config = engine.create_serialization_config() + serialization_config.set_flag(trt.SerializationFlag.EXCLUDE_WEIGHTS) + weight_stripped_serialized_engine = engine.serialize_with_config( + serialization_config + ) + self.engine_cache.insert( hash_val, ( - maybe_unrefitted_serialized_engine, + weight_stripped_serialized_engine, self._input_names, self._output_names, self.input_specs, @@ -655,11 +693,11 @@ def run( ) with io.BytesIO() as engine_bytes: - engine_bytes.write(maybe_unrefitted_serialized_engine) - maybe_unrefitted_engine_str = engine_bytes.getvalue() + engine_bytes.write(serialized_engine) + engine_str = engine_bytes.getvalue() return TRTInterpreterResult( - maybe_unrefitted_engine_str, + engine_str, self._input_names, self._output_names, self.weight_name_map, diff --git a/py/torch_tensorrt/dynamo/conversion/_conversion.py b/py/torch_tensorrt/dynamo/conversion/_conversion.py index aa7ff05cc8..06fade9674 100644 --- a/py/torch_tensorrt/dynamo/conversion/_conversion.py +++ b/py/torch_tensorrt/dynamo/conversion/_conversion.py @@ -165,5 +165,4 @@ def convert_module( name=name, settings=settings, weight_name_map=interpreter_result.weight_name_map, - graph_module=module, ) diff --git a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py index 7b18415167..91cd909423 100644 --- a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py @@ -12,7 +12,7 @@ from torch_tensorrt._Device import Device from torch_tensorrt._enums import Platform, dtype from torch_tensorrt.dynamo._settings import CompilationSettings -from torch_tensorrt.dynamo.utils import DYNAMIC_DIM, get_model_device +from torch_tensorrt.dynamo.utils import DYNAMIC_DIM from torch_tensorrt.logging import TRT_LOGGER from torch_tensorrt.runtime._utils import ( _is_switch_required, @@ -39,7 +39,6 @@ def __init__( name: str = "", settings: CompilationSettings = CompilationSettings(), weight_name_map: Optional[dict[Any, Any]] = None, - graph_module: torch.fx.GraphModule = None, ): """Takes a name, target device, serialized TensorRT engine, and binding names / order and constructs a PyTorch ``torch.nn.Module`` around it. Uses TensorRT Python APIs to run the engine @@ -53,7 +52,6 @@ def __init__( name (str): Name for module settings (torch_tensorrt.dynamo.CompilationSettings): Settings used to compile engine, assumes engine was built with default compilation settings if object not passed weight_name_map (dict): Mapping of engine weight name to state_dict weight name - graph_module (torch.fx.GraphModule): GraphModule used to refit the weights Example: @@ -108,7 +106,6 @@ def __init__( self.settings = settings self.engine = None self.weight_name_map = weight_name_map - self.graph_module = graph_module # may be used to refit the weights self.target_platform = Platform.current_platform() if self.serialized_engine is not None and not self.settings.lazy_engine_init: @@ -124,52 +121,6 @@ def setup_engine(self) -> None: self.engine = runtime.deserialize_cuda_engine(self.serialized_engine) self.context = self.engine.create_execution_context() - if self.settings.strip_engine_weights: - assert ( - self.settings.make_refittable - ), "weight-stripped engines must be refittable, please set make_refittable=True" - - # Refit the weights - refitter = trt.Refitter(self.engine, TRT_LOGGER) - refittable_weights = refitter.get_all_weights() - torch_device = get_model_device(self.graph_module) - - for layer_name in refittable_weights: - trt_wt_location = ( - trt.TensorLocation.DEVICE - if torch_device.type == "cuda" - else trt.TensorLocation.HOST - ) - from torch_tensorrt.dynamo._refit import ( - construct_refit_mapping_from_weight_name_map, - ) - - mapping = construct_refit_mapping_from_weight_name_map( - self.weight_name_map, self.graph_module.state_dict() - ) - - for layer_name in refittable_weights: - if layer_name not in mapping: - logger.warning(f"{layer_name} is not found in weight mapping.") - continue - # Use Numpy to create weights - weight, weight_dtype = mapping[layer_name] - trt_wt_tensor = trt.Weights( - weight_dtype, weight.data_ptr(), torch.numel(weight) - ) - refitter.set_named_weights( - layer_name, trt_wt_tensor, trt_wt_location - ) - assert ( - len(refitter.get_missing_weights()) == 0 - ), "Fast refitting failed due to incomplete mapping" - - # Refit the engine - if refitter.refit_cuda_engine(): - logger.info("Engine refitted successfully!") - else: - logger.info("Engine refit failed!") - assert self.engine.num_io_tensors == ( len(self.input_names) + len(self.output_names) ) diff --git a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py index ccfbad352e..03daeada5f 100644 --- a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py @@ -79,7 +79,6 @@ def __init__( name: str = "", settings: CompilationSettings = CompilationSettings(), # Assumes engine was built with default compilation settings if object not passed weight_name_map: Optional[dict[Any, Any]] = None, - graph_module: torch.fx.GraphModule = None, ): """Takes a name, target device, serialized TensorRT engine, and binding names / order and constructs a PyTorch ``torch.nn.Module`` around it. Uses the Torch-TensorRT runtime extension to run the engines @@ -98,7 +97,6 @@ def __init__( name (str): Name for module settings (torch_tensorrt.dynamo.CompilationSettings): Settings used to compile engine, assumes engine was built with default compilation settings if object not passed weight_name_map (dict): Mapping of engine weight name to state_dict weight name - graph_module (torch.fx.GraphModule): GraphModule used to refit the weights Example: @@ -132,7 +130,6 @@ def __init__( self.hardware_compatible = settings.hardware_compatible self.settings = copy.deepcopy(settings) self.weight_name_map = weight_name_map - self.graph_module = graph_module self.serialized_engine = serialized_engine self.engine = None diff --git a/tests/py/dynamo/models/test_engine_cache.py b/tests/py/dynamo/models/test_engine_cache.py index cd720bc030..5dcdfe4ae9 100644 --- a/tests/py/dynamo/models/test_engine_cache.py +++ b/tests/py/dynamo/models/test_engine_cache.py @@ -332,7 +332,7 @@ def test_dynamo_compile_change_input_shape(self): trt_gm = torch_trt.dynamo.compile( torch.export.export(model, args=inputs), inputs=inputs, - use_python_runtime=True, + use_python_runtime=False, enabled_precisions={torch.float}, debug=False, min_block_size=1, @@ -387,7 +387,7 @@ def remove_timing_cache(path=TIMING_CACHE_PATH): model, backend="tensorrt", options={ - "use_python_runtime": True, + "use_python_runtime": False, "enabled_precisions": {torch.float}, "debug": False, "min_block_size": 1, @@ -452,7 +452,7 @@ def test_torch_compile_with_custom_engine_cache(self): model, backend="tensorrt", options={ - "use_python_runtime": True, + "use_python_runtime": False, "enabled_precisions": {torch.float}, "debug": False, "min_block_size": 1, diff --git a/tests/py/dynamo/models/test_weight_stripped_engine.py b/tests/py/dynamo/models/test_weight_stripped_engine.py index 4cef33e082..196800e758 100644 --- a/tests/py/dynamo/models/test_weight_stripped_engine.py +++ b/tests/py/dynamo/models/test_weight_stripped_engine.py @@ -1,4 +1,5 @@ import os +import pickle import shutil import unittest @@ -6,38 +7,61 @@ import torch_tensorrt as torch_trt import torchvision.models as models from torch.testing._internal.common_utils import TestCase +from torch_tensorrt.dynamo import convert_exported_program_to_serialized_trt_engine from torch_tensorrt.dynamo._defaults import TIMING_CACHE_PATH from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity assertions = unittest.TestCase() -class TestEngineCache(TestCase): - def test_weight_stripped_engine(self): - model = models.resnet18(pretrained=True).eval().to("cuda") +class TestWeightStrippedEngine(TestCase): + def test_weight_stripped_engine_sizes(self): + pyt_model = models.resnet18(pretrained=True).eval().to("cuda") + example_inputs = (torch.randn((100, 3, 224, 224)).to("cuda"),) + exp_program = torch.export.export(pyt_model, example_inputs) + weight_included_engine = convert_exported_program_to_serialized_trt_engine( + exp_program, + example_inputs, + make_refittable=False, + strip_engine_weights=False, + refit_identical_engine_weights=False, + ) + weight_stripped_engine = convert_exported_program_to_serialized_trt_engine( + exp_program, + example_inputs, + make_refittable=True, + strip_engine_weights=True, + refit_identical_engine_weights=False, + ) + weight_stripped_refit_identical_engine = ( + convert_exported_program_to_serialized_trt_engine( + exp_program, + example_inputs, + make_refittable=True, + strip_engine_weights=True, + refit_identical_engine_weights=True, + ) + ) + assertions.assertTrue( + len(bytes(weight_included_engine)) > len(bytes(weight_stripped_engine)), + msg=f"Weight-stripped engine size is not smaller than the weight included engine size. Weight included engine size: {len(bytes(weight_included_engine))}, weight stripped engine size: {len(bytes(weight_stripped_engine))}", + ) + assertions.assertTrue( + len(bytes(weight_stripped_engine)) + > len(bytes(weight_stripped_refit_identical_engine)), + msg=f"Weight-stripped refit-identical engine size is not smaller than the weight-stripped engine size. Weight-stripped engine size: {len(bytes(weight_stripped_engine))}, weight-stripped refit-identical engine size: {len(bytes(weight_stripped_refit_identical_engine))}", + ) + + def test_weight_stripped_engine_results(self): + pyt_model = models.resnet18(pretrained=True).eval().to("cuda") example_inputs = (torch.randn((100, 3, 224, 224)).to("cuda"),) # Mark the dim0 of inputs as dynamic batch = torch.export.Dim("batch", min=1, max=200) exp_program = torch.export.export( - model, args=example_inputs, dynamic_shapes={"x": {0: batch}} + pyt_model, args=example_inputs, dynamic_shapes={"x": {0: batch}} ) - engine_cache_dir = "/tmp/test_weight_stripped_engine" - if os.path.exists(engine_cache_dir): - shutil.rmtree(engine_cache_dir) - - def remove_timing_cache(path=TIMING_CACHE_PATH): - if os.path.exists(path): - os.remove(path) - inputs = [torch.rand((128, 3, 224, 224)).to("cuda")] - results = [] - - # run pytorch model - results.append(model(*inputs)) - - remove_timing_cache() - torch._dynamo.reset() trt_gm = torch_trt.dynamo.compile( exp_program, @@ -47,29 +71,99 @@ def remove_timing_cache(path=TIMING_CACHE_PATH): debug=False, min_block_size=1, make_refittable=True, - refit_identical_engine_weights=False, cache_built_engines=False, reuse_cached_engines=False, - engine_cache_dir=engine_cache_dir, + refit_identical_engine_weights=False, strip_engine_weights=True, ) - results.append(trt_gm(*inputs)) + output = trt_gm(*inputs) + assertions.assertEqual( + output.sum(), 0, msg="weight-stripped engine results should be all zeros" + ) - cos_sim = cosine_similarity(results[0], results[1]) + from torch_tensorrt.dynamo._refit import refit_module_weights + + refitted_trt_gm = refit_module_weights(trt_gm, exp_program) + refitted_output = refitted_trt_gm(*inputs) + assertions.assertNotEqual( + refitted_output.sum(), + 0, + msg="refitted engine results should not be all zeros", + ) + compiled_model = torch.compile( + pyt_model, + backend="tensorrt", + options={ + "use_python_runtime": False, + "enabled_precisions": {torch.float}, + "debug": False, + "min_block_size": 1, + "make_refittable": True, + "cache_built_engines": False, + "reuse_cached_engines": False, + "refit_identical_engine_weights": False, + "strip_engine_weights": False, + }, + ) + compiled_model_output = compiled_model(*inputs) + cos_sim = cosine_similarity(refitted_output, compiled_model_output) assertions.assertTrue( cos_sim > COSINE_THRESHOLD, - msg=f"results[0] doesn't match with results[1]. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + msg=f"refitted_output doesn't match with compiled_model_output. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", ) - def test_dynamo_compile_with_refittable_weight_stripped_engine(self): - model = models.resnet18(pretrained=True).eval().to("cuda") + def test_weight_stripped_engine_with_engine_cache(self): + pyt_model = models.resnet18(pretrained=True).eval().to("cuda") example_inputs = (torch.randn((100, 3, 224, 224)).to("cuda"),) - # Mark the dim0 of inputs as dynamic - batch = torch.export.Dim("batch", min=1, max=200) - exp_program = torch.export.export( - model, args=example_inputs, dynamic_shapes={"x": {0: batch}} + exp_program = torch.export.export(pyt_model, example_inputs) + + engine_cache_dir = "/tmp/test_weight_stripped_engine_with_engine_cache" + if os.path.exists(engine_cache_dir): + shutil.rmtree(engine_cache_dir) + + weight_included_engine = convert_exported_program_to_serialized_trt_engine( + exp_program, + example_inputs, + make_refittable=False, + strip_engine_weights=False, + refit_identical_engine_weights=False, + ) + + trt_gm = torch_trt.dynamo.compile( + exp_program, + tuple(example_inputs), + use_python_runtime=True, + enabled_precisions={torch.float}, + debug=False, + min_block_size=1, + make_refittable=True, + refit_identical_engine_weights=True, + strip_engine_weights=False, # engine cache will save the stripped engine even if this is False + cache_built_engines=True, + reuse_cached_engines=True, + engine_cache_dir=engine_cache_dir, + ) + output = trt_gm(*example_inputs) + + blob_path = os.path.join( + engine_cache_dir, os.listdir(engine_cache_dir)[0], "blob.bin" + ) + with open(blob_path, "rb") as f: + blob = f.read() + unpacked = pickle.loads(blob) + cached_stripped_engine = unpacked["serialized_engine"] + + assertions.assertTrue( + len(bytes(weight_included_engine)) > len(bytes(cached_stripped_engine)), + msg=f"cached engine size is not smaller than the weight included engine size. Weight included engine size: {len(bytes(weight_included_engine))}, cached stripped engine size: {len(bytes(cached_stripped_engine))}", ) + assertions.assertNotEqual(output.sum(), 0, msg="results are all zeros") + + def test_dynamo_compile_with_refittable_weight_stripped_engine(self): + pyt_model = models.resnet18(pretrained=True).eval().to("cuda") + example_inputs = (torch.randn((100, 3, 224, 224)).to("cuda"),) + exp_program = torch.export.export(pyt_model, args=example_inputs) engine_cache_dir = ( "/tmp/test_dynamo_compile_with_refittable_weight_stripped_engine" @@ -110,8 +204,8 @@ def remove_timing_cache(path=TIMING_CACHE_PATH): debug=False, min_block_size=1, make_refittable=True, - refit_identical_engine_weights=True, - strip_engine_weights=True, + refit_identical_engine_weights=False, + strip_engine_weights=False, cache_built_engines=cache_built_engines, reuse_cached_engines=reuse_cached_engines, engine_cache_dir=engine_cache_dir, @@ -141,3 +235,80 @@ def remove_timing_cache(path=TIMING_CACHE_PATH): times[0] > times[2], msg=f"Engine caching didn't speed up the compilation. Time taken without engine caching: {times[0]} ms, time taken with engine caching: {times[2]} ms", ) + + def test_torch_compile_with_refittable_weight_stripped_engine(self): + pyt_model = models.resnet18(pretrained=True).eval().to("cuda") + + engine_cache_dir = ( + "/tmp/test_torch_compile_with_refittable_weight_stripped_engine" + ) + if os.path.exists(engine_cache_dir): + shutil.rmtree(engine_cache_dir) + + def remove_timing_cache(path=TIMING_CACHE_PATH): + if os.path.exists(path): + os.remove(path) + + # The 1st iteration is to measure the compilation time without engine caching + # The 2nd and 3rd iterations are to measure the compilation time with engine caching. + # Since the 2nd iteration needs to compile and save the engine, it will be slower than the 1st iteration. + # The 3rd iteration should be faster than the 1st iteration because it loads the cached engine. + inputs = [torch.rand((128, 3, 224, 224)).to("cuda")] + results = [] + times = [] + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + for i in range(3): + remove_timing_cache() + torch._dynamo.reset() + if i == 0: + cache_built_engines = False + reuse_cached_engines = False + else: + cache_built_engines = True + reuse_cached_engines = True + + torch.cuda.synchronize() + start.record() + compiled_model = torch.compile( + pyt_model, + backend="tensorrt", + options={ + "use_python_runtime": False, + "enabled_precisions": {torch.float}, + "debug": False, + "min_block_size": 1, + "make_refittable": True, + "cache_built_engines": cache_built_engines, + "reuse_cached_engines": reuse_cached_engines, + "engine_cache_dir": engine_cache_dir, + "torch_executed_ops": {"torch.ops.aten.relu.default"}, + "refit_identical_engine_weights": True, + "strip_engine_weights": False, + }, + ) + results.append(compiled_model(*inputs)) # trigger the compilation + end.record() + torch.cuda.synchronize() + times.append(start.elapsed_time(end)) + + assertions.assertNotEqual(results[0].sum(), 0, msg="results[0] are all zeros") + assertions.assertNotEqual(results[1].sum(), 0, msg="results[1] are all zeros") + assertions.assertNotEqual(results[2].sum(), 0, msg="results[2] are all zeros") + + cos_sim = cosine_similarity(results[0], results[1]) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"results[0] doesn't match with results[1]. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + cos_sim = cosine_similarity(results[1], results[2]) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"results[1] doesn't match with results[2]. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + assertions.assertTrue( + times[0] > times[2], + msg=f"Engine caching didn't speed up the compilation. Time taken without engine caching: {times[0]} ms, time taken with engine caching: {times[2]} ms", + ) From 82b7ddc8c92f705b065bb3379e4f662ebabcd93b Mon Sep 17 00:00:00 2001 From: Evan Li Date: Mon, 30 Sep 2024 19:21:05 -0700 Subject: [PATCH 03/52] lint --- .../dynamo/engine_caching_bert_example.py | 1 - examples/dynamo/engine_caching_example.py | 5 -- .../dynamo/mutable_torchtrt_module_example.py | 2 - examples/dynamo/refit_engine_example.py | 7 +-- py/torch_tensorrt/dynamo/_compiler.py | 31 +++++----- py/torch_tensorrt/dynamo/_defaults.py | 1 - py/torch_tensorrt/dynamo/_refit.py | 8 +-- py/torch_tensorrt/dynamo/_settings.py | 5 +- .../dynamo/conversion/_TRTInterpreter.py | 58 ++++++++++--------- .../runtime/_MutableTorchTensorRTModule.py | 5 -- py/torch_tensorrt/dynamo/utils.py | 4 -- tests/py/dynamo/models/test_engine_cache.py | 16 ++--- tests/py/dynamo/models/test_model_refit.py | 26 --------- .../models/test_weight_stripped_engine.py | 44 +++++++------- .../runtime/test_mutable_torchtrt_module.py | 8 --- 15 files changed, 78 insertions(+), 143 deletions(-) diff --git a/examples/dynamo/engine_caching_bert_example.py b/examples/dynamo/engine_caching_bert_example.py index 989913bd31..9cddefd509 100644 --- a/examples/dynamo/engine_caching_bert_example.py +++ b/examples/dynamo/engine_caching_bert_example.py @@ -52,7 +52,6 @@ def compile_bert(iterations=3): "truncate_double": True, "debug": False, "min_block_size": 1, - "make_refittable": True, "cache_built_engines": cache_built_engines, "reuse_cached_engines": reuse_cached_engines, "engine_cache_dir": "/tmp/torch_trt_bert_engine_cache", diff --git a/examples/dynamo/engine_caching_example.py b/examples/dynamo/engine_caching_example.py index 28ff73aa72..20388e9372 100644 --- a/examples/dynamo/engine_caching_example.py +++ b/examples/dynamo/engine_caching_example.py @@ -62,8 +62,6 @@ def remove_timing_cache(path=TIMING_CACHE_PATH): # engines are saved to disk tied to a hash of their corresponding PyTorch subgraph. If # in a subsequent compilation, either as part of this session or a new session, the cache will # pull the built engine and **refit** the weights which can reduce compilation times by orders of magnitude. -# As such, in order to insert a new engine into the cache (i.e. ``cache_built_engines=True``), -# the engine must be refittable (``make_refittable=True``). See :ref:`refit_engine_example` for more details. def torch_compile(iterations=3): @@ -97,7 +95,6 @@ def torch_compile(iterations=3): "enabled_precisions": enabled_precisions, "debug": debug, "min_block_size": min_block_size, - "make_refittable": True, "cache_built_engines": cache_built_engines, "reuse_cached_engines": reuse_cached_engines, }, @@ -157,7 +154,6 @@ def dynamo_compile(iterations=3): enabled_precisions=enabled_precisions, debug=debug, min_block_size=min_block_size, - make_refittable=True, cache_built_engines=cache_built_engines, reuse_cached_engines=reuse_cached_engines, engine_cache_size=1 << 30, # 1GB @@ -268,7 +264,6 @@ def torch_compile_my_cache(iterations=3): "enabled_precisions": enabled_precisions, "debug": debug, "min_block_size": min_block_size, - "make_refittable": True, "cache_built_engines": cache_built_engines, "reuse_cached_engines": reuse_cached_engines, "custom_engine_cache": engine_cache, diff --git a/examples/dynamo/mutable_torchtrt_module_example.py b/examples/dynamo/mutable_torchtrt_module_example.py index b68c9a11ee..3ea9fab9a5 100644 --- a/examples/dynamo/mutable_torchtrt_module_example.py +++ b/examples/dynamo/mutable_torchtrt_module_example.py @@ -31,7 +31,6 @@ settings = { "use_python": False, "enabled_precisions": {torch.float32}, - "make_refittable": True, } model = models.resnet18(pretrained=True).eval().to("cuda") @@ -80,7 +79,6 @@ "use_python_runtime": True, "enabled_precisions": {torch.float16}, "debug": True, - "make_refittable": True, } model_id = "runwayml/stable-diffusion-v1-5" diff --git a/examples/dynamo/refit_engine_example.py b/examples/dynamo/refit_engine_example.py index f93b097385..44f78abbc0 100644 --- a/examples/dynamo/refit_engine_example.py +++ b/examples/dynamo/refit_engine_example.py @@ -46,10 +46,7 @@ # Make a refittable Compilation Program # --------------------------------------- # -# The inital step is to compile a module and save it as with a normal. Note that there is an -# additional parameter `make_refittable` that is set to `True`. This parameter is used to -# indicate that the engine being built should support weight refitting later. Engines built without -# these setttings will not be able to be refit. +# The inital step is to compile a module and save it as with a normal. # # In this case we are going to compile a ResNet18 model with randomly initialized weights and save it. @@ -69,8 +66,6 @@ debug=debug, min_block_size=min_block_size, torch_executed_ops=torch_executed_ops, - make_refittable=True, - reuse_cached_engines=False, ) # Output is a torch.fx.GraphModule # Save the graph module as an exported program diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index a3ec343a6b..32253359cc 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -60,7 +60,6 @@ def compile( Set[Union[torch.dtype, dtype]], Tuple[Union[torch.dtype, dtype]] ] = _defaults.ENABLED_PRECISIONS, engine_capability: EngineCapability = _defaults.ENGINE_CAPABILITY, - make_refittable: bool = _defaults.MAKE_REFITTABLE, debug: bool = _defaults.DEBUG, num_avg_timing_iters: int = _defaults.NUM_AVG_TIMING_ITERS, workspace_size: int = _defaults.WORKSPACE_SIZE, @@ -132,7 +131,6 @@ def compile( assume_dynamic_shape_support (bool): Setting this to true enables the converters work for both dynamic and static shapes. Default: False sparse_weights (bool): Enable sparsity for convolution and fully connected layers. enabled_precision (Set(Union(torch.dtype, torch_tensorrt.dtype))): The set of datatypes that TensorRT can use when selecting kernels - refit (bool): Enable refitting debug (bool): Enable debuggable engine capability (torch_tensorrt.EngineCapability): Restrict kernel selection to safe gpu kernels or safe dla kernels num_avg_timing_iters (int): Number of averaging timing iterations used to select kernels @@ -188,14 +186,17 @@ def compile( if "refit" in kwargs.keys(): warnings.warn( - "Refit is deprecated. Please use make_refittable=True if you want to enable refitting of the engine.", + "`refit` is deprecated. All engines are refittable now. If you want to disable refitting, please open an issue on the Github repo with reasons.", + DeprecationWarning, + stacklevel=2, + ) + + if "make_refittable" in kwargs.keys(): + warnings.warn( + "`make_refittable` is deprecated. All engines are refittable now. If you want to disable refitting, please open an issue on the Github repo with reasons.", DeprecationWarning, stacklevel=2, ) - if make_refittable: - raise ValueError("Use flag make_refittable only. Flag refit is deprecated.") - else: - make_refittable = kwargs["refit"] engine_capability = EngineCapability._from(engine_capability) @@ -259,9 +260,6 @@ def compile( engine_cache = None if cache_built_engines or reuse_cached_engines: - assert ( - make_refittable - ), "Engine caching requires make_refittable to be set to True" engine_cache = ( custom_engine_cache if custom_engine_cache is not None @@ -292,7 +290,6 @@ def compile( "require_full_compilation": require_full_compilation, "disable_tf32": disable_tf32, "sparse_weights": sparse_weights, - "make_refittable": make_refittable, "engine_capability": engine_capability, "dla_sram_size": dla_sram_size, "dla_local_dram_size": dla_local_dram_size, @@ -539,7 +536,6 @@ def convert_exported_program_to_serialized_trt_engine( require_full_compilation: bool = _defaults.REQUIRE_FULL_COMPILATION, disable_tf32: bool = _defaults.DISABLE_TF32, sparse_weights: bool = _defaults.SPARSE_WEIGHTS, - make_refittable: bool = _defaults.MAKE_REFITTABLE, engine_capability: EngineCapability = _defaults.ENGINE_CAPABILITY, num_avg_timing_iters: int = _defaults.NUM_AVG_TIMING_ITERS, dla_sram_size: int = _defaults.DLA_SRAM_SIZE, @@ -601,7 +597,6 @@ def convert_exported_program_to_serialized_trt_engine( Only applicable for `ir="dynamo"`; has no effect for `torch.compile` path disable_tf32 (bool): Whether to disable TF32 computation for TRT layers sparse_weights (bool): Whether to allow the builder to use sparse weights - refit (bool): Whether to build a refittable engine engine_capability (trt.EngineCapability): Restrict kernel selection to safe gpu kernels or safe dla kernels num_avg_timing_iters (int): Number of averaging timing iterations used to select kernels dla_sram_size (int): Fast software managed RAM used by DLA to communicate within a layer. @@ -634,10 +629,17 @@ def convert_exported_program_to_serialized_trt_engine( ) if "refit" in kwargs.keys(): warnings.warn( - "Refit is deprecated. Please use make_refittable=True if you want to enable refitting of the engine.", + "`refit` is deprecated. All engines are refittable now. If you want to disable refitting, please open an issue on the Github repo with reasons.", DeprecationWarning, stacklevel=2, ) + if "make_refittable" in kwargs.keys(): + warnings.warn( + "`make_refittable` is deprecated. All engines are refittable now. If you want to disable refitting, please open an issue on the Github repo with reasons.", + DeprecationWarning, + stacklevel=2, + ) + if arg_inputs is None and inputs is None: raise AssertionError("'arg_inputs' and 'inputs' should not both be None.") @@ -680,7 +682,6 @@ def convert_exported_program_to_serialized_trt_engine( "require_full_compilation": require_full_compilation, "disable_tf32": disable_tf32, "sparse_weights": sparse_weights, - "make_refittable": make_refittable, "engine_capability": engine_capability, "num_avg_timing_iters": num_avg_timing_iters, "dla_sram_size": dla_sram_size, diff --git a/py/torch_tensorrt/dynamo/_defaults.py b/py/torch_tensorrt/dynamo/_defaults.py index 000f6e80ec..afa0a53f81 100644 --- a/py/torch_tensorrt/dynamo/_defaults.py +++ b/py/torch_tensorrt/dynamo/_defaults.py @@ -26,7 +26,6 @@ USE_PYTHON_RUNTIME = False USE_FAST_PARTITIONER = True ENABLE_EXPERIMENTAL_DECOMPOSITIONS = False -MAKE_REFITTABLE = False REQUIRE_FULL_COMPILATION = False DRYRUN = False HARDWARE_COMPATIBLE = False diff --git a/py/torch_tensorrt/dynamo/_refit.py b/py/torch_tensorrt/dynamo/_refit.py index 359dc0b3ff..20dfb982cb 100644 --- a/py/torch_tensorrt/dynamo/_refit.py +++ b/py/torch_tensorrt/dynamo/_refit.py @@ -253,7 +253,7 @@ def refit_module_weights( ] assert ( encoded_metadata != "" - ), "The engine provided is either not refittable or was built with a version of Torch-TensorRT that is too old, please recompile using the latest version with make_refittable=True" + ), "The engine provided is either not refittable or was built with a version of Torch-TensorRT that is too old, please recompile using the latest version" settings = TorchTensorRTModule.decode_metadata(encoded_metadata)["settings"] # Handle torch modules compiled_submodules_map = dict(compiled_submodules) @@ -270,10 +270,6 @@ def refit_module_weights( assert settings is not None - assert ( - settings.make_refittable - ), "Refitting is not enabled. Please recompile the engine with refit=True." - if settings.debug: set_log_level(logger.parent, logging.DEBUG) @@ -397,7 +393,7 @@ def refit_module_weights( if isinstance(compiled_submodule, PythonTorchTensorRTModule): engine = compiled_submodule.engine elif isinstance(compiled_submodule, TorchTensorRTModule): - engine_info = compiled_submodule.engine.__getstate__()[0] # type: ignore[index] + engine_info = compiled_submodule.engine.__getstate__()[0] engine = get_engine_from_encoded_engine( engine_info[ENGINE_IDX], runtime ) diff --git a/py/torch_tensorrt/dynamo/_settings.py b/py/torch_tensorrt/dynamo/_settings.py index 5ec22cbc8d..1a6d4ade5b 100644 --- a/py/torch_tensorrt/dynamo/_settings.py +++ b/py/torch_tensorrt/dynamo/_settings.py @@ -18,7 +18,6 @@ ENGINE_CAPABILITY, HARDWARE_COMPATIBLE, LAZY_ENGINE_INIT, - MAKE_REFITTABLE, MAX_AUX_STREAMS, MIN_BLOCK_SIZE, NUM_AVG_TIMING_ITERS, @@ -106,7 +105,6 @@ class CompilationSettings: disable_tf32: bool = DISABLE_TF32 assume_dynamic_shape_support: bool = ASSUME_DYNAMIC_SHAPE_SUPPORT sparse_weights: bool = SPARSE_WEIGHTS - make_refittable: bool = MAKE_REFITTABLE engine_capability: EngineCapability = field( default_factory=lambda: ENGINE_CAPABILITY ) @@ -133,11 +131,10 @@ class CompilationSettings: "optimization_level", "disable_tf32", "sparse_weights", - "make_refittable", "engine_capability", "hardware_compatible", - "refit_identical_engine_weights", "strip_engine_weights", + "refit_identical_engine_weights", ) diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index cb49ca7f5f..d0b3455ccc 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -287,17 +287,15 @@ def _populate_trt_builder_config( if self.compilation_settings.disable_tf32: builder_config.clear_flag(trt.BuilderFlag.TF32) - if self.compilation_settings.make_refittable: - if version.parse(trt.__version__) >= version.parse("10.0"): - if self.compilation_settings.refit_identical_engine_weights: - builder_config.set_flag(trt.BuilderFlag.REFIT_IDENTICAL) - else: - builder_config.set_flag(trt.BuilderFlag.REFIT) + if version.parse(trt.__version__) >= version.parse("10.0"): + if self.compilation_settings.refit_identical_engine_weights: + builder_config.set_flag(trt.BuilderFlag.REFIT_IDENTICAL) else: builder_config.set_flag(trt.BuilderFlag.REFIT) + else: + builder_config.set_flag(trt.BuilderFlag.REFIT) - if self.compilation_settings.strip_engine_weights: - builder_config.set_flag(trt.BuilderFlag.STRIP_PLAN) + builder_config.set_flag(trt.BuilderFlag.STRIP_PLAN) if strict_type_constraints: builder_config.set_flag(trt.BuilderFlag.STRICT_TYPES) @@ -632,8 +630,7 @@ def run( self._construct_trt_network_def() - if self.compilation_settings.make_refittable: - self._save_weight_mapping() + self._save_weight_mapping() build_engine_start_time = datetime.now() _LOGGER.info("Not found cached TRT engines. Start building engine.") @@ -652,7 +649,7 @@ def run( assert serialized_engine _LOGGER.info( - f"Build TRT engine elapsed time: {datetime.now() - build_engine_start_time}" + f"Build weight-stripped TRT engine elapsed time: {datetime.now() - build_engine_start_time}" ) _LOGGER.info(f"TRT Engine uses: {serialized_engine.nbytes} bytes of Memory") @@ -664,26 +661,11 @@ def run( self.engine_cache is not None and self.compilation_settings.cache_built_engines ): - assert ( - self.compilation_settings.make_refittable - ), "weight-stripped engines must be refittable, please set make_refittable=True" - - # no matter what compilation_settings is, we cache the weight-stripped engine - if self.compilation_settings.strip_engine_weights: - weight_stripped_serialized_engine = serialized_engine - else: - runtime = trt.Runtime(TRT_LOGGER) - engine = runtime.deserialize_cuda_engine(serialized_engine) - serialization_config = engine.create_serialization_config() - serialization_config.set_flag(trt.SerializationFlag.EXCLUDE_WEIGHTS) - weight_stripped_serialized_engine = engine.serialize_with_config( - serialization_config - ) - + # Cache the weight-stripped engine self.engine_cache.insert( hash_val, ( - weight_stripped_serialized_engine, + serialized_engine, self._input_names, self._output_names, self.input_specs, @@ -692,6 +674,26 @@ def run( ), ) + if not self.compilation_settings.strip_engine_weights: + # Refit the engine with the original weights + runtime = trt.Runtime(TRT_LOGGER) + engine = runtime.deserialize_cuda_engine(serialized_engine) + + from torch_tensorrt.dynamo._refit import _refit_single_trt_engine_with_gm + + _refit_single_trt_engine_with_gm( + new_gm=self.module, + old_engine=engine, + input_list=self.input_specs, + settings=self.compilation_settings, + weight_name_map=self.weight_name_map, + ) + + # Serialize the refitted engine where the EXCLUDE_WEIGHTS flag must be cleared + serialization_config = engine.create_serialization_config() + serialization_config.clear_flag(trt.SerializationFlag.EXCLUDE_WEIGHTS) + serialized_engine = engine.serialize_with_config(serialization_config) + with io.BytesIO() as engine_bytes: engine_bytes.write(serialized_engine) engine_str = engine_bytes.getvalue() diff --git a/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py index 9abd896d50..ac2bf1512f 100644 --- a/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py @@ -65,7 +65,6 @@ def __init__( Union[torch.dtype, dtype] ] = _defaults.ENABLED_PRECISIONS, engine_capability: EngineCapability = _defaults.ENGINE_CAPABILITY, - make_refittable: bool = _defaults.MAKE_REFITTABLE, debug: bool = _defaults.DEBUG, num_avg_timing_iters: int = _defaults.NUM_AVG_TIMING_ITERS, workspace_size: int = _defaults.WORKSPACE_SIZE, @@ -151,9 +150,6 @@ def __init__( self.kwarg_inputs: dict[str, Any] = {} device = to_torch_tensorrt_device(device) enabled_precisions = {dtype._from(p) for p in enabled_precisions} - assert ( - make_refittable - ), "'make_refittable' has to be True for a MutableTorchTensorRTModule." compilation_options = { "enabled_precisions": ( enabled_precisions @@ -180,7 +176,6 @@ def __init__( "require_full_compilation": require_full_compilation, "disable_tf32": disable_tf32, "sparse_weights": sparse_weights, - "make_refittable": make_refittable, "engine_capability": engine_capability, "dla_sram_size": dla_sram_size, "dla_local_dram_size": dla_local_dram_size, diff --git a/py/torch_tensorrt/dynamo/utils.py b/py/torch_tensorrt/dynamo/utils.py index a85494239e..f40db6ab0e 100644 --- a/py/torch_tensorrt/dynamo/utils.py +++ b/py/torch_tensorrt/dynamo/utils.py @@ -522,10 +522,6 @@ def parse_dynamo_kwargs( engine_cache = None if kwargs.get("cache_built_engines") or kwargs.get("reuse_cached_engines"): - assert kwargs.get( - "make_refittable" - ), "Engine caching requires make_refittable to be set to True" - if kwargs.get("custom_engine_cache") is not None: engine_cache = kwargs.get("custom_engine_cache") else: diff --git a/tests/py/dynamo/models/test_engine_cache.py b/tests/py/dynamo/models/test_engine_cache.py index 5dcdfe4ae9..3502e430f8 100644 --- a/tests/py/dynamo/models/test_engine_cache.py +++ b/tests/py/dynamo/models/test_engine_cache.py @@ -74,7 +74,7 @@ def test_reexport_is_equal(self): ), ) settings1 = CompilationSettings( - make_refittable=True, cache_built_engines=True, reuse_cached_engines=True + cache_built_engines=True, reuse_cached_engines=True ) hash1 = BaseEngineCache.get_hash(exp_program1.module(), input_specs1, settings1) @@ -89,7 +89,7 @@ def test_reexport_is_equal(self): ), ) settings2 = CompilationSettings( - make_refittable=True, cache_built_engines=True, reuse_cached_engines=True + cache_built_engines=True, reuse_cached_engines=True ) hash2 = BaseEngineCache.get_hash(exp_program2.module(), input_specs2, settings2) @@ -111,7 +111,7 @@ def test_input_shape_change_is_not_equal(self): ), ) settings1 = CompilationSettings( - make_refittable=True, cache_built_engines=True, reuse_cached_engines=True + cache_built_engines=True, reuse_cached_engines=True ) hash1 = BaseEngineCache.get_hash(exp_program1.module(), input_specs1, settings1) @@ -126,7 +126,7 @@ def test_input_shape_change_is_not_equal(self): ), ) settings2 = CompilationSettings( - make_refittable=True, cache_built_engines=True, reuse_cached_engines=True + cache_built_engines=True, reuse_cached_engines=True ) hash2 = BaseEngineCache.get_hash(exp_program2.module(), input_specs2, settings2) @@ -148,7 +148,6 @@ def test_engine_settings_is_not_equal(self): ), ) settings1 = CompilationSettings( - make_refittable=True, cache_built_engines=True, reuse_cached_engines=True, enabled_precisions={torch.float32}, @@ -166,7 +165,6 @@ def test_engine_settings_is_not_equal(self): ), ) settings2 = CompilationSettings( - make_refittable=True, cache_built_engines=True, reuse_cached_engines=True, enabled_precisions={torch.float32, torch.float16}, @@ -225,7 +223,6 @@ def remove_timing_cache(path=TIMING_CACHE_PATH): enabled_precisions={torch.float}, debug=False, min_block_size=1, - make_refittable=True, cache_built_engines=cache_built_engines, reuse_cached_engines=reuse_cached_engines, engine_cache_dir=engine_cache_dir, @@ -289,7 +286,6 @@ def test_dynamo_compile_with_custom_engine_cache(self): enabled_precisions={torch.float}, debug=False, min_block_size=1, - make_refittable=True, cache_built_engines=cache_built_engines, reuse_cached_engines=reuse_cached_engines, custom_engine_cache=custom_engine_cache, @@ -336,7 +332,6 @@ def test_dynamo_compile_change_input_shape(self): enabled_precisions={torch.float}, debug=False, min_block_size=1, - make_refittable=True, cache_built_engines=True, reuse_cached_engines=True, ) @@ -391,7 +386,6 @@ def remove_timing_cache(path=TIMING_CACHE_PATH): "enabled_precisions": {torch.float}, "debug": False, "min_block_size": 1, - "make_refittable": True, "cache_built_engines": cache_built_engines, "reuse_cached_engines": reuse_cached_engines, "engine_cache_dir": engine_cache_dir, @@ -456,7 +450,6 @@ def test_torch_compile_with_custom_engine_cache(self): "enabled_precisions": {torch.float}, "debug": False, "min_block_size": 1, - "make_refittable": True, "cache_built_engines": cache_built_engines, "reuse_cached_engines": reuse_cached_engines, "custom_engine_cache": custom_engine_cache, @@ -508,7 +501,6 @@ def test_torch_compile_change_input_shape(self): "enabled_precisions": {torch.float}, "debug": False, "min_block_size": 1, - "make_refittable": True, "cache_built_engines": True, "reuse_cached_engines": True, "custom_engine_cache": custom_engine_cache, diff --git a/tests/py/dynamo/models/test_model_refit.py b/tests/py/dynamo/models/test_model_refit.py index 46ffa7b6d8..ffc38fbacd 100644 --- a/tests/py/dynamo/models/test_model_refit.py +++ b/tests/py/dynamo/models/test_model_refit.py @@ -1,9 +1,7 @@ import os import tempfile -import time import unittest -import numpy as np import pytest import tensorrt as trt import torch @@ -57,8 +55,6 @@ def test_mapping(): enabled_precisions=enabled_precisions, debug=debug, min_block_size=min_block_size, - make_refittable=True, - reuse_cached_engines=False, ) settings = trt_gm._run_on_acc_0.settings runtime = trt.Runtime(TRT_LOGGER) @@ -110,8 +106,6 @@ def test_refit_one_engine_with_weightmap(): enabled_precisions=enabled_precisions, debug=debug, min_block_size=min_block_size, - make_refittable=True, - reuse_cached_engines=False, ) new_trt_gm = refit_module_weights( @@ -161,8 +155,6 @@ def test_refit_one_engine_no_map_with_weightmap(): enabled_precisions=enabled_precisions, debug=debug, min_block_size=min_block_size, - make_refittable=True, - reuse_cached_engines=False, ) trt_gm._run_on_acc_0.weight_name_map = None @@ -213,8 +205,6 @@ def test_refit_one_engine_with_wrong_weightmap(): enabled_precisions=enabled_precisions, debug=debug, min_block_size=min_block_size, - make_refittable=True, - reuse_cached_engines=False, ) # Manually Deleted all batch norm layer. This suppose to fail the fast refit trt_gm._run_on_acc_0.weight_name_map = { @@ -271,8 +261,6 @@ def test_refit_one_engine_bert_with_weightmap(): enabled_precisions=enabled_precisions, debug=debug, min_block_size=min_block_size, - make_refittable=True, - reuse_cached_engines=False, ) new_trt_gm = refit_module_weights( @@ -325,8 +313,6 @@ def test_refit_one_engine_inline_runtime__with_weightmap(): enabled_precisions=enabled_precisions, debug=debug, min_block_size=min_block_size, - make_refittable=True, - reuse_cached_engines=False, ) torchtrt.save(trt_gm, trt_ep_path, inputs=inputs) trt_gm = torch.export.load(trt_ep_path) @@ -372,8 +358,6 @@ def test_refit_one_engine_python_runtime_with_weightmap(): enabled_precisions=enabled_precisions, debug=debug, min_block_size=min_block_size, - make_refittable=True, - reuse_cached_engines=False, ) new_trt_gm = refit_module_weights( @@ -443,7 +427,6 @@ def forward(self, x): enabled_precisions=enabled_precisions, debug=debug, min_block_size=min_block_size, - make_refittable=True, torch_executed_ops=torch_executed_ops, reuse_cached_engines=False, ) @@ -494,8 +477,6 @@ def test_refit_one_engine_without_weightmap(): enabled_precisions=enabled_precisions, debug=debug, min_block_size=min_block_size, - make_refittable=True, - reuse_cached_engines=False, ) new_trt_gm = refit_module_weights( @@ -546,8 +527,6 @@ def test_refit_one_engine_bert_without_weightmap(): enabled_precisions=enabled_precisions, debug=debug, min_block_size=min_block_size, - make_refittable=True, - reuse_cached_engines=False, ) new_trt_gm = refit_module_weights( @@ -600,8 +579,6 @@ def test_refit_one_engine_inline_runtime_without_weightmap(): enabled_precisions=enabled_precisions, debug=debug, min_block_size=min_block_size, - make_refittable=True, - reuse_cached_engines=False, ) torchtrt.save(trt_gm, trt_ep_path, inputs=inputs) trt_gm = torch.export.load(trt_ep_path) @@ -647,8 +624,6 @@ def test_refit_one_engine_python_runtime_without_weightmap(): enabled_precisions=enabled_precisions, debug=debug, min_block_size=min_block_size, - make_refittable=True, - reuse_cached_engines=False, ) new_trt_gm = refit_module_weights( @@ -718,7 +693,6 @@ def forward(self, x): enabled_precisions=enabled_precisions, debug=debug, min_block_size=min_block_size, - make_refittable=True, torch_executed_ops=torch_executed_ops, reuse_cached_engines=False, ) diff --git a/tests/py/dynamo/models/test_weight_stripped_engine.py b/tests/py/dynamo/models/test_weight_stripped_engine.py index 196800e758..1454eb4542 100644 --- a/tests/py/dynamo/models/test_weight_stripped_engine.py +++ b/tests/py/dynamo/models/test_weight_stripped_engine.py @@ -22,14 +22,12 @@ def test_weight_stripped_engine_sizes(self): weight_included_engine = convert_exported_program_to_serialized_trt_engine( exp_program, example_inputs, - make_refittable=False, strip_engine_weights=False, refit_identical_engine_weights=False, ) weight_stripped_engine = convert_exported_program_to_serialized_trt_engine( exp_program, example_inputs, - make_refittable=True, strip_engine_weights=True, refit_identical_engine_weights=False, ) @@ -37,7 +35,6 @@ def test_weight_stripped_engine_sizes(self): convert_exported_program_to_serialized_trt_engine( exp_program, example_inputs, - make_refittable=True, strip_engine_weights=True, refit_identical_engine_weights=True, ) @@ -70,11 +67,10 @@ def test_weight_stripped_engine_results(self): enabled_precisions={torch.float}, debug=False, min_block_size=1, - make_refittable=True, cache_built_engines=False, reuse_cached_engines=False, - refit_identical_engine_weights=False, strip_engine_weights=True, + refit_identical_engine_weights=False, ) output = trt_gm(*inputs) assertions.assertEqual( @@ -83,6 +79,7 @@ def test_weight_stripped_engine_results(self): from torch_tensorrt.dynamo._refit import refit_module_weights + # Refit the weight-stripped engine with the same weights refitted_trt_gm = refit_module_weights(trt_gm, exp_program) refitted_output = refitted_trt_gm(*inputs) assertions.assertNotEqual( @@ -99,7 +96,6 @@ def test_weight_stripped_engine_results(self): "enabled_precisions": {torch.float}, "debug": False, "min_block_size": 1, - "make_refittable": True, "cache_built_engines": False, "reuse_cached_engines": False, "refit_identical_engine_weights": False, @@ -125,7 +121,6 @@ def test_weight_stripped_engine_with_engine_cache(self): weight_included_engine = convert_exported_program_to_serialized_trt_engine( exp_program, example_inputs, - make_refittable=False, strip_engine_weights=False, refit_identical_engine_weights=False, ) @@ -137,9 +132,8 @@ def test_weight_stripped_engine_with_engine_cache(self): enabled_precisions={torch.float}, debug=False, min_block_size=1, - make_refittable=True, - refit_identical_engine_weights=True, strip_engine_weights=False, # engine cache will save the stripped engine even if this is False + refit_identical_engine_weights=True, cache_built_engines=True, reuse_cached_engines=True, engine_cache_dir=engine_cache_dir, @@ -203,21 +197,26 @@ def remove_timing_cache(path=TIMING_CACHE_PATH): enabled_precisions={torch.float}, debug=False, min_block_size=1, - make_refittable=True, - refit_identical_engine_weights=False, - strip_engine_weights=False, cache_built_engines=cache_built_engines, reuse_cached_engines=reuse_cached_engines, engine_cache_dir=engine_cache_dir, + strip_engine_weights=False, + refit_identical_engine_weights=False, ) end.record() torch.cuda.synchronize() times.append(start.elapsed_time(end)) results.append(trt_gm(*inputs)) - assertions.assertNotEqual(results[0].sum(), 0, msg="results[0] are all zeros") - assertions.assertNotEqual(results[1].sum(), 0, msg="results[1] are all zeros") - assertions.assertNotEqual(results[2].sum(), 0, msg="results[2] are all zeros") + assertions.assertNotEqual( + results[0].sum(), 0, msg="results[0] shouldn't be all zeros" + ) + assertions.assertNotEqual( + results[1].sum(), 0, msg="results[1] shouldn't be all zeros" + ) + assertions.assertNotEqual( + results[2].sum(), 0, msg="results[2] shouldn't be all zeros" + ) cos_sim = cosine_similarity(results[0], results[1]) assertions.assertTrue( @@ -278,13 +277,12 @@ def remove_timing_cache(path=TIMING_CACHE_PATH): "enabled_precisions": {torch.float}, "debug": False, "min_block_size": 1, - "make_refittable": True, "cache_built_engines": cache_built_engines, "reuse_cached_engines": reuse_cached_engines, "engine_cache_dir": engine_cache_dir, "torch_executed_ops": {"torch.ops.aten.relu.default"}, - "refit_identical_engine_weights": True, "strip_engine_weights": False, + "refit_identical_engine_weights": True, }, ) results.append(compiled_model(*inputs)) # trigger the compilation @@ -292,9 +290,15 @@ def remove_timing_cache(path=TIMING_CACHE_PATH): torch.cuda.synchronize() times.append(start.elapsed_time(end)) - assertions.assertNotEqual(results[0].sum(), 0, msg="results[0] are all zeros") - assertions.assertNotEqual(results[1].sum(), 0, msg="results[1] are all zeros") - assertions.assertNotEqual(results[2].sum(), 0, msg="results[2] are all zeros") + assertions.assertNotEqual( + results[0].sum(), 0, msg="results[0] shouldn't be all zeros" + ) + assertions.assertNotEqual( + results[1].sum(), 0, msg="results[1] shouldn't be all zeros" + ) + assertions.assertNotEqual( + results[2].sum(), 0, msg="results[2] shouldn't be all zeros" + ) cos_sim = cosine_similarity(results[0], results[1]) assertions.assertTrue( diff --git a/tests/py/dynamo/runtime/test_mutable_torchtrt_module.py b/tests/py/dynamo/runtime/test_mutable_torchtrt_module.py index b52530efd1..fd9fa4e1e0 100644 --- a/tests/py/dynamo/runtime/test_mutable_torchtrt_module.py +++ b/tests/py/dynamo/runtime/test_mutable_torchtrt_module.py @@ -49,7 +49,6 @@ def test_resnet18(): compile_spec = { "use_python_runtime": False, "enabled_precisions": {torch.float32}, - "make_refittable": True, } model = models.resnet18(pretrained=True).eval().to("cuda") @@ -89,7 +88,6 @@ def test_save(): compile_spec = { "use_python_runtime": False, "enabled_precisions": {torch.float32}, - "make_refittable": True, } model = models.resnet18(pretrained=True).eval().to("cuda") @@ -123,7 +121,6 @@ def test_resnet18_modify_attribute(): compile_spec = { "use_python_runtime": False, "enabled_precisions": {torch.float32}, - "make_refittable": True, } model = models.resnet18(pretrained=True).eval().to("cuda") @@ -164,7 +161,6 @@ def test_resnet18_modify_attribute_no_refit(): compile_spec = { "use_python_runtime": False, "enabled_precisions": {torch.float32}, - "make_refittable": True, } model = models.resnet18(pretrained=True).eval().to("cuda") @@ -243,7 +239,6 @@ def forward(self, x, b=5, c=None, d=None): "optimization_level": 1, "min_block_size": 1, "ir": "dynamo", - "make_refittable": True, } mutable_module = torch_trt.MutableTorchTensorRTModule(model, **compile_spec) @@ -304,7 +299,6 @@ def set_weights(self): "optimization_level": 1, "min_block_size": 1, "ir": "dynamo", - "make_refittable": True, } mutable_module = torch_trt.MutableTorchTensorRTModule(model, **compile_spec) @@ -367,7 +361,6 @@ def set_layer(self): "optimization_level": 1, "min_block_size": 1, "ir": "dynamo", - "make_refittable": True, } mutable_module = torch_trt.MutableTorchTensorRTModule(model, **compile_spec) @@ -436,7 +429,6 @@ def forward(self, x, b=5, c=None, d=None): "optimization_level": 1, "min_block_size": 1, "ir": "dynamo", - "make_refittable": True, } mutable_module = torch_trt.MutableTorchTensorRTModule(model, **compile_spec) From 9f6a771ee219d18e465d4fec549c1963443a89bf Mon Sep 17 00:00:00 2001 From: Evan Li Date: Tue, 1 Oct 2024 00:03:37 -0700 Subject: [PATCH 04/52] samll fix --- py/torch_tensorrt/dynamo/_settings.py | 1 - 1 file changed, 1 deletion(-) diff --git a/py/torch_tensorrt/dynamo/_settings.py b/py/torch_tensorrt/dynamo/_settings.py index 1a6d4ade5b..6a8d37cbfc 100644 --- a/py/torch_tensorrt/dynamo/_settings.py +++ b/py/torch_tensorrt/dynamo/_settings.py @@ -68,7 +68,6 @@ class CompilationSettings: assume_dynamic_shape_support (bool): Setting this to true enables the converters work for both dynamic and static shapes. Default: False disable_tf32 (bool): Whether to disable TF32 computation for TRT layers sparse_weights (bool): Whether to allow the builder to use sparse weights - refit (bool): Whether to build a refittable engine engine_capability (trt.EngineCapability): Restrict kernel selection to safe gpu kernels or safe dla kernels num_avg_timing_iters (int): Number of averaging timing iterations used to select kernels dla_sram_size (int): Fast software managed RAM used by DLA to communicate within a layer. From 7ea3c0f6dbb664c59eda98af99b526d45460bdd8 Mon Sep 17 00:00:00 2001 From: Evan Li Date: Tue, 1 Oct 2024 14:56:03 -0700 Subject: [PATCH 05/52] remove make_refittable --- .../dynamo/conversion/aten_ops_converters.py | 11 ----------- tests/py/dynamo/conversion/harness.py | 7 +------ tests/py/dynamo/conversion/test_cumsum_aten.py | 4 ---- tests/py/dynamo/conversion/test_embedding_bag_aten.py | 4 ---- tests/py/dynamo/models/test_model_refit.py | 1 - 5 files changed, 1 insertion(+), 26 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 60a48d98e3..0ae16731d2 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -273,9 +273,6 @@ def aten_ops_embedding( def embedding_bag_validator(node: Node, settings: CompilationSettings = None) -> bool: # Embedding bag op is not refitable - if settings.make_refittable: - return False - if not one_user_validator(node): return False meta = node.args[1].meta @@ -929,16 +926,8 @@ def aten_ops_slice( ) -def refit_validator(node: Node, settings: CompilationSettings = None) -> bool: - # cumsum op is not refitable - if settings and settings.make_refittable: - return False - return True - - @dynamo_tensorrt_converter( torch.ops.aten.cumsum.default, - capability_validator=refit_validator, supports_dynamic_shapes=True, ) @enforce_tensor_types( diff --git a/tests/py/dynamo/conversion/harness.py b/tests/py/dynamo/conversion/harness.py index 632b73e2f3..d29e3182dc 100644 --- a/tests/py/dynamo/conversion/harness.py +++ b/tests/py/dynamo/conversion/harness.py @@ -263,7 +263,6 @@ def run_test( enable_passes=False, propagate_shapes=False, int32_reqd=False, - make_refittable=False, ): mod = self.generate_graph( mod, @@ -279,7 +278,6 @@ def run_test( enabled_precisions={dtype._from(precision)}, truncate_double=True, debug=True, - make_refittable=make_refittable, ) num_inputs = len(inputs) @@ -348,7 +346,6 @@ def run_test_compare_tensor_attributes_only( output_dtypes=None, use_dynamo_tracer=False, enable_passes=False, - make_refittable=False, ): mod = self.generate_graph( mod, @@ -362,7 +359,6 @@ def run_test_compare_tensor_attributes_only( enabled_precisions={dtype._from(precision)}, truncate_double=True, debug=True, - make_refittable=make_refittable, ) interp = TRTInterpreter( @@ -388,7 +384,6 @@ def run_test_with_dynamic_shape( pyt_inputs=None, propagate_shapes=False, check_dtype=True, - make_refittable=False, ): mod = self.generate_graph( mod, @@ -401,7 +396,7 @@ def run_test_with_dynamic_shape( # Previous instance of the interpreter auto-casted 64-bit inputs # We replicate this behavior here compilation_settings = CompilationSettings( - truncate_double=True, make_refittable=make_refittable + truncate_double=True, ) if check_dtype: diff --git a/tests/py/dynamo/conversion/test_cumsum_aten.py b/tests/py/dynamo/conversion/test_cumsum_aten.py index 1c32be6dd6..4143401bd4 100644 --- a/tests/py/dynamo/conversion/test_cumsum_aten.py +++ b/tests/py/dynamo/conversion/test_cumsum_aten.py @@ -24,7 +24,6 @@ def forward(self, x): self.run_test( Cumsum(), inputs, - make_refittable=False, ) @parameterized.expand( @@ -44,7 +43,6 @@ def forward(self, x): self.run_test( Cumsum(), inputs, - make_refittable=False, ) @parameterized.expand( @@ -65,7 +63,6 @@ def forward(self, x): self.run_test( Cumsum(), inputs, - make_refittable=False, ) @parameterized.expand( @@ -95,7 +92,6 @@ def forward(self, x): self.run_test_with_dynamic_shape( Cumsum(), inputs, - make_refittable=False, ) diff --git a/tests/py/dynamo/conversion/test_embedding_bag_aten.py b/tests/py/dynamo/conversion/test_embedding_bag_aten.py index 6543ac2306..3fef3d70cf 100644 --- a/tests/py/dynamo/conversion/test_embedding_bag_aten.py +++ b/tests/py/dynamo/conversion/test_embedding_bag_aten.py @@ -148,7 +148,6 @@ def forward(self, weight, indices): precision=weight.dtype, enable_passes=True, propagate_shapes=True, - make_refittable=False, ) @parameterized.expand( @@ -346,7 +345,6 @@ def forward(self, weight, indices, offsets): precision=weight.dtype, enable_passes=True, propagate_shapes=True, - make_refittable=False, ) @parameterized.expand( @@ -411,7 +409,6 @@ def forward(self, weight, indices, offsets): precision=weight.dtype, enable_passes=True, propagate_shapes=True, - make_refittable=False, ) @parameterized.expand( @@ -493,7 +490,6 @@ def forward(self, weights, indices, offsets, per_sample_weights=None): min_block_size=1, cache_built_engines=False, reuse_cached_engines=False, - make_refittable=False, ) # use the inputs with different shape to inference: if per_sample_weights is None: diff --git a/tests/py/dynamo/models/test_model_refit.py b/tests/py/dynamo/models/test_model_refit.py index ffc38fbacd..b505282733 100644 --- a/tests/py/dynamo/models/test_model_refit.py +++ b/tests/py/dynamo/models/test_model_refit.py @@ -746,7 +746,6 @@ def forward(self, x): enabled_precisions={torch.float}, debug=True, min_block_size=1, - make_refittable=True, ) num_pyt_segments = len( From bf7553b0b51e98c29fa56736fa6c25c2f8dd5b39 Mon Sep 17 00:00:00 2001 From: Evan Li Date: Tue, 1 Oct 2024 21:22:05 -0700 Subject: [PATCH 06/52] fast refit -> slow refit --- py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index d0b3455ccc..712eec6ef9 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -605,7 +605,7 @@ def run( old_engine=engine, input_list=self.input_specs, settings=self.compilation_settings, - weight_name_map=self.weight_name_map, + weight_name_map=None, ) # Serialize the refitted engine where the EXCLUDE_WEIGHTS flag must be cleared @@ -686,7 +686,7 @@ def run( old_engine=engine, input_list=self.input_specs, settings=self.compilation_settings, - weight_name_map=self.weight_name_map, + weight_name_map=None, ) # Serialize the refitted engine where the EXCLUDE_WEIGHTS flag must be cleared From 46e9bc875370735281ed1f042a9886bab8340489 Mon Sep 17 00:00:00 2001 From: Evan Li Date: Wed, 2 Oct 2024 00:33:18 -0700 Subject: [PATCH 07/52] fix np.bool_, group_norm --- py/torch_tensorrt/_enums.py | 4 ++-- tests/py/dynamo/conversion/test_group_norm_aten.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/py/torch_tensorrt/_enums.py b/py/torch_tensorrt/_enums.py index a580e6efbb..eaefb68ce5 100644 --- a/py/torch_tensorrt/_enums.py +++ b/py/torch_tensorrt/_enums.py @@ -220,7 +220,7 @@ def _from( return dtype.f32 elif t == np.float64: return dtype.f64 - elif t == np.bool: + elif t == np.bool_: return dtype.b # TODO: Consider using ml_dtypes when issues like this are resolved: # https://github.com/pytorch/pytorch/issues/109873 @@ -1384,7 +1384,7 @@ def current_platform(cls) -> Platform: def __str__(self) -> str: return str(self.name) - @needs_torch_tensorrt_runtime + @needs_torch_tensorrt_runtime # type: ignore def _to_serialized_rt_platform(self) -> str: val: str = torch.ops.tensorrt._platform_unknown() diff --git a/tests/py/dynamo/conversion/test_group_norm_aten.py b/tests/py/dynamo/conversion/test_group_norm_aten.py index 617166d0c4..b62be920f9 100644 --- a/tests/py/dynamo/conversion/test_group_norm_aten.py +++ b/tests/py/dynamo/conversion/test_group_norm_aten.py @@ -135,10 +135,10 @@ def forward(self, x): @parameterized.expand( [ - (5, 4, 4, 2, (2, 4, 2), (3, 4, 2), (5, 4, 4)), - (5, 4, 2 * 2, 2, (2, 4, 2, 2), (3, 4, 2, 2), (5, 4, 2, 2)), - (5, 9, 6 * 3, 3, (3, 9, 3, 3), (4, 9, 3, 3), (5, 9, 6, 3)), - (8, 9, 6 * 6, 3, (3, 9, 2, 3, 2), (5, 9, 3, 3, 2), (8, 9, 6, 3, 2)), + (5, 4, 4, 2, (2, 4, 2), (5, 4, 4), (5, 4, 4)), + (5, 4, 2 * 2, 2, (2, 4, 2, 2), (5, 4, 2, 2), (5, 4, 2, 2)), + (5, 9, 6 * 3, 3, (3, 9, 3, 3), (5, 9, 6, 3), (5, 9, 6, 3)), + (8, 9, 6 * 6, 3, (3, 9, 2, 3, 2), (8, 9, 6, 3, 2), (8, 9, 6, 3, 2)), ] ) def test_groupnorm_with_dynamic_shape( From d783fdd06e32699c673c87d2239d5c918eb0de32 Mon Sep 17 00:00:00 2001 From: Evan Li Date: Wed, 2 Oct 2024 12:59:38 -0700 Subject: [PATCH 08/52] add immutable_weights --- py/torch_tensorrt/dynamo/_compiler.py | 14 ++- py/torch_tensorrt/dynamo/_defaults.py | 1 + py/torch_tensorrt/dynamo/_settings.py | 4 + .../dynamo/conversion/_TRTInterpreter.py | 99 ++++++++++--------- .../dynamo/conversion/aten_ops_converters.py | 65 ++++++++---- tests/py/dynamo/conversion/harness.py | 6 ++ 6 files changed, 121 insertions(+), 68 deletions(-) diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 32253359cc..352e81950e 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -91,6 +91,7 @@ def compile( use_fp32_acc: bool = _defaults.USE_FP32_ACC, refit_identical_engine_weights: bool = _defaults.REFIT_IDENTICAL_ENGINE_WEIGHTS, strip_engine_weights: bool = _defaults.STRIP_ENGINE_WEIGHTS, + immutable_weights: bool = _defaults.IMMUTABLE_WEIGHTS, **kwargs: Any, ) -> torch.fx.GraphModule: """Compile an ExportedProgram module for NVIDIA GPUs using TensorRT @@ -164,6 +165,7 @@ def compile( use_fp32_acc (bool): This option inserts cast to FP32 nodes around matmul layers and TensorRT ensures the accumulation of matmul happens in FP32. Use this only when FP16 precision is configured in enabled_precisions. refit_identical_engine_weights (bool): Refit engines with identical weights. This is useful when the same model is compiled multiple times with different inputs and the weights are the same. This will save time by reusing the same engine for different inputs. strip_engine_weights (bool): Strip engine weights from the serialized engine. This is useful when the engine is to be deployed in an environment where the weights are not required. + immutable_weights (bool): Build non-refittable engines. This is useful for some layers that are not refittable. If this argument is set, `strip_engine_weights` and `refit_identical_engine_weights` will be ignored. **kwargs: Any, Returns: torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT @@ -186,14 +188,14 @@ def compile( if "refit" in kwargs.keys(): warnings.warn( - "`refit` is deprecated. All engines are refittable now. If you want to disable refitting, please open an issue on the Github repo with reasons.", + "`refit` is deprecated. Engines are refittable by default. Please set immutable_weights=True to build a non-refittable engine whose weights will be fixed.", DeprecationWarning, stacklevel=2, ) if "make_refittable" in kwargs.keys(): warnings.warn( - "`make_refittable` is deprecated. All engines are refittable now. If you want to disable refitting, please open an issue on the Github repo with reasons.", + "`make_refittable` is deprecated. Engines are refittable by default. Please set immutable_weights=True to build a non-refittable engine whose weights will be fixed.", DeprecationWarning, stacklevel=2, ) @@ -304,6 +306,7 @@ def compile( "use_fp32_acc": use_fp32_acc, "refit_identical_engine_weights": refit_identical_engine_weights, "strip_engine_weights": strip_engine_weights, + "immutable_weights": immutable_weights, } settings = CompilationSettings(**compilation_options) @@ -548,6 +551,7 @@ def convert_exported_program_to_serialized_trt_engine( use_fp32_acc: bool = _defaults.USE_FP32_ACC, refit_identical_engine_weights: bool = _defaults.REFIT_IDENTICAL_ENGINE_WEIGHTS, strip_engine_weights: bool = _defaults.STRIP_ENGINE_WEIGHTS, + immutable_weights: bool = _defaults.IMMUTABLE_WEIGHTS, **kwargs: Any, ) -> bytes: """Convert an ExportedProgram to a serialized TensorRT engine @@ -609,6 +613,7 @@ def convert_exported_program_to_serialized_trt_engine( use_fp32_acc (bool): This option inserts cast to FP32 nodes around matmul layers and TensorRT ensures the accumulation of matmul happens in FP32. Use this only when FP16 precision is configured in enabled_precisions. refit_identical_engine_weights (bool): Refit engines with identical weights. This is useful when the same model is compiled multiple times with different inputs and the weights are the same. This will save time by reusing the same engine for different inputs. strip_engine_weights (bool): Strip engine weights from the serialized engine. This is useful when the engine is to be deployed in an environment where the weights are not required. + immutable_weights (bool): Build non-refittable engines. This is useful for some layers that are not refittable. If this argument is set, `strip_engine_weights` and `refit_identical_engine_weights` will be ignored. Returns: bytes: Serialized TensorRT engine, can either be saved to a file or deserialized via TensorRT APIs """ @@ -629,13 +634,13 @@ def convert_exported_program_to_serialized_trt_engine( ) if "refit" in kwargs.keys(): warnings.warn( - "`refit` is deprecated. All engines are refittable now. If you want to disable refitting, please open an issue on the Github repo with reasons.", + "`refit` is deprecated. Engines are refittable by default. Please set immutable_weights=True to build a non-refittable engine whose weights will be fixed.", DeprecationWarning, stacklevel=2, ) if "make_refittable" in kwargs.keys(): warnings.warn( - "`make_refittable` is deprecated. All engines are refittable now. If you want to disable refitting, please open an issue on the Github repo with reasons.", + "`make_refittable` is deprecated. Engines are refittable by default. Please set immutable_weights=True to build a non-refittable engine whose weights will be fixed.", DeprecationWarning, stacklevel=2, ) @@ -692,6 +697,7 @@ def convert_exported_program_to_serialized_trt_engine( "use_fp32_acc": use_fp32_acc, "refit_identical_engine_weights": refit_identical_engine_weights, "strip_engine_weights": strip_engine_weights, + "immutable_weights": immutable_weights, } exported_program = pre_export_lowering(exported_program) diff --git a/py/torch_tensorrt/dynamo/_defaults.py b/py/torch_tensorrt/dynamo/_defaults.py index afa0a53f81..822194bb30 100644 --- a/py/torch_tensorrt/dynamo/_defaults.py +++ b/py/torch_tensorrt/dynamo/_defaults.py @@ -43,6 +43,7 @@ USE_FP32_ACC = False REFIT_IDENTICAL_ENGINE_WEIGHTS = False STRIP_ENGINE_WEIGHTS = False +IMMUTABLE_WEIGHTS = False def default_device() -> Device: diff --git a/py/torch_tensorrt/dynamo/_settings.py b/py/torch_tensorrt/dynamo/_settings.py index 6a8d37cbfc..51bdc2c707 100644 --- a/py/torch_tensorrt/dynamo/_settings.py +++ b/py/torch_tensorrt/dynamo/_settings.py @@ -17,6 +17,7 @@ ENABLED_PRECISIONS, ENGINE_CAPABILITY, HARDWARE_COMPATIBLE, + IMMUTABLE_WEIGHTS, LAZY_ENGINE_INIT, MAX_AUX_STREAMS, MIN_BLOCK_SIZE, @@ -84,6 +85,7 @@ class CompilationSettings: use_fp32_acc (bool): This option inserts cast to FP32 nodes around matmul layers and TensorRT ensures the accumulation of matmul happens in FP32. Use this only when FP16 precision is configured in enabled_precisions. refit_identical_engine_weights (bool): Whether to refit the engine with identical weights strip_engine_weights (bool): Whether to strip the engine weights + immutable_weights (bool): Build non-refittable engines. This is useful for some layers that are not refittable. If this argument is set, `strip_engine_weights` and `refit_identical_engine_weights` will be ignored """ enabled_precisions: Set[dtype] = field(default_factory=lambda: ENABLED_PRECISIONS) @@ -121,6 +123,7 @@ class CompilationSettings: use_fp32_acc: bool = USE_FP32_ACC refit_identical_engine_weights: bool = REFIT_IDENTICAL_ENGINE_WEIGHTS strip_engine_weights: bool = STRIP_ENGINE_WEIGHTS + immutable_weights: bool = IMMUTABLE_WEIGHTS _SETTINGS_TO_BE_ENGINE_INVARIANT = ( @@ -134,6 +137,7 @@ class CompilationSettings: "hardware_compatible", "strip_engine_weights", "refit_identical_engine_weights", + "immutable_weights", ) diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index 712eec6ef9..6c876b6625 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -287,15 +287,16 @@ def _populate_trt_builder_config( if self.compilation_settings.disable_tf32: builder_config.clear_flag(trt.BuilderFlag.TF32) - if version.parse(trt.__version__) >= version.parse("10.0"): - if self.compilation_settings.refit_identical_engine_weights: - builder_config.set_flag(trt.BuilderFlag.REFIT_IDENTICAL) + if not self.compilation_settings.immutable_weights: + if version.parse(trt.__version__) >= version.parse("10.0"): + if self.compilation_settings.refit_identical_engine_weights: + builder_config.set_flag(trt.BuilderFlag.REFIT_IDENTICAL) + else: + builder_config.set_flag(trt.BuilderFlag.REFIT) else: builder_config.set_flag(trt.BuilderFlag.REFIT) - else: - builder_config.set_flag(trt.BuilderFlag.REFIT) - builder_config.set_flag(trt.BuilderFlag.STRIP_PLAN) + builder_config.set_flag(trt.BuilderFlag.STRIP_PLAN) if strict_type_constraints: builder_config.set_flag(trt.BuilderFlag.STRICT_TYPES) @@ -591,31 +592,32 @@ def run( "Found the cached engine that corresponds to this graph. It is directly loaded." ) - # refit the cached engine with the new graph module - if not self.compilation_settings.strip_engine_weights: - runtime = trt.Runtime(TRT_LOGGER) - engine = runtime.deserialize_cuda_engine(serialized_engine) - - from torch_tensorrt.dynamo._refit import ( - _refit_single_trt_engine_with_gm, - ) - - _refit_single_trt_engine_with_gm( - new_gm=self.module, - old_engine=engine, - input_list=self.input_specs, - settings=self.compilation_settings, - weight_name_map=None, - ) - - # Serialize the refitted engine where the EXCLUDE_WEIGHTS flag must be cleared - serialization_config = engine.create_serialization_config() - serialization_config.clear_flag( - trt.SerializationFlag.EXCLUDE_WEIGHTS - ) - serialized_engine = engine.serialize_with_config( - serialization_config - ) + if not self.compilation_settings.immutable_weights: + # refit the cached engine with the new graph module + if not self.compilation_settings.strip_engine_weights: + runtime = trt.Runtime(TRT_LOGGER) + engine = runtime.deserialize_cuda_engine(serialized_engine) + + from torch_tensorrt.dynamo._refit import ( + _refit_single_trt_engine_with_gm, + ) + + _refit_single_trt_engine_with_gm( + new_gm=self.module, + old_engine=engine, + input_list=self.input_specs, + settings=self.compilation_settings, + weight_name_map=None, + ) + + # Serialize the refitted engine where the EXCLUDE_WEIGHTS flag must be cleared + serialization_config = engine.create_serialization_config() + serialization_config.clear_flag( + trt.SerializationFlag.EXCLUDE_WEIGHTS + ) + serialized_engine = engine.serialize_with_config( + serialization_config + ) with io.BytesIO() as engine_bytes: engine_bytes.write(serialized_engine) @@ -674,25 +676,28 @@ def run( ), ) - if not self.compilation_settings.strip_engine_weights: - # Refit the engine with the original weights - runtime = trt.Runtime(TRT_LOGGER) - engine = runtime.deserialize_cuda_engine(serialized_engine) + if not self.compilation_settings.immutable_weights: + if not self.compilation_settings.strip_engine_weights: + # Refit the engine with the original weights + runtime = trt.Runtime(TRT_LOGGER) + engine = runtime.deserialize_cuda_engine(serialized_engine) - from torch_tensorrt.dynamo._refit import _refit_single_trt_engine_with_gm + from torch_tensorrt.dynamo._refit import ( + _refit_single_trt_engine_with_gm, + ) - _refit_single_trt_engine_with_gm( - new_gm=self.module, - old_engine=engine, - input_list=self.input_specs, - settings=self.compilation_settings, - weight_name_map=None, - ) + _refit_single_trt_engine_with_gm( + new_gm=self.module, + old_engine=engine, + input_list=self.input_specs, + settings=self.compilation_settings, + weight_name_map=None, + ) - # Serialize the refitted engine where the EXCLUDE_WEIGHTS flag must be cleared - serialization_config = engine.create_serialization_config() - serialization_config.clear_flag(trt.SerializationFlag.EXCLUDE_WEIGHTS) - serialized_engine = engine.serialize_with_config(serialization_config) + # Serialize the refitted engine where the EXCLUDE_WEIGHTS flag must be cleared + serialization_config = engine.create_serialization_config() + serialization_config.clear_flag(trt.SerializationFlag.EXCLUDE_WEIGHTS) + serialized_engine = engine.serialize_with_config(serialization_config) with io.BytesIO() as engine_bytes: engine_bytes.write(serialized_engine) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 0ae16731d2..b103a14e00 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -49,7 +49,9 @@ def get_ir(target: Target) -> SourceIR: return SourceIR.UNKNOWN -def one_user_validator(node: Node, settings: CompilationSettings = None) -> bool: +def one_user_validator( + node: Node, settings: Optional[CompilationSettings] = None +) -> bool: # Validate only one user, which is a getitem node that accesses the first element in the list return ( len(node.users) == 1 @@ -271,8 +273,13 @@ def aten_ops_embedding( ) -def embedding_bag_validator(node: Node, settings: CompilationSettings = None) -> bool: +def embedding_bag_validator( + node: Node, settings: Optional[CompilationSettings] = None +) -> bool: # Embedding bag op is not refitable + if not settings or not settings.immutable_weights: + return False + if not one_user_validator(node): return False meta = node.args[1].meta @@ -418,7 +425,9 @@ def aten_ops_symsize_int( return impl.shape.shape(ctx, target, SourceIR.ATEN, name, args[0], args[1]) -def index_dtype_validator(node: Node, settings: CompilationSettings = None) -> bool: +def index_dtype_validator( + node: Node, settings: Optional[CompilationSettings] = None +) -> bool: index = node.args[1] for ind in index: if ind is not None: @@ -839,7 +848,9 @@ def aten_ops_select( ) -def index_put_validator(node: Node, settings: CompilationSettings = None) -> bool: +def index_put_validator( + node: Node, settings: Optional[CompilationSettings] = None +) -> bool: if args_bounds_check(node.args, 3, False): # Check if accumulate is valid _LOGGER.debug("We do not support accumulate=True for aten.index_put operation") accumulate_valid = False @@ -926,8 +937,16 @@ def aten_ops_slice( ) +def refit_validator(node: Node, settings: Optional[CompilationSettings] = None) -> bool: + # cumsum op is not refitable + if not settings or not settings.immutable_weights: + return False + return True + + @dynamo_tensorrt_converter( torch.ops.aten.cumsum.default, + capability_validator=refit_validator, supports_dynamic_shapes=True, ) @enforce_tensor_types( @@ -975,7 +994,9 @@ def aten_ops_tile( ) -def zero_output_validator(node: Node, settings: CompilationSettings = None) -> bool: +def zero_output_validator( + node: Node, settings: Optional[CompilationSettings] = None +) -> bool: if 0 in node.args[1]: _LOGGER.debug( f"We do not support output tensor {node.args[1]} tensors with zero-sized dimensions for this operation." @@ -1033,7 +1054,7 @@ def aten_ops_permute( def to_copy_dtype_validator( - placeholder_only: bool, settings: CompilationSettings = None + placeholder_only: bool, settings: Optional[CompilationSettings] = None ) -> Callable[[Node, CompilationSettings], bool]: """Return validator for to_copy node with placeholder restrictions""" @@ -1066,7 +1087,9 @@ def validate_dtype(to_copy_node: Node) -> bool: ) return False - def validator(to_copy_node: Node, settings: CompilationSettings = None) -> bool: + def validator( + to_copy_node: Node, settings: Optional[CompilationSettings] = None + ) -> bool: """Returns true if the to_copy node can be converted to TRT and the placeholder restriction is satisfied """ @@ -2137,7 +2160,9 @@ def aten_ops_logical_xor( ) -def bitwise_type_validator(node: Node, settings: CompilationSettings = None) -> bool: +def bitwise_type_validator( + node: Node, settings: Optional[CompilationSettings] = None +) -> bool: supported_type = [torch.bool, bool] tensor_targets = [ @@ -2281,7 +2306,7 @@ def aten_ops_bitwise_xor( def bitwise_not_type_validator( - node: Node, settings: CompilationSettings = None + node: Node, settings: Optional[CompilationSettings] = None ) -> bool: val = node.args[0] val_meta = val.meta.get("tensor_meta") @@ -2464,7 +2489,9 @@ def aten_ops_le( ) -def conv_param_validator(conv_node: Node, settings: CompilationSettings = None) -> bool: +def conv_param_validator( + conv_node: Node, settings: Optional[CompilationSettings] = None +) -> bool: return conv_node.args[7] in ([0], [0, 0], [0, 0, 0]) @@ -2561,7 +2588,7 @@ def aten_ops_cdist_forward( def avg_pool_param_validator( - pool_node: Node, settings: CompilationSettings = None + pool_node: Node, settings: Optional[CompilationSettings] = None ) -> bool: ceil_mode = args_bounds_check(pool_node.args, 4, False) divisor_override = args_bounds_check(pool_node.args, 6) @@ -2678,12 +2705,12 @@ def aten_ops_adaptive_avg_poolNd( ) -def topk_validator(node: Node, settings: CompilationSettings = None) -> bool: +def topk_validator(node: Node, settings: Optional[CompilationSettings] = None) -> bool: k = node.args[1] return topk_sort_validator(k) -def sort_validator(node: Node, settings: CompilationSettings = None) -> bool: +def sort_validator(node: Node, settings: Optional[CompilationSettings] = None) -> bool: meta_data = node.args[0].meta.get("tensor_meta") if meta_data is None: return False @@ -2706,7 +2733,7 @@ def topk_sort_validator(k: int) -> bool: def max_pool_param_validator( - pool_node: Node, settings: CompilationSettings = None + pool_node: Node, settings: Optional[CompilationSettings] = None ) -> bool: dilation = args_bounds_check(pool_node.args, 4, 1) ceil_mode = args_bounds_check(pool_node.args, 5, False) @@ -2761,7 +2788,9 @@ def aten_ops_max_pool( ) -def attention_validator(node: Node, settings: CompilationSettings = None) -> bool: +def attention_validator( + node: Node, settings: Optional[CompilationSettings] = None +) -> bool: # Currently, `attn_mask` is not supported return args_bounds_check(node.args, 3) is None @@ -3652,7 +3681,9 @@ def aten_ops_flip( ) -def zero_diag_size_validator(node: Node, settings: CompilationSettings = None) -> bool: +def zero_diag_size_validator( + node: Node, settings: Optional[CompilationSettings] = None +) -> bool: meta = node.args[0].meta.get("tensor_meta") if meta: input_shape = meta.shape @@ -3781,7 +3812,7 @@ def aten_ops_index_select( def dropout_inference_validator( - node: Node, settings: CompilationSettings = None + node: Node, settings: Optional[CompilationSettings] = None ) -> bool: train_mode = args_bounds_check(node.args, 2, None) if train_mode is False: diff --git a/tests/py/dynamo/conversion/harness.py b/tests/py/dynamo/conversion/harness.py index d29e3182dc..07e70c5b9b 100644 --- a/tests/py/dynamo/conversion/harness.py +++ b/tests/py/dynamo/conversion/harness.py @@ -263,6 +263,7 @@ def run_test( enable_passes=False, propagate_shapes=False, int32_reqd=False, + immutable_weights=True, ): mod = self.generate_graph( mod, @@ -278,6 +279,7 @@ def run_test( enabled_precisions={dtype._from(precision)}, truncate_double=True, debug=True, + immutable_weights=immutable_weights, ) num_inputs = len(inputs) @@ -346,6 +348,7 @@ def run_test_compare_tensor_attributes_only( output_dtypes=None, use_dynamo_tracer=False, enable_passes=False, + immutable_weights=True, ): mod = self.generate_graph( mod, @@ -359,6 +362,7 @@ def run_test_compare_tensor_attributes_only( enabled_precisions={dtype._from(precision)}, truncate_double=True, debug=True, + immutable_weights=immutable_weights, ) interp = TRTInterpreter( @@ -384,6 +388,7 @@ def run_test_with_dynamic_shape( pyt_inputs=None, propagate_shapes=False, check_dtype=True, + immutable_weights=True, ): mod = self.generate_graph( mod, @@ -397,6 +402,7 @@ def run_test_with_dynamic_shape( # We replicate this behavior here compilation_settings = CompilationSettings( truncate_double=True, + immutable_weights=immutable_weights, ) if check_dtype: From 160588e120c07452d7e778f52e735d7be69f2ddf Mon Sep 17 00:00:00 2001 From: Evan Li Date: Wed, 2 Oct 2024 13:28:24 -0700 Subject: [PATCH 09/52] skip engine caching for non-refittable engines, slow refit -> fast refit --- .../dynamo/conversion/_TRTInterpreter.py | 93 ++++++++++--------- 1 file changed, 48 insertions(+), 45 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index 6c876b6625..c1734498c5 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -547,7 +547,10 @@ def run( # self.engine_cache could be None if: # 1) engine_cache is not passed in when calling this function like convert_exported_program_to_serialized_trt_engine etc., or # 2) both cache_built_engines and reuse_cached_engines are False - if self.engine_cache is not None: + if ( + self.engine_cache is not None + and not self.compilation_settings.immutable_weights + ): if ( self.compilation_settings.cache_built_engines or self.compilation_settings.reuse_cached_engines @@ -592,32 +595,31 @@ def run( "Found the cached engine that corresponds to this graph. It is directly loaded." ) - if not self.compilation_settings.immutable_weights: - # refit the cached engine with the new graph module - if not self.compilation_settings.strip_engine_weights: - runtime = trt.Runtime(TRT_LOGGER) - engine = runtime.deserialize_cuda_engine(serialized_engine) - - from torch_tensorrt.dynamo._refit import ( - _refit_single_trt_engine_with_gm, - ) - - _refit_single_trt_engine_with_gm( - new_gm=self.module, - old_engine=engine, - input_list=self.input_specs, - settings=self.compilation_settings, - weight_name_map=None, - ) - - # Serialize the refitted engine where the EXCLUDE_WEIGHTS flag must be cleared - serialization_config = engine.create_serialization_config() - serialization_config.clear_flag( - trt.SerializationFlag.EXCLUDE_WEIGHTS - ) - serialized_engine = engine.serialize_with_config( - serialization_config - ) + # refit the cached engine with the new graph module + if not self.compilation_settings.strip_engine_weights: + runtime = trt.Runtime(TRT_LOGGER) + engine = runtime.deserialize_cuda_engine(serialized_engine) + + from torch_tensorrt.dynamo._refit import ( + _refit_single_trt_engine_with_gm, + ) + + _refit_single_trt_engine_with_gm( + new_gm=self.module, + old_engine=engine, + input_list=self.input_specs, + settings=self.compilation_settings, + weight_name_map=self.weight_name_map, + ) + + # Serialize the refitted engine where the EXCLUDE_WEIGHTS flag must be cleared + serialization_config = engine.create_serialization_config() + serialization_config.clear_flag( + trt.SerializationFlag.EXCLUDE_WEIGHTS + ) + serialized_engine = engine.serialize_with_config( + serialization_config + ) with io.BytesIO() as engine_bytes: engine_bytes.write(serialized_engine) @@ -659,24 +661,25 @@ def run( builder_config, self.compilation_settings.timing_cache_path ) - if ( - self.engine_cache is not None - and self.compilation_settings.cache_built_engines - ): - # Cache the weight-stripped engine - self.engine_cache.insert( - hash_val, - ( - serialized_engine, - self._input_names, - self._output_names, - self.input_specs, - self.compilation_settings, - self.weight_name_map, - ), - ) - if not self.compilation_settings.immutable_weights: + # Disable engine caching for non-refittable engines + if ( + self.engine_cache is not None + and self.compilation_settings.cache_built_engines + ): + # Cache the weight-stripped engine + self.engine_cache.insert( + hash_val, + ( + serialized_engine, + self._input_names, + self._output_names, + self.input_specs, + self.compilation_settings, + self.weight_name_map, + ), + ) + if not self.compilation_settings.strip_engine_weights: # Refit the engine with the original weights runtime = trt.Runtime(TRT_LOGGER) @@ -691,7 +694,7 @@ def run( old_engine=engine, input_list=self.input_specs, settings=self.compilation_settings, - weight_name_map=None, + weight_name_map=self.weight_name_map, ) # Serialize the refitted engine where the EXCLUDE_WEIGHTS flag must be cleared From 493f9810a43c2ea2b05dbd187db9cfef4eb56f1c Mon Sep 17 00:00:00 2001 From: Evan Li Date: Fri, 4 Oct 2024 18:11:16 -0700 Subject: [PATCH 10/52] refactored, there are 3 types of engines --- py/torch_tensorrt/dynamo/_compiler.py | 4 +- py/torch_tensorrt/dynamo/_settings.py | 2 +- .../dynamo/conversion/_TRTInterpreter.py | 71 +++++++++------ .../models/test_weight_stripped_engine.py | 89 ++++++++++++++++++- 4 files changed, 133 insertions(+), 33 deletions(-) diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 352e81950e..d35beacd91 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -165,7 +165,7 @@ def compile( use_fp32_acc (bool): This option inserts cast to FP32 nodes around matmul layers and TensorRT ensures the accumulation of matmul happens in FP32. Use this only when FP16 precision is configured in enabled_precisions. refit_identical_engine_weights (bool): Refit engines with identical weights. This is useful when the same model is compiled multiple times with different inputs and the weights are the same. This will save time by reusing the same engine for different inputs. strip_engine_weights (bool): Strip engine weights from the serialized engine. This is useful when the engine is to be deployed in an environment where the weights are not required. - immutable_weights (bool): Build non-refittable engines. This is useful for some layers that are not refittable. If this argument is set, `strip_engine_weights` and `refit_identical_engine_weights` will be ignored. + immutable_weights (bool): Build non-refittable engines. This is useful for some layers that are not refittable. If this argument is set to true, `strip_engine_weights` and `refit_identical_engine_weights` will be ignored. **kwargs: Any, Returns: torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT @@ -613,7 +613,7 @@ def convert_exported_program_to_serialized_trt_engine( use_fp32_acc (bool): This option inserts cast to FP32 nodes around matmul layers and TensorRT ensures the accumulation of matmul happens in FP32. Use this only when FP16 precision is configured in enabled_precisions. refit_identical_engine_weights (bool): Refit engines with identical weights. This is useful when the same model is compiled multiple times with different inputs and the weights are the same. This will save time by reusing the same engine for different inputs. strip_engine_weights (bool): Strip engine weights from the serialized engine. This is useful when the engine is to be deployed in an environment where the weights are not required. - immutable_weights (bool): Build non-refittable engines. This is useful for some layers that are not refittable. If this argument is set, `strip_engine_weights` and `refit_identical_engine_weights` will be ignored. + immutable_weights (bool): Build non-refittable engines. This is useful for some layers that are not refittable. If this argument is set to true, `strip_engine_weights` and `refit_identical_engine_weights` will be ignored. Returns: bytes: Serialized TensorRT engine, can either be saved to a file or deserialized via TensorRT APIs """ diff --git a/py/torch_tensorrt/dynamo/_settings.py b/py/torch_tensorrt/dynamo/_settings.py index 51bdc2c707..5ad36e1077 100644 --- a/py/torch_tensorrt/dynamo/_settings.py +++ b/py/torch_tensorrt/dynamo/_settings.py @@ -85,7 +85,7 @@ class CompilationSettings: use_fp32_acc (bool): This option inserts cast to FP32 nodes around matmul layers and TensorRT ensures the accumulation of matmul happens in FP32. Use this only when FP16 precision is configured in enabled_precisions. refit_identical_engine_weights (bool): Whether to refit the engine with identical weights strip_engine_weights (bool): Whether to strip the engine weights - immutable_weights (bool): Build non-refittable engines. This is useful for some layers that are not refittable. If this argument is set, `strip_engine_weights` and `refit_identical_engine_weights` will be ignored + immutable_weights (bool): Build non-refittable engines. This is useful for some layers that are not refittable. If this argument is set to true, `strip_engine_weights` and `refit_identical_engine_weights` will be ignored """ enabled_precisions: Set[dtype] = field(default_factory=lambda: ENABLED_PRECISIONS) diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index c1734498c5..54c0489932 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -287,7 +287,14 @@ def _populate_trt_builder_config( if self.compilation_settings.disable_tf32: builder_config.clear_flag(trt.BuilderFlag.TF32) - if not self.compilation_settings.immutable_weights: + if self.compilation_settings.immutable_weights: + # non-refittable engine + if self.compilation_settings.strip_engine_weights: + _LOGGER.warning( + "You cannot get a non-refittable engine with weights stripped. `strip_engine_weights` will be set to false and engine caching will be disabled." + ) + else: + # refittable engine if version.parse(trt.__version__) >= version.parse("10.0"): if self.compilation_settings.refit_identical_engine_weights: builder_config.set_flag(trt.BuilderFlag.REFIT_IDENTICAL) @@ -296,7 +303,8 @@ def _populate_trt_builder_config( else: builder_config.set_flag(trt.BuilderFlag.REFIT) - builder_config.set_flag(trt.BuilderFlag.STRIP_PLAN) + if self.compilation_settings.strip_engine_weights: + builder_config.set_flag(trt.BuilderFlag.STRIP_PLAN) if strict_type_constraints: builder_config.set_flag(trt.BuilderFlag.STRICT_TYPES) @@ -564,7 +572,7 @@ def run( cached_data = self.engine_cache.check(hash_val) if cached_data is not None: # hit the cache ( - serialized_engine, + weight_stripped_serialized_engine, self._input_names, self._output_names, cached_engine_input_specs, @@ -598,7 +606,9 @@ def run( # refit the cached engine with the new graph module if not self.compilation_settings.strip_engine_weights: runtime = trt.Runtime(TRT_LOGGER) - engine = runtime.deserialize_cuda_engine(serialized_engine) + engine = runtime.deserialize_cuda_engine( + weight_stripped_serialized_engine + ) from torch_tensorrt.dynamo._refit import ( _refit_single_trt_engine_with_gm, @@ -620,6 +630,7 @@ def run( serialized_engine = engine.serialize_with_config( serialization_config ) + # As of now, the engine becomes non-refittable because when EXCLUDE_WEIGHTS flag is cleared, the REFIT flag is also cleared by TRT to make the plan file smaller with io.BytesIO() as engine_bytes: engine_bytes.write(serialized_engine) @@ -661,17 +672,43 @@ def run( builder_config, self.compilation_settings.timing_cache_path ) + # refittable engine if not self.compilation_settings.immutable_weights: - # Disable engine caching for non-refittable engines + # Engine caching only for refittable engine if ( self.engine_cache is not None and self.compilation_settings.cache_built_engines ): # Cache the weight-stripped engine + if self.compilation_settings.strip_engine_weights: + weight_stripped_serialized_engine = serialized_engine + else: + # Serialize the refitted engine where the EXCLUDE_WEIGHTS flag must be cleared + runtime = trt.Runtime(TRT_LOGGER) + engine = runtime.deserialize_cuda_engine(serialized_engine) + + from torch_tensorrt.dynamo._refit import ( + _refit_single_trt_engine_with_gm, + ) + + _refit_single_trt_engine_with_gm( + new_gm=self.module, + old_engine=engine, + input_list=self.input_specs, + settings=self.compilation_settings, + weight_name_map=self.weight_name_map, + ) + + serialization_config = engine.create_serialization_config() + serialization_config.set_flag(trt.SerializationFlag.EXCLUDE_WEIGHTS) + weight_stripped_serialized_engine = engine.serialize_with_config( + serialization_config + ) + self.engine_cache.insert( hash_val, ( - serialized_engine, + weight_stripped_serialized_engine, self._input_names, self._output_names, self.input_specs, @@ -680,28 +717,6 @@ def run( ), ) - if not self.compilation_settings.strip_engine_weights: - # Refit the engine with the original weights - runtime = trt.Runtime(TRT_LOGGER) - engine = runtime.deserialize_cuda_engine(serialized_engine) - - from torch_tensorrt.dynamo._refit import ( - _refit_single_trt_engine_with_gm, - ) - - _refit_single_trt_engine_with_gm( - new_gm=self.module, - old_engine=engine, - input_list=self.input_specs, - settings=self.compilation_settings, - weight_name_map=self.weight_name_map, - ) - - # Serialize the refitted engine where the EXCLUDE_WEIGHTS flag must be cleared - serialization_config = engine.create_serialization_config() - serialization_config.clear_flag(trt.SerializationFlag.EXCLUDE_WEIGHTS) - serialized_engine = engine.serialize_with_config(serialization_config) - with io.BytesIO() as engine_bytes: engine_bytes.write(serialized_engine) engine_str = engine_bytes.getvalue() diff --git a/tests/py/dynamo/models/test_weight_stripped_engine.py b/tests/py/dynamo/models/test_weight_stripped_engine.py index 1454eb4542..67df54a09b 100644 --- a/tests/py/dynamo/models/test_weight_stripped_engine.py +++ b/tests/py/dynamo/models/test_weight_stripped_engine.py @@ -15,6 +15,56 @@ class TestWeightStrippedEngine(TestCase): + def test_three_ways_to_compile(self): + pyt_model = models.resnet18(pretrained=True).eval().to("cuda") + example_inputs = (torch.randn((100, 3, 224, 224)).to("cuda"),) + exp_program = torch.export.export(pyt_model, example_inputs) + + settings = { + "use_python_runtime": False, + "enabled_precisions": {torch.float}, + "debug": False, + "min_block_size": 1, + "strip_engine_weights": False, + "refit_identical_engine_weights": False, + } + + # 1. Compile with torch_trt.dynamo.compile + gm1 = torch_trt.dynamo.compile( + exp_program, + example_inputs, + **settings, + ) + gm1_output = gm1(*example_inputs) + + # 2. Compile with torch_trt.compile using dynamo backend + gm2 = torch_trt.compile( + pyt_model, ir="dynamo", inputs=example_inputs, **settings + ) + gm2_output = gm2(*example_inputs) + + # 3. Compile with torch.compile using tensorrt backend + gm3 = torch.compile( + pyt_model, + backend="tensorrt", + options=settings, + ) + gm3_output = gm3(*example_inputs) + + pyt_model_output = pyt_model(*example_inputs) + + assert torch.allclose( + pyt_model_output, gm1_output, 1e-2, 1e-2 + ), "gm1_output is not correct" + + assert torch.allclose( + gm1_output, gm2_output, 1e-2, 1e-2 + ), "gm2_output is not correct" + + assert torch.allclose( + gm2_output, gm3_output, 1e-2, 1e-2 + ), "gm3_output is not correct" + def test_weight_stripped_engine_sizes(self): pyt_model = models.resnet18(pretrained=True).eval().to("cuda") example_inputs = (torch.randn((100, 3, 224, 224)).to("cuda"),) @@ -67,8 +117,6 @@ def test_weight_stripped_engine_results(self): enabled_precisions={torch.float}, debug=False, min_block_size=1, - cache_built_engines=False, - reuse_cached_engines=False, strip_engine_weights=True, refit_identical_engine_weights=False, ) @@ -316,3 +364,40 @@ def remove_timing_cache(path=TIMING_CACHE_PATH): times[0] > times[2], msg=f"Engine caching didn't speed up the compilation. Time taken without engine caching: {times[0]} ms, time taken with engine caching: {times[2]} ms", ) + + def test_different_args_dont_share_engine_caching(self): + pyt_model = models.resnet18(pretrained=True).eval().to("cuda") + + engine_cache_dir = "/tmp/test_different_args_dont_share_engine_caching" + if os.path.exists(engine_cache_dir): + shutil.rmtree(engine_cache_dir) + + inputs = [torch.rand((128, 3, 224, 224)).to("cuda")] + + for i in range(2): + if i == 0: + strip_engine_weights = False + else: + strip_engine_weights = True + + compiled_model = torch.compile( + pyt_model, + backend="tensorrt", + options={ + "use_python_runtime": True, + "enabled_precisions": {torch.float}, + "debug": False, + "min_block_size": 1, + "cache_built_engines": True, + "reuse_cached_engines": True, + "engine_cache_dir": engine_cache_dir, + "strip_engine_weights": strip_engine_weights, + }, + ) + compiled_model(*inputs) + + assertions.assertEqual( + len(os.listdir(engine_cache_dir)), + 2, + msg=f"It has {len(os.listdir(engine_cache_dir))} cached engine(s) but should have 2 engines", + ) From f204104fce90c81a5c2769d5209e976e8547f69e Mon Sep 17 00:00:00 2001 From: Evan Li Date: Sat, 5 Oct 2024 02:03:14 -0700 Subject: [PATCH 11/52] fix and add tests --- tests/py/dynamo/models/test_engine_cache.py | 3 +- .../models/test_weight_stripped_engine.py | 176 +++++++++++++++++- 2 files changed, 169 insertions(+), 10 deletions(-) diff --git a/tests/py/dynamo/models/test_engine_cache.py b/tests/py/dynamo/models/test_engine_cache.py index 3502e430f8..383009e48e 100644 --- a/tests/py/dynamo/models/test_engine_cache.py +++ b/tests/py/dynamo/models/test_engine_cache.py @@ -485,8 +485,7 @@ def test_torch_compile_with_custom_engine_cache(self): def test_torch_compile_change_input_shape(self): # Custom Engine Cache model = models.resnet18(pretrained=True).eval().to("cuda") - - engine_cache_dir = "/tmp/test_torch_compile_with_default_disk_engine_cache" + engine_cache_dir = "/tmp/test_torch_compile_change_input_shape" if os.path.exists(engine_cache_dir): shutil.rmtree(engine_cache_dir) diff --git a/tests/py/dynamo/models/test_weight_stripped_engine.py b/tests/py/dynamo/models/test_weight_stripped_engine.py index 67df54a09b..8ca22a132c 100644 --- a/tests/py/dynamo/models/test_weight_stripped_engine.py +++ b/tests/py/dynamo/models/test_weight_stripped_engine.py @@ -9,6 +9,7 @@ from torch.testing._internal.common_utils import TestCase from torch_tensorrt.dynamo import convert_exported_program_to_serialized_trt_engine from torch_tensorrt.dynamo._defaults import TIMING_CACHE_PATH +from torch_tensorrt.dynamo._refit import refit_module_weights from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity assertions = unittest.TestCase() @@ -65,6 +66,54 @@ def test_three_ways_to_compile(self): gm2_output, gm3_output, 1e-2, 1e-2 ), "gm3_output is not correct" + def test_three_ways_to_compile_weight_stripped_engine(self): + pyt_model = models.resnet18(pretrained=True).eval().to("cuda") + example_inputs = (torch.randn((100, 3, 224, 224)).to("cuda"),) + exp_program = torch.export.export(pyt_model, example_inputs) + + settings = { + "use_python_runtime": False, + "enabled_precisions": {torch.float}, + "debug": False, + "min_block_size": 1, + "strip_engine_weights": True, + "refit_identical_engine_weights": False, + } + + # 1. Compile with torch_trt.dynamo.compile + gm1 = torch_trt.dynamo.compile( + exp_program, + example_inputs, + **settings, + ) + gm1_output = gm1(*example_inputs) + + # 2. Compile with torch_trt.compile using dynamo backend + gm2 = torch_trt.compile( + pyt_model, ir="dynamo", inputs=example_inputs, **settings + ) + gm2_output = gm2(*example_inputs) + + # 3. Compile with torch.compile using tensorrt backend + gm3 = torch.compile( + pyt_model, + backend="tensorrt", + options=settings, + ) + gm3_output = gm3(*example_inputs) + + assertions.assertEqual( + gm1_output.sum(), 0, msg="gm1_output should be all zeros" + ) + + assertions.assertEqual( + gm2_output.sum(), 0, msg="gm2_output should be all zeros" + ) + + assertions.assertEqual( + gm3_output.sum(), 0, msg="gm3_output should be all zeros" + ) + def test_weight_stripped_engine_sizes(self): pyt_model = models.resnet18(pretrained=True).eval().to("cuda") example_inputs = (torch.randn((100, 3, 224, 224)).to("cuda"),) @@ -125,8 +174,6 @@ def test_weight_stripped_engine_results(self): output.sum(), 0, msg="weight-stripped engine results should be all zeros" ) - from torch_tensorrt.dynamo._refit import refit_module_weights - # Refit the weight-stripped engine with the same weights refitted_trt_gm = refit_module_weights(trt_gm, exp_program) refitted_output = refitted_trt_gm(*inputs) @@ -157,12 +204,12 @@ def test_weight_stripped_engine_results(self): msg=f"refitted_output doesn't match with compiled_model_output. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", ) - def test_weight_stripped_engine_with_engine_cache(self): + def test_engine_caching_saves_weight_stripped_engine(self): pyt_model = models.resnet18(pretrained=True).eval().to("cuda") example_inputs = (torch.randn((100, 3, 224, 224)).to("cuda"),) exp_program = torch.export.export(pyt_model, example_inputs) - engine_cache_dir = "/tmp/test_weight_stripped_engine_with_engine_cache" + engine_cache_dir = "/tmp/test_engine_caching_saves_weight_stripped_engine" if os.path.exists(engine_cache_dir): shutil.rmtree(engine_cache_dir) @@ -187,6 +234,7 @@ def test_weight_stripped_engine_with_engine_cache(self): engine_cache_dir=engine_cache_dir, ) output = trt_gm(*example_inputs) + assertions.assertNotEqual(output.sum(), 0, msg="results shouldn't be all zeros") blob_path = os.path.join( engine_cache_dir, os.listdir(engine_cache_dir)[0], "blob.bin" @@ -200,7 +248,6 @@ def test_weight_stripped_engine_with_engine_cache(self): len(bytes(weight_included_engine)) > len(bytes(cached_stripped_engine)), msg=f"cached engine size is not smaller than the weight included engine size. Weight included engine size: {len(bytes(weight_included_engine))}, cached stripped engine size: {len(bytes(cached_stripped_engine))}", ) - assertions.assertNotEqual(output.sum(), 0, msg="results are all zeros") def test_dynamo_compile_with_refittable_weight_stripped_engine(self): pyt_model = models.resnet18(pretrained=True).eval().to("cuda") @@ -328,7 +375,6 @@ def remove_timing_cache(path=TIMING_CACHE_PATH): "cache_built_engines": cache_built_engines, "reuse_cached_engines": reuse_cached_engines, "engine_cache_dir": engine_cache_dir, - "torch_executed_ops": {"torch.ops.aten.relu.default"}, "strip_engine_weights": False, "refit_identical_engine_weights": True, }, @@ -365,10 +411,10 @@ def remove_timing_cache(path=TIMING_CACHE_PATH): msg=f"Engine caching didn't speed up the compilation. Time taken without engine caching: {times[0]} ms, time taken with engine caching: {times[2]} ms", ) - def test_different_args_dont_share_engine_caching(self): + def test_different_args_dont_share_cached_engine(self): pyt_model = models.resnet18(pretrained=True).eval().to("cuda") - engine_cache_dir = "/tmp/test_different_args_dont_share_engine_caching" + engine_cache_dir = "/tmp/test_different_args_dont_share_cached_engine" if os.path.exists(engine_cache_dir): shutil.rmtree(engine_cache_dir) @@ -401,3 +447,117 @@ def test_different_args_dont_share_engine_caching(self): 2, msg=f"It has {len(os.listdir(engine_cache_dir))} cached engine(s) but should have 2 engines", ) + + def test_constant_mul_in_refitting(self): + class MyModel(torch.nn.Module): + def forward(self, x): + out = x * 0.5 + return out + + # TODO: investigate why this doesn't work + pyt_model = MyModel().eval().cuda() + inputs = [torch.randn((1, 3, 4, 4)).to("cuda")] + # TODO: investigate why this works + # pyt_model = models.resnet18(pretrained=True).eval().to("cuda") + # inputs = [torch.rand((2, 3, 224, 224)).to("cuda")] + + exp_program = torch.export.export(pyt_model, args=tuple(inputs)) + + trt_module = torch_trt.compile( + pyt_model, + ir="dynamo", + inputs=tuple(inputs), + min_block_size=1, + use_python_runtime=True, + strip_engine_weights=True, + refit_identical_engine_weights=False, + ) + + refitted_trt_gm = refit_module_weights(trt_module, exp_program) + + outputs_pyt = pyt_model(*inputs) + outputs_trt = refitted_trt_gm(*inputs) + + cos_sim = cosine_similarity(outputs_pyt, outputs_trt) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + def test_two_TRTRuntime_in_refitting(self): + pyt_model = models.resnet18(pretrained=True).eval().to("cuda") + example_inputs = (torch.randn((100, 3, 224, 224)).to("cuda"),) + batch = torch.export.Dim("batch", min=1, max=200) + exp_program = torch.export.export( + pyt_model, args=example_inputs, dynamic_shapes={"x": {0: batch}} + ) + inputs = [torch.rand((128, 3, 224, 224)).to("cuda")] + + pyt_results = pyt_model(*inputs) + + for i in range(2): + if i == 0: + use_python_runtime = True + else: + use_python_runtime = False + + trt_gm = torch_trt.dynamo.compile( + exp_program, + tuple(inputs), + use_python_runtime=use_python_runtime, + debug=False, + min_block_size=1, + strip_engine_weights=True, + refit_identical_engine_weights=False, + ) + + output = trt_gm(*inputs) + assertions.assertEqual(output.sum(), 0, msg="results should be all zeros") + + refitted_trt_gm = refit_module_weights(trt_gm, exp_program) + refitted_output = refitted_trt_gm(*inputs) + cos_sim = cosine_similarity(pyt_results, refitted_output) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"{'PythonTorchTensorRTModule' if use_python_runtime else 'TorchTensorRTModule'} outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + @unittest.skip("Waiting for implementation") + def test_refit_identical_engine_weights(self): + pyt_model = models.resnet18(pretrained=True).eval().to("cuda") + example_inputs = (torch.randn((100, 3, 224, 224)).to("cuda"),) + exp_program = torch.export.export(pyt_model, example_inputs) + + engine_cache_dir = "/tmp/test_refit_identical_engine_weights" + if os.path.exists(engine_cache_dir): + shutil.rmtree(engine_cache_dir) + + trt_gm = torch_trt.dynamo.compile( + exp_program, + tuple(example_inputs), + use_python_runtime=True, + enabled_precisions={torch.float}, + debug=False, + min_block_size=1, + strip_engine_weights=True, + refit_identical_engine_weights=True, + ) + output = trt_gm(*example_inputs) + + pyt_model2 = models.resnet18(pretrained=False).eval().to("cuda") + exp_program2 = torch.export.export(pyt_model2, example_inputs) + + try: + refit_module_weights(trt_gm, exp_program) + except Exception as e: + assertions.fail( + f"Refitting the engine with the same weights failed with the following error: {e}" + ) + + try: + refit_module_weights(trt_gm, exp_program2) + assertions.fail( + "Refitting the engine with different weights should have failed but it didn't" + ) + except Exception as e: + pass From 4663c834c3a8e81b6a1b91774a20f50481bf9523 Mon Sep 17 00:00:00 2001 From: Evan Li Date: Mon, 7 Oct 2024 22:29:23 -0700 Subject: [PATCH 12/52] fix issues #3206 #3217 --- py/torch_tensorrt/dynamo/_refit.py | 19 +++++++++++++++---- .../models/test_weight_stripped_engine.py | 10 +++++----- 2 files changed, 20 insertions(+), 9 deletions(-) diff --git a/py/torch_tensorrt/dynamo/_refit.py b/py/torch_tensorrt/dynamo/_refit.py index 20dfb982cb..03f2ba58c7 100644 --- a/py/torch_tensorrt/dynamo/_refit.py +++ b/py/torch_tensorrt/dynamo/_refit.py @@ -410,6 +410,10 @@ def refit_module_weights( "The type of graph module is not supported for refitting or two compiled modules do not match." ) + assert ( + engine.refittable + ), "The engine is not refittable. The reason may be that the engine was built with an older version of Torch-TensorRT, or you are refitting a refitted weight-stripped engine. Note that weight-stripped engine can be refitted only once." + # Get the submodule inputs for min, opt, max shapes of the graph inputs submodule_inputs = partitioning.construct_submodule_inputs(new_submodule) logger.debug( @@ -447,17 +451,24 @@ def refit_module_weights( weight_name_map=None, ) + # clear EXCLUDE_WEIGHTS flag + serialization_config = engine.create_serialization_config() + serialization_config.clear_flag(trt.SerializationFlag.EXCLUDE_WEIGHTS) + serialized_engine = engine.serialize_with_config(serialization_config) + engine = runtime.deserialize_cuda_engine(serialized_engine) + + if isinstance(compiled_submodule, PythonTorchTensorRTModule): + compiled_submodule.engine = engine + if isinstance(compiled_submodule, TorchTensorRTModule): - serialized_engine = bytes(engine.serialize()) new_engine_info = list(engine_info) - new_engine_info[ENGINE_IDX] = serialized_engine + new_engine_info[ENGINE_IDX] = bytes(serialized_engine) refitted_engine = torch.classes.tensorrt.Engine(tuple(new_engine_info)) compiled_submodule.engine = refitted_engine elif inline_module: - serialized_engine = bytes(engine.serialize()) new_engine_info = list(engine_info) - new_engine_info[ENGINE_IDX] = serialized_engine + new_engine_info[ENGINE_IDX] = bytes(serialized_engine) refitted_engine = torch.classes.tensorrt.Engine(tuple(new_engine_info)) setattr(compiled_module, f"{name}_engine", refitted_engine) diff --git a/tests/py/dynamo/models/test_weight_stripped_engine.py b/tests/py/dynamo/models/test_weight_stripped_engine.py index 8ca22a132c..630bf10f4e 100644 --- a/tests/py/dynamo/models/test_weight_stripped_engine.py +++ b/tests/py/dynamo/models/test_weight_stripped_engine.py @@ -450,16 +450,16 @@ def test_different_args_dont_share_cached_engine(self): def test_constant_mul_in_refitting(self): class MyModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.weight = torch.tensor(0.5, requires_grad=False) + def forward(self, x): - out = x * 0.5 + out = x * self.weight return out - # TODO: investigate why this doesn't work pyt_model = MyModel().eval().cuda() inputs = [torch.randn((1, 3, 4, 4)).to("cuda")] - # TODO: investigate why this works - # pyt_model = models.resnet18(pretrained=True).eval().to("cuda") - # inputs = [torch.rand((2, 3, 224, 224)).to("cuda")] exp_program = torch.export.export(pyt_model, args=tuple(inputs)) From c57ab061fba0b4e55b62f505f34db637cc56bef2 Mon Sep 17 00:00:00 2001 From: Evan Li Date: Mon, 14 Oct 2024 21:02:04 -0700 Subject: [PATCH 13/52] small fix --- py/torch_tensorrt/dynamo/_refit.py | 4 ---- .../dynamo/conversion/_TRTInterpreter.py | 11 +++-------- 2 files changed, 3 insertions(+), 12 deletions(-) diff --git a/py/torch_tensorrt/dynamo/_refit.py b/py/torch_tensorrt/dynamo/_refit.py index 03f2ba58c7..951325f2d0 100644 --- a/py/torch_tensorrt/dynamo/_refit.py +++ b/py/torch_tensorrt/dynamo/_refit.py @@ -410,10 +410,6 @@ def refit_module_weights( "The type of graph module is not supported for refitting or two compiled modules do not match." ) - assert ( - engine.refittable - ), "The engine is not refittable. The reason may be that the engine was built with an older version of Torch-TensorRT, or you are refitting a refitted weight-stripped engine. Note that weight-stripped engine can be refitted only once." - # Get the submodule inputs for min, opt, max shapes of the graph inputs submodule_inputs = partitioning.construct_submodule_inputs(new_submodule) logger.debug( diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index 54c0489932..a793f4cc60 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -290,16 +290,11 @@ def _populate_trt_builder_config( if self.compilation_settings.immutable_weights: # non-refittable engine if self.compilation_settings.strip_engine_weights: - _LOGGER.warning( - "You cannot get a non-refittable engine with weights stripped. `strip_engine_weights` will be set to false and engine caching will be disabled." - ) + _LOGGER.warning("strip_engine_weights will be ignored.") else: # refittable engine - if version.parse(trt.__version__) >= version.parse("10.0"): - if self.compilation_settings.refit_identical_engine_weights: - builder_config.set_flag(trt.BuilderFlag.REFIT_IDENTICAL) - else: - builder_config.set_flag(trt.BuilderFlag.REFIT) + if self.compilation_settings.refit_identical_engine_weights: + builder_config.set_flag(trt.BuilderFlag.REFIT_IDENTICAL) else: builder_config.set_flag(trt.BuilderFlag.REFIT) From 402c9b0518a57f930aa0361302491eb589fb0e5a Mon Sep 17 00:00:00 2001 From: Evan Li Date: Mon, 14 Oct 2024 21:10:50 -0700 Subject: [PATCH 14/52] resolve comments --- py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index a793f4cc60..256b5111c8 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -640,7 +640,8 @@ def run( self._construct_trt_network_def() - self._save_weight_mapping() + if not self.compilation_settings.immutable_weights: + self._save_weight_mapping() build_engine_start_time = datetime.now() _LOGGER.info("Not found cached TRT engines. Start building engine.") @@ -659,7 +660,7 @@ def run( assert serialized_engine _LOGGER.info( - f"Build weight-stripped TRT engine elapsed time: {datetime.now() - build_engine_start_time}" + f"Build TRT engine elapsed time: {datetime.now() - build_engine_start_time}" ) _LOGGER.info(f"TRT Engine uses: {serialized_engine.nbytes} bytes of Memory") From d8e59da9f4c5cf4199e1af378d390ed7d5616562 Mon Sep 17 00:00:00 2001 From: Evan Li Date: Tue, 22 Oct 2024 16:54:31 -0700 Subject: [PATCH 15/52] WIP: cache weight-stripped engine --- py/torch_tensorrt/dynamo/_engine_cache.py | 13 +++++++++---- py/torch_tensorrt/dynamo/_settings.py | 1 - .../dynamo/conversion/_TRTInterpreter.py | 16 ++-------------- 3 files changed, 11 insertions(+), 19 deletions(-) diff --git a/py/torch_tensorrt/dynamo/_engine_cache.py b/py/torch_tensorrt/dynamo/_engine_cache.py index f166b489cb..f2452fadd5 100644 --- a/py/torch_tensorrt/dynamo/_engine_cache.py +++ b/py/torch_tensorrt/dynamo/_engine_cache.py @@ -1,4 +1,5 @@ import copy +import hashlib import io import logging import os @@ -6,10 +7,10 @@ import pickletools import shutil from abc import ABC, abstractmethod -from typing import Any, Dict, List, Optional, Sequence, Tuple, cast +from typing import Any, Dict, List, Optional, Sequence, Tuple import torch -from torch._inductor.codecache import FxGraphCachePickler, sha256_hash +from torch._inductor.codecache import sha256_hash from torch.fx.experimental.proxy_tensor import unset_fake_temporarily from torch_tensorrt._Input import Input from torch_tensorrt.dynamo._settings import ( @@ -59,7 +60,11 @@ def get_hash( for name, param in new_gm.named_parameters(): param.data.zero_() - graph_hash_val = cast(str, FxGraphCachePickler.get_hash(new_gm)) + # TODO: This hash function is slow, reported in https://github.com/pytorch/TensorRT/issues/3249 + # Waiting for a fix from PyTorch team + # graph_hash = FxGraphCachePickler.get_hash(new_gm) + graph_str = str(new_gm.graph) + graph_hash = hashlib.sha256(graph_str.encode()).hexdigest() input_spec_strs = [str(i) for i in input_specs] with io.BytesIO() as stream: @@ -75,7 +80,7 @@ def get_hash( engine_specs_data = pickletools.optimize(engine_specs_data) engine_specs_hash = sha256_hash(engine_specs_data) - hash_val: str = graph_hash_val + input_specs_hash + engine_specs_hash + hash_val: str = graph_hash + input_specs_hash + engine_specs_hash return hash_val diff --git a/py/torch_tensorrt/dynamo/_settings.py b/py/torch_tensorrt/dynamo/_settings.py index 5ad36e1077..d99702a589 100644 --- a/py/torch_tensorrt/dynamo/_settings.py +++ b/py/torch_tensorrt/dynamo/_settings.py @@ -135,7 +135,6 @@ class CompilationSettings: "sparse_weights", "engine_capability", "hardware_compatible", - "strip_engine_weights", "refit_identical_engine_weights", "immutable_weights", ) diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index 256b5111c8..d5076cddaf 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -617,7 +617,7 @@ def run( weight_name_map=self.weight_name_map, ) - # Serialize the refitted engine where the EXCLUDE_WEIGHTS flag must be cleared + # EXCLUDE_WEIGHTS flag must be cleared serialization_config = engine.create_serialization_config() serialization_config.clear_flag( trt.SerializationFlag.EXCLUDE_WEIGHTS @@ -679,22 +679,10 @@ def run( if self.compilation_settings.strip_engine_weights: weight_stripped_serialized_engine = serialized_engine else: - # Serialize the refitted engine where the EXCLUDE_WEIGHTS flag must be cleared + # set EXCLUDE_WEIGHTS flag to strip weights runtime = trt.Runtime(TRT_LOGGER) engine = runtime.deserialize_cuda_engine(serialized_engine) - from torch_tensorrt.dynamo._refit import ( - _refit_single_trt_engine_with_gm, - ) - - _refit_single_trt_engine_with_gm( - new_gm=self.module, - old_engine=engine, - input_list=self.input_specs, - settings=self.compilation_settings, - weight_name_map=self.weight_name_map, - ) - serialization_config = engine.create_serialization_config() serialization_config.set_flag(trt.SerializationFlag.EXCLUDE_WEIGHTS) weight_stripped_serialized_engine = engine.serialize_with_config( From f2e3f00091d1e85efb8fefdd269a4e4cbfe08972 Mon Sep 17 00:00:00 2001 From: Evan Li Date: Mon, 4 Nov 2024 14:44:18 -0800 Subject: [PATCH 16/52] redesigned hash func and add constant mapping to fast refit --- py/torch_tensorrt/dynamo/_engine_cache.py | 41 +++++++---- py/torch_tensorrt/dynamo/_refit.py | 17 +++++ .../dynamo/conversion/_TRTInterpreter.py | 72 +++++++++++-------- tests/py/dynamo/models/test_engine_cache.py | 52 ++++++++++++-- 4 files changed, 133 insertions(+), 49 deletions(-) diff --git a/py/torch_tensorrt/dynamo/_engine_cache.py b/py/torch_tensorrt/dynamo/_engine_cache.py index f2452fadd5..7835c419d0 100644 --- a/py/torch_tensorrt/dynamo/_engine_cache.py +++ b/py/torch_tensorrt/dynamo/_engine_cache.py @@ -1,5 +1,4 @@ import copy -import hashlib import io import logging import os @@ -11,7 +10,6 @@ import torch from torch._inductor.codecache import sha256_hash -from torch.fx.experimental.proxy_tensor import unset_fake_temporarily from torch_tensorrt._Input import Input from torch_tensorrt.dynamo._settings import ( _SETTINGS_TO_BE_ENGINE_INVARIANT, @@ -50,21 +48,38 @@ def get_hash( Args: gm (torch.fx.GraphModule): GraphModule to hash + input_specs (Sequence[Input]): input specs for the GraphModule + settings (CompilationSettings): compilation settings for the GraphModule Returns: str: hash value of the GraphModule """ - # parameters are set to 0 - with unset_fake_temporarily(): - new_gm = copy.deepcopy(gm) - for name, param in new_gm.named_parameters(): - param.data.zero_() - - # TODO: This hash function is slow, reported in https://github.com/pytorch/TensorRT/issues/3249 - # Waiting for a fix from PyTorch team - # graph_hash = FxGraphCachePickler.get_hash(new_gm) - graph_str = str(new_gm.graph) - graph_hash = hashlib.sha256(graph_str.encode()).hexdigest() + + def canonicalize_graph(graph: torch.fx.Graph) -> str: + """Canonicalize the graph to a string for isomorphic graph comparison + + Args: + graph (torch.fx.Graph): graph to canonicalize + + Returns: + str: canonicalized graph string + """ + canonical_nodes = [] + input_counter = 0 + + for node in graph.nodes: + if node.op == "placeholder": + canonical_nodes.append(f"placeholder_input_{input_counter}") + input_counter += 1 + else: + canonical_nodes.append(f"{node.op}_{node.target}") + + return " ".join(canonical_nodes) + + graph_str = canonicalize_graph(gm.graph) + _LOGGER.debug(f"graph_str:\n {graph_str}") + + graph_hash = sha256_hash(graph_str.encode()) input_spec_strs = [str(i) for i in input_specs] with io.BytesIO() as stream: diff --git a/py/torch_tensorrt/dynamo/_refit.py b/py/torch_tensorrt/dynamo/_refit.py index b61262e9fe..f996af809c 100644 --- a/py/torch_tensorrt/dynamo/_refit.py +++ b/py/torch_tensorrt/dynamo/_refit.py @@ -156,9 +156,26 @@ def _refit_single_trt_engine_with_gm( if torch_device.type == "cuda" else trt.TensorLocation.HOST ) + + constant_mapping: dict[str, Any] = weight_name_map.pop( + "constant_mapping", {} + ) # type: ignore mapping = construct_refit_mapping_from_weight_name_map( weight_name_map, new_gm.state_dict() ) + constant_mapping_with_type = {} + + for constant_name, val in constant_mapping.items(): + np_weight_type = val.dtype + val_tensor = torch.from_numpy(val).cuda() + trt_dtype = dtype.try_from(np_weight_type).to(trt.DataType) + torch_dtype = dtype.try_from(np_weight_type).to(torch.dtype) + constant_mapping_with_type[constant_name] = ( + val_tensor.clone().reshape(-1).contiguous().to(torch_dtype), + trt_dtype, + ) + + mapping.update(constant_mapping_with_type) # Debug Use # correct = construct_refit_mapping(new_gm, input_list, settings) diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index 287b8ef49d..698db646b2 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -376,7 +376,6 @@ def find_weight( np_map: the map from weight name to np values in INetworkDefinition state_dict: state of the graph module """ - network_weight = np_map[weight_name] network_weight = torch.from_numpy(np_map[weight_name]).cuda() for sd_w_name, sd_weight in state_dict.items(): if TRTInterpreter.check_weight_equal(sd_weight, network_weight): @@ -465,6 +464,7 @@ def _save_weight_mapping(self) -> None: sd = {k: v.reshape(-1) for k, v in self.module.state_dict().items()} weight_name_map: dict[str, Any] = {} np_map = {} + constant_mapping = {} net = self.ctx.net for i in range(net.num_layers): layer = net[i] @@ -501,8 +501,12 @@ def _save_weight_mapping(self) -> None: elif "running_var" in suffix: # Linear layer weight sd_weight_name = f"{sd_weight_name}.running_var" - else: + elif "bias" in suffix: sd_weight_name = f"{sd_weight_name}.bias" + else: + # Save the constant weights for future fast refit + sd_weight_name = f"{sd_weight_name}.unknown" + constant_mapping[engine_weight_name] = weight elif layer_type == "SCALE": # Batch norm needs all weights to calculate scale and shift sd_weight_name = [f"{sd_weight_name}.{n}" for n in torch_attr] @@ -523,12 +527,19 @@ def _save_weight_mapping(self) -> None: weight_name_map[engine_weight_name] = TRTInterpreter.find_weight( engine_weight_name, np_map, sd ) + if ( + weight_name_map[engine_weight_name] != "" + and engine_weight_name in constant_mapping + ): + # If the weight is found in state_dict, remove it from constant_mapping + del constant_mapping[engine_weight_name] weight_name_map[engine_weight_name] = [ weight_name_map[engine_weight_name], np_map[engine_weight_name].dtype, ] + weight_name_map["constant_mapping"] = constant_mapping self.weight_name_map = weight_name_map del np_map, sd @@ -570,7 +581,7 @@ def run( cached_data = self.engine_cache.check(hash_val) if cached_data is not None: # hit the cache ( - weight_stripped_serialized_engine, + serialized_engine, self._input_names, self._output_names, cached_engine_input_specs, @@ -604,9 +615,7 @@ def run( # refit the cached engine with the new graph module if not self.compilation_settings.strip_engine_weights: runtime = trt.Runtime(TRT_LOGGER) - engine = runtime.deserialize_cuda_engine( - weight_stripped_serialized_engine - ) + engine = runtime.deserialize_cuda_engine(serialized_engine) from torch_tensorrt.dynamo._refit import ( _refit_single_trt_engine_with_gm, @@ -619,16 +628,18 @@ def run( settings=self.compilation_settings, weight_name_map=self.weight_name_map, ) - - # EXCLUDE_WEIGHTS flag must be cleared - serialization_config = engine.create_serialization_config() - serialization_config.clear_flag( - trt.SerializationFlag.EXCLUDE_WEIGHTS - ) - serialized_engine = engine.serialize_with_config( - serialization_config - ) - # As of now, the engine becomes non-refittable because when EXCLUDE_WEIGHTS flag is cleared, the REFIT flag is also cleared by TRT to make the plan file smaller + serialized_engine = engine.serialize() + + # TODO: Waiting for TRT's feature to load the weight-stripped engine + # # EXCLUDE_WEIGHTS flag must be cleared + # serialization_config = engine.create_serialization_config() + # serialization_config.clear_flag( + # trt.SerializationFlag.EXCLUDE_WEIGHTS + # ) + # serialized_engine = engine.serialize_with_config( + # serialization_config + # ) + # # As of now, the engine becomes non-refittable because when EXCLUDE_WEIGHTS flag is cleared, the REFIT flag is also cleared by TRT to make the plan file smaller with io.BytesIO() as engine_bytes: engine_bytes.write(serialized_engine) @@ -678,24 +689,23 @@ def run( self.engine_cache is not None and self.compilation_settings.cache_built_engines ): - # Cache the weight-stripped engine - if self.compilation_settings.strip_engine_weights: - weight_stripped_serialized_engine = serialized_engine - else: - # set EXCLUDE_WEIGHTS flag to strip weights - runtime = trt.Runtime(TRT_LOGGER) - engine = runtime.deserialize_cuda_engine(serialized_engine) - - serialization_config = engine.create_serialization_config() - serialization_config.set_flag(trt.SerializationFlag.EXCLUDE_WEIGHTS) - weight_stripped_serialized_engine = engine.serialize_with_config( - serialization_config - ) - + # TODO: Waiting for TRT's feature to cache the weight-stripped engine + # if not self.compilation_settings.strip_engine_weights: + # # set EXCLUDE_WEIGHTS flag to strip weights + # runtime = trt.Runtime(TRT_LOGGER) + # engine = runtime.deserialize_cuda_engine(serialized_engine) + + # serialization_config = engine.create_serialization_config() + # serialization_config.set_flag(trt.SerializationFlag.EXCLUDE_WEIGHTS) + # serialized_engine = engine.serialize_with_config( + # serialization_config + # ) + + # Cache weighted engine for now self.engine_cache.insert( hash_val, ( - weight_stripped_serialized_engine, + serialized_engine, self._input_names, self._output_names, self.input_specs, diff --git a/tests/py/dynamo/models/test_engine_cache.py b/tests/py/dynamo/models/test_engine_cache.py index 383009e48e..c7b7a32e89 100644 --- a/tests/py/dynamo/models/test_engine_cache.py +++ b/tests/py/dynamo/models/test_engine_cache.py @@ -390,7 +390,6 @@ def remove_timing_cache(path=TIMING_CACHE_PATH): "reuse_cached_engines": reuse_cached_engines, "engine_cache_dir": engine_cache_dir, "engine_cache_size": 1 << 30, # 1GB - "torch_executed_ops": {"torch.ops.aten.relu.default"}, }, ) results.append(compiled_model(*inputs)) # trigger the compilation @@ -453,7 +452,6 @@ def test_torch_compile_with_custom_engine_cache(self): "cache_built_engines": cache_built_engines, "reuse_cached_engines": reuse_cached_engines, "custom_engine_cache": custom_engine_cache, - "torch_executed_ops": {"torch.ops.aten.relu.default"}, }, ) results.append(compiled_model(*inputs)) # trigger the compilation @@ -482,16 +480,58 @@ def test_torch_compile_with_custom_engine_cache(self): for h, count in custom_engine_cache.hashes.items() ] - def test_torch_compile_change_input_shape(self): + def test_torch_trt_compile_change_input_shape(self): # Custom Engine Cache model = models.resnet18(pretrained=True).eval().to("cuda") - engine_cache_dir = "/tmp/test_torch_compile_change_input_shape" + engine_cache_dir = "/tmp/test_torch_trt_compile_change_input_shape" if os.path.exists(engine_cache_dir): shutil.rmtree(engine_cache_dir) custom_engine_cache = MyEngineCache(engine_cache_dir) for i in range(3): inputs = [torch.rand((4 * (i + 1), 3, 224, 224)).to("cuda")] + compiled_model = torch_trt.compile( + model, + inputs=inputs, + **{ + "use_python_runtime": True, + "enabled_precisions": {torch.float}, + "debug": False, + "min_block_size": 1, + "cache_built_engines": True, + "reuse_cached_engines": True, + "custom_engine_cache": custom_engine_cache, + }, + ) + compiled_model(*inputs) + [ + assertions.assertTrue( + count == 0, f"Unintended cache hit for entry ({h}, hit: {count})" + ) + for h, count in custom_engine_cache.hashes.items() + ] + + def test_torch_compile_graph_break(self): + class MyModel(torch.nn.Module): + def forward(self, x): + x = x + x + x = x + x + x = torch.ops.aten.relu.default(x) + x = x + x + x = x + x + x = torch.ops.aten.relu.default(x) + x = x + x + x = x + x + return x + + model = MyModel().eval().cuda() + engine_cache_dir = "/tmp/test_torch_compile_graph_break" + if os.path.exists(engine_cache_dir): + shutil.rmtree(engine_cache_dir) + + custom_engine_cache = MyEngineCache(engine_cache_dir) + inputs = [torch.rand((3, 3, 224, 224)).to("cuda")] + for i in range(3): compiled_model = torch.compile( model, backend="tensorrt", @@ -506,10 +546,12 @@ def test_torch_compile_change_input_shape(self): "torch_executed_ops": {"torch.ops.aten.relu.default"}, }, ) + compiled_model(*inputs) [ assertions.assertTrue( - count == 0, f"Unintended cache hit for entry ({h}, hit: {count})" + count == 2, + f"cache was not hit exactly twice for entry ({h}, hit: {count})", ) for h, count in custom_engine_cache.hashes.items() ] From 31af308accde90c493865e6373db585dec6e03f9 Mon Sep 17 00:00:00 2001 From: Evan Li Date: Tue, 5 Nov 2024 16:17:41 -0800 Subject: [PATCH 17/52] refactor and add tests --- .../dynamo/conversion/_TRTInterpreter.py | 218 ++++----- tests/py/dynamo/models/test_engine_cache.py | 439 ++++++++++++++++++ 2 files changed, 551 insertions(+), 106 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index 698db646b2..1a8ffae7fe 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -546,6 +546,107 @@ def _save_weight_mapping(self) -> None: gc.collect() torch.cuda.empty_cache() + def _insert_engine_to_cache(self, hash_val: str, serialized_engine: bytes) -> None: + # TODO: Waiting for TRT's feature to cache the weight-stripped engine + # if not self.compilation_settings.strip_engine_weights: + # # set EXCLUDE_WEIGHTS flag to strip weights + # runtime = trt.Runtime(TRT_LOGGER) + # engine = runtime.deserialize_cuda_engine(serialized_engine) + + # serialization_config = engine.create_serialization_config() + # serialization_config.set_flag(trt.SerializationFlag.EXCLUDE_WEIGHTS) + # serialized_engine = engine.serialize_with_config( + # serialization_config + # ) + + # Cache weighted engine for now + self.engine_cache.insert( # type: ignore[union-attr] + hash_val, + ( + serialized_engine, + self._input_names, + self._output_names, + self.input_specs, + self.compilation_settings, + self.weight_name_map, + ), + ) + + def _pull_cached_engine(self, hash_val: str) -> Optional[TRTInterpreterResult]: + # query the cached TRT engine + cached_data = self.engine_cache.check(hash_val) # type: ignore[union-attr] + if cached_data is not None: # hit the cache + ( + serialized_engine, + self._input_names, + self._output_names, + cached_engine_input_specs, + engine_compilation_settings, + self.weight_name_map, + ) = cached_data + + setting_compatiblity, incompattible_settings = settings_are_compatible( + self.compilation_settings, engine_compilation_settings + ) + assert ( + setting_compatiblity + ), f"Attempted to refit a cached engine with incompatible settings: {incompattible_settings}, (old_settings: {engine_compilation_settings}, new_settings: {self.compilation_settings})" + + for i, e in enumerate( + [ + Input.equivalent_spec(c, i) + for c, i in zip(cached_engine_input_specs, self.input_specs) + ] + ): + assert ( + e + ), f"Attempted to refit a cached engine built for a different input size (input: {i}, cached size: {cached_engine_input_specs[i]}, new size: {self.input_specs[i]}" + + _LOGGER.info( + "Found the cached engine that corresponds to this graph. It is directly loaded." + ) + + # refit the cached engine with the new graph module + if not self.compilation_settings.strip_engine_weights: + runtime = trt.Runtime(TRT_LOGGER) + engine = runtime.deserialize_cuda_engine(serialized_engine) + + from torch_tensorrt.dynamo._refit import ( + _refit_single_trt_engine_with_gm, + ) + + _refit_single_trt_engine_with_gm( + new_gm=self.module, + old_engine=engine, + input_list=self.input_specs, + settings=self.compilation_settings, + weight_name_map=self.weight_name_map, + ) + serialized_engine = engine.serialize() + + # TODO: Waiting for TRT's feature to load the weight-stripped engine + # # EXCLUDE_WEIGHTS flag must be cleared + # serialization_config = engine.create_serialization_config() + # serialization_config.clear_flag( + # trt.SerializationFlag.EXCLUDE_WEIGHTS + # ) + # serialized_engine = engine.serialize_with_config( + # serialization_config + # ) + # # As of now, the engine becomes non-refittable because when EXCLUDE_WEIGHTS flag is cleared, the REFIT flag is also cleared by TRT to make the plan file smaller + + with io.BytesIO() as engine_bytes: + engine_bytes.write(serialized_engine) + engine_str = engine_bytes.getvalue() + + return TRTInterpreterResult( + engine_str, + self._input_names, + self._output_names, + self.weight_name_map, + ) + return None + def run( self, strict_type_constraints: bool = False, @@ -576,81 +677,10 @@ def run( self.module, self.input_specs, self.compilation_settings ) - if self.compilation_settings.reuse_cached_engines: - # query the cached TRT engine - cached_data = self.engine_cache.check(hash_val) - if cached_data is not None: # hit the cache - ( - serialized_engine, - self._input_names, - self._output_names, - cached_engine_input_specs, - engine_compilation_settings, - self.weight_name_map, - ) = cached_data - - setting_compatiblity, incompattible_settings = ( - settings_are_compatible( - self.compilation_settings, engine_compilation_settings - ) - ) - assert ( - setting_compatiblity - ), f"Attempted to refit a cached engine with incompatible settings: {incompattible_settings}, (old_settings: {engine_compilation_settings}, new_settings: {self.compilation_settings})" - - for i, e in enumerate( - [ - Input.equivalent_spec(c, i) - for c, i in zip(cached_engine_input_specs, self.input_specs) - ] - ): - assert ( - e - ), f"Attempted to refit a cached engine built for a different input size (input: {i}, cached size: {cached_engine_input_specs[i]}, new size: {self.input_specs[i]}" - - _LOGGER.info( - "Found the cached engine that corresponds to this graph. It is directly loaded." - ) - - # refit the cached engine with the new graph module - if not self.compilation_settings.strip_engine_weights: - runtime = trt.Runtime(TRT_LOGGER) - engine = runtime.deserialize_cuda_engine(serialized_engine) - - from torch_tensorrt.dynamo._refit import ( - _refit_single_trt_engine_with_gm, - ) - - _refit_single_trt_engine_with_gm( - new_gm=self.module, - old_engine=engine, - input_list=self.input_specs, - settings=self.compilation_settings, - weight_name_map=self.weight_name_map, - ) - serialized_engine = engine.serialize() - - # TODO: Waiting for TRT's feature to load the weight-stripped engine - # # EXCLUDE_WEIGHTS flag must be cleared - # serialization_config = engine.create_serialization_config() - # serialization_config.clear_flag( - # trt.SerializationFlag.EXCLUDE_WEIGHTS - # ) - # serialized_engine = engine.serialize_with_config( - # serialization_config - # ) - # # As of now, the engine becomes non-refittable because when EXCLUDE_WEIGHTS flag is cleared, the REFIT flag is also cleared by TRT to make the plan file smaller - - with io.BytesIO() as engine_bytes: - engine_bytes.write(serialized_engine) - engine_str = engine_bytes.getvalue() - - return TRTInterpreterResult( - engine_str, - self._input_names, - self._output_names, - self.weight_name_map, - ) + if self.compilation_settings.reuse_cached_engines: + interpreter_result = self._pull_cached_engine(hash_val) + if interpreter_result is not None: # hit the cache + return interpreter_result self._construct_trt_network_def() @@ -682,37 +712,13 @@ def run( builder_config, self.compilation_settings.timing_cache_path ) - # refittable engine - if not self.compilation_settings.immutable_weights: - # Engine caching only for refittable engine - if ( - self.engine_cache is not None - and self.compilation_settings.cache_built_engines - ): - # TODO: Waiting for TRT's feature to cache the weight-stripped engine - # if not self.compilation_settings.strip_engine_weights: - # # set EXCLUDE_WEIGHTS flag to strip weights - # runtime = trt.Runtime(TRT_LOGGER) - # engine = runtime.deserialize_cuda_engine(serialized_engine) - - # serialization_config = engine.create_serialization_config() - # serialization_config.set_flag(trt.SerializationFlag.EXCLUDE_WEIGHTS) - # serialized_engine = engine.serialize_with_config( - # serialization_config - # ) - - # Cache weighted engine for now - self.engine_cache.insert( - hash_val, - ( - serialized_engine, - self._input_names, - self._output_names, - self.input_specs, - self.compilation_settings, - self.weight_name_map, - ), - ) + # Engine caching only for refittable engines + if ( + not self.compilation_settings.immutable_weights + and self.compilation_settings.cache_built_engines + and self.engine_cache is not None + ): + self._insert_engine_to_cache(hash_val, serialized_engine) with io.BytesIO() as engine_bytes: engine_bytes.write(serialized_engine) diff --git a/tests/py/dynamo/models/test_engine_cache.py b/tests/py/dynamo/models/test_engine_cache.py index c7b7a32e89..5044654d81 100644 --- a/tests/py/dynamo/models/test_engine_cache.py +++ b/tests/py/dynamo/models/test_engine_cache.py @@ -555,3 +555,442 @@ def forward(self, x): ) for h, count in custom_engine_cache.hashes.items() ] + + def test_isomorphic_graphs(self): + class MyModel1(torch.nn.Module): + def forward(self, a, b): + return a + b + + class MyModel2(torch.nn.Module): + def forward(self, c, d): + return c + d + + model1 = MyModel1().eval().cuda() + model2 = MyModel2().eval().cuda() + + inputs1 = (torch.randn((2, 3)).to("cuda"), torch.randn((2, 3)).to("cuda")) + inputs2 = (torch.randn((2, 3)).to("cuda"), torch.randn((2, 3)).to("cuda")) + + exp_program1 = torch.export.export(model1, args=inputs1) + exp_program2 = torch.export.export(model2, args=inputs2) + + input_specs1 = ( + torch_trt.Input( + min_shape=(1, 3), + opt_shape=(2, 3), + max_shape=(10, 3), + ), + ) + + input_specs2 = ( + torch_trt.Input( + min_shape=(1, 3), + opt_shape=(2, 3), + max_shape=(10, 3), + ), + ) + + settings1 = CompilationSettings( + cache_built_engines=True, reuse_cached_engines=True + ) + + settings2 = CompilationSettings( + cache_built_engines=True, reuse_cached_engines=True + ) + + hash1 = BaseEngineCache.get_hash(exp_program1.module(), input_specs1, settings1) + hash2 = BaseEngineCache.get_hash(exp_program2.module(), input_specs2, settings2) + + assertions.assertEqual(hash1, hash2) + + # @unittest.skip("benchmark on small models") + def test_caching_small_model(self): + from torch_tensorrt.dynamo._refit import refit_module_weights + + model = models.resnet18(pretrained=True).eval().to("cuda") + + engine_cache_dir = "/tmp/test_caching_small_model" + if os.path.exists(engine_cache_dir): + shutil.rmtree(engine_cache_dir) + + def remove_timing_cache(path=TIMING_CACHE_PATH): + if os.path.exists(path): + os.remove(path) + + inputs = (torch.rand((100, 3, 224, 224)).to("cuda"),) + exp_program = torch.export.export(model, args=inputs) + + # warm up + trt_gm = torch_trt.dynamo.compile( + exp_program, + inputs, + use_python_runtime=True, + enabled_precisions={torch.float}, + debug=False, + min_block_size=1, + cache_built_engines=False, + reuse_cached_engines=False, + strip_engine_weights=False, + refit_identical_engine_weights=False, + ) + torch.cuda.empty_cache() + + compile_times = [[] for _ in range(3)] + inference_times = [[] for _ in range(3)] + results = [[] for _ in range(3)] + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + + interval = 3 + for i in range(interval * 3): + if i < interval: + # non-refittable + immutable_weights = True + strip_engine_weights = False + refit_identical_engine_weights = False + cache_built_engines = reuse_cached_engines = False + # continue + elif i < interval * 2: + # REFIT w/ engine caching + immutable_weights = False + strip_engine_weights = False + refit_identical_engine_weights = False + cache_built_engines = reuse_cached_engines = True + # continue + else: + # REFIT_IDENTICAL w/ engine caching + immutable_weights = False + strip_engine_weights = False + refit_identical_engine_weights = True + cache_built_engines = reuse_cached_engines = True + # continue + + if i % interval == 0: + remove_timing_cache() + + torch._dynamo.reset() + + torch.cuda.synchronize() + start.record() + + trt_gm = torch_trt.dynamo.compile( + exp_program, + tuple(inputs), + use_python_runtime=True, + enabled_precisions={torch.float}, + debug=False, + min_block_size=1, + cache_built_engines=cache_built_engines, + reuse_cached_engines=reuse_cached_engines, + engine_cache_dir=engine_cache_dir, + engine_cache_size=1 << 40, + immutable_weights=immutable_weights, + strip_engine_weights=strip_engine_weights, + refit_identical_engine_weights=refit_identical_engine_weights, + ) + + if strip_engine_weights: + trt_gm = refit_module_weights(trt_gm, exp_program) + + end.record() + torch.cuda.synchronize() + compile_times[i // interval].append(start.elapsed_time(end)) + + # inference + torch.cuda.synchronize() + start.record() + out = trt_gm(*inputs) + end.record() + torch.cuda.synchronize() + inference_times[i // interval].append(start.elapsed_time(end)) + + results[i // interval].append(out) + + torch.cuda.empty_cache() + + cos_sim = cosine_similarity(torch.stack(results[0]), torch.stack(results[1])) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"results[0] doesn't match with results[1]. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + cos_sim = cosine_similarity(torch.stack(results[1]), torch.stack(results[2])) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"results[1] doesn't match with results[2]. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + assertions.assertTrue( + compile_times[1][0] > compile_times[1][1], + msg=f"Engine caching didn't speed up the compilation. Time taken without engine caching: {compile_times[1][0]} ms, time taken with engine caching: {compile_times[1][1]} ms", + ) + + assertions.assertTrue( + compile_times[1][0] > compile_times[1][2], + msg=f"Engine caching didn't speed up the compilation. Time taken without engine caching: {compile_times[1][0]} ms, time taken with engine caching: {compile_times[1][2]} ms", + ) + + assertions.assertTrue( + compile_times[2][0] > compile_times[2][1], + msg=f"Engine caching didn't speed up the compilation. Time taken without engine caching: {compile_times[2][0]} ms, time taken with engine caching: {compile_times[2][1]} ms", + ) + + assertions.assertTrue( + compile_times[2][0] > compile_times[2][2], + msg=f"Engine caching didn't speed up the compilation. Time taken without engine caching: {compile_times[2][0]} ms, time taken with engine caching: {compile_times[2][2]} ms", + ) + + assertions.assertTrue( + compile_times[0][2] > compile_times[1][2], + msg=f"Engine caching is slower than recompiling a non-refittable engine. Recompile a non-refittable engine: {compile_times[0][2]} ms, time taken with engine caching: {compile_times[1][2]} ms", + ) + + assertions.assertTrue( + compile_times[0][2] > compile_times[2][2], + msg=f"Engine caching is slower than recompiling a non-refittable engine. Recompile a non-refittable engine: {compile_times[0][2]} ms, time taken with engine caching: {compile_times[2][2]} ms", + ) + + @unittest.skip("benchmark on llama2") + def test_caching_llama2_model(self): + import torch + from torch_tensorrt.dynamo._refit import refit_module_weights + from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + StoppingCriteriaList, + ) + from transformers.generation.stopping_criteria import ( + EosTokenCriteria, + MaxLengthCriteria, + ) + + def export_llm(model, inputs, min_seq_len=1, max_seq_len=16): + """ + Exports the LLM model into an ExportedProgram with dynamic shapes. + In the case of guard failures due to some PyTorch kernel implements, we also + try to re-export the graph by expressing them as runtime assert nodes + """ + with torch.no_grad(): + # max=1024 has contraint violation error. https://github.com/pytorch/pytorch/issues/125604 + seq_len = torch.export.Dim("seq_len", min=min_seq_len, max=max_seq_len) + try: + print("Trying to export the model using torch.export.export()..") + # strict=False only enables aotautograd tracing and excludes dynamo. + ep = torch.export.export( + model, (inputs,), dynamic_shapes=({1: seq_len},), strict=False + ) + except: + print( + "Trying torch.export._trace._export to trace the graph since torch.export.export() failed" + ) + # This API is used to express the constraint violation guards as asserts in the graph. + ep = torch.export._trace._export( + model, + (inputs,), + dynamic_shapes=({1: seq_len},), + strict=False, + allow_complex_guards_as_runtime_asserts=True, + ) + + return ep + + def generate(model, input_seq, max_tokens, eos_token_id): + """ + Greedy decoding of the model. This generates up to max_tokens. + """ + # Max length of output seq = current input_seq length + max_tokens allowed to generate + max_output_seq_length = input_seq.shape[1] + max_tokens + stopping_criteria = StoppingCriteriaList( + [ + MaxLengthCriteria(max_length=max_output_seq_length), + EosTokenCriteria(eos_token_id=eos_token_id), + ] + ) + + while True: + outputs = model(input_seq) + logits = outputs.logits + next_token_logits = logits[:, -1, :] + next_tokens = torch.argmax(next_token_logits, dim=-1) + input_seq = torch.cat([input_seq, next_tokens[:, None]], dim=-1) + # TODO: Handle batch in this check + if stopping_criteria(input_seq, logits).item(): + break + + return input_seq + + MAX_TOKENS = 32 + DEVICE = torch.device("cuda:0") + + llama_path = "meta-llama/Llama-2-7b-chat-hf" + with torch.no_grad(): + model = AutoModelForCausalLM.from_pretrained( + llama_path, use_cache=False, attn_implementation="eager" + ).eval() + + tokenizer = AutoTokenizer.from_pretrained(llama_path) + + prompt = "What is dynamic programming?" + model_inputs = tokenizer(prompt, return_tensors="pt") + input_ids = model_inputs.input_ids + + llama2_ep = export_llm(model, input_ids, max_seq_len=64) + + engine_cache_dir = "/tmp/test_caching_llama2_model" + if os.path.exists(engine_cache_dir): + shutil.rmtree(engine_cache_dir) + + timing_cache_path = os.path.join( + engine_cache_dir, "llama2_timing_cache_original.bin" + ) + + def remove_timing_cache(path=timing_cache_path): + if os.path.exists(path): + os.remove(path) + + input_ids = input_ids.to(DEVICE) + + # warm up + trt_gm = torch_trt.dynamo.compile( + llama2_ep, + inputs=[input_ids], + use_python_runtime=True, + enabled_precisions={torch.float32}, + debug=False, + min_block_size=1, + truncate_double=True, + device=DEVICE, + disable_tf32=True, + cache_built_engines=False, + reuse_cached_engines=False, + strip_engine_weights=False, + refit_identical_engine_weights=False, + timing_cache_path=timing_cache_path, + ) + torch.cuda.empty_cache() + + compile_times = [[] for _ in range(3)] + inference_times = [[] for _ in range(3)] + results = [[] for _ in range(3)] + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + + interval = 3 + for i in range(interval * 3): + if i < interval: + # non-refittable + immutable_weights = True + strip_engine_weights = False + refit_identical_engine_weights = False + cache_built_engines = reuse_cached_engines = False + elif i < interval * 2: + # REFIT w/ engine caching + immutable_weights = False + strip_engine_weights = False + refit_identical_engine_weights = False + cache_built_engines = reuse_cached_engines = True + else: + # REFIT_IDENTICAL w/ engine caching + immutable_weights = False + strip_engine_weights = False + refit_identical_engine_weights = True + cache_built_engines = reuse_cached_engines = True + + if i % interval == 0: + remove_timing_cache() + + torch._dynamo.reset() + + torch.cuda.synchronize() + start.record() + + trt_gm = torch_trt.dynamo.compile( + llama2_ep, + inputs=[input_ids], + use_python_runtime=True, + enabled_precisions={torch.float32}, + debug=False, + min_block_size=1, + truncate_double=True, + device=DEVICE, + disable_tf32=True, + cache_built_engines=cache_built_engines, + reuse_cached_engines=reuse_cached_engines, + engine_cache_dir=engine_cache_dir, + engine_cache_size=1 << 40, + immutable_weights=immutable_weights, + strip_engine_weights=strip_engine_weights, + refit_identical_engine_weights=refit_identical_engine_weights, + timing_cache_path=timing_cache_path, + ) + + if strip_engine_weights: + trt_gm = refit_module_weights(trt_gm, llama2_ep) + + end.record() + torch.cuda.synchronize() + + compile_times[i // interval].append(start.elapsed_time(end)) + + # inference + torch.cuda.synchronize() + start.record() + + trt_gen_tokens = generate( + trt_gm, input_ids, MAX_TOKENS, tokenizer.eos_token_id + ) + # trt_gen_text = tokenizer.batch_decode( + # trt_gen_tokens, + # skip_special_tokens=True, + # clean_up_tokenization_spaces=False, + # )[0], + results[i // interval].append(trt_gen_tokens) + + end.record() + torch.cuda.synchronize() + + inference_times[i // interval].append(start.elapsed_time(end)) + + torch.cuda.empty_cache() + + cos_sim = cosine_similarity(torch.stack(results[0]), torch.stack(results[1])) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"results[0] doesn't match with results[1]. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + cos_sim = cosine_similarity(torch.stack(results[1]), torch.stack(results[2])) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"results[1] doesn't match with results[2]. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + assertions.assertTrue( + compile_times[1][0] > compile_times[1][1], + msg=f"Engine caching didn't speed up the compilation. Time taken without engine caching: {compile_times[1][0]} ms, time taken with engine caching: {compile_times[1][1]} ms", + ) + + assertions.assertTrue( + compile_times[1][0] > compile_times[1][2], + msg=f"Engine caching didn't speed up the compilation. Time taken without engine caching: {compile_times[1][0]} ms, time taken with engine caching: {compile_times[1][2]} ms", + ) + + assertions.assertTrue( + compile_times[2][0] > compile_times[2][1], + msg=f"Engine caching didn't speed up the compilation. Time taken without engine caching: {compile_times[2][0]} ms, time taken with engine caching: {compile_times[2][1]} ms", + ) + + assertions.assertTrue( + compile_times[2][0] > compile_times[2][2], + msg=f"Engine caching didn't speed up the compilation. Time taken without engine caching: {compile_times[2][0]} ms, time taken with engine caching: {compile_times[2][2]} ms", + ) + + assertions.assertTrue( + compile_times[0][2] > compile_times[1][2], + msg=f"Engine caching is slower than recompiling a non-refittable engine. Recompile a non-refittable engine: {compile_times[0][2]} ms, time taken with engine caching: {compile_times[1][2]} ms", + ) + + assertions.assertTrue( + compile_times[0][2] > compile_times[2][2], + msg=f"Engine caching is slower than recompiling a non-refittable engine. Recompile a non-refittable engine: {compile_times[0][2]} ms, time taken with engine caching: {compile_times[2][2]} ms", + ) From 90bf67927ee6e2fc77a5acd34d25678f06c78f68 Mon Sep 17 00:00:00 2001 From: Evan Li Date: Tue, 5 Nov 2024 18:59:49 -0800 Subject: [PATCH 18/52] update --- py/torch_tensorrt/dynamo/_refit.py | 15 ++++++--------- py/torch_tensorrt/dynamo/_settings.py | 1 + py/torch_tensorrt/dynamo/backend/backends.py | 4 ++++ .../dynamo/conversion/_TRTInterpreter.py | 4 ++-- .../models/test_weight_stripped_engine.py | 18 +++++++----------- 5 files changed, 20 insertions(+), 22 deletions(-) diff --git a/py/torch_tensorrt/dynamo/_refit.py b/py/torch_tensorrt/dynamo/_refit.py index f996af809c..ca379a9ada 100644 --- a/py/torch_tensorrt/dynamo/_refit.py +++ b/py/torch_tensorrt/dynamo/_refit.py @@ -466,16 +466,13 @@ def refit_module_weights( serialization_config = engine.create_serialization_config() serialization_config.clear_flag(trt.SerializationFlag.EXCLUDE_WEIGHTS) serialized_engine = engine.serialize_with_config(serialization_config) - engine = runtime.deserialize_cuda_engine(serialized_engine) - if isinstance(compiled_submodule, PythonTorchTensorRTModule): - compiled_submodule.engine = engine - - if isinstance(compiled_submodule, TorchTensorRTModule): - new_engine_info = list(engine_info) - new_engine_info[ENGINE_IDX] = bytes(serialized_engine) - refitted_engine = torch.classes.tensorrt.Engine(tuple(new_engine_info)) - compiled_submodule.engine = refitted_engine + if isinstance( + compiled_submodule, (PythonTorchTensorRTModule, TorchTensorRTModule) + ): + compiled_submodule.engine = None # Clear the engine for TorchTensorRTModule, otherwise it won't be updated + compiled_submodule.serialized_engine = bytes(serialized_engine) + compiled_submodule.setup_engine() elif inline_module: new_engine_info = list(engine_info) diff --git a/py/torch_tensorrt/dynamo/_settings.py b/py/torch_tensorrt/dynamo/_settings.py index ba94aa10cb..05f6f1c0e6 100644 --- a/py/torch_tensorrt/dynamo/_settings.py +++ b/py/torch_tensorrt/dynamo/_settings.py @@ -139,6 +139,7 @@ class CompilationSettings: "engine_capability", "hardware_compatible", "refit_identical_engine_weights", + "strip_engine_weights", # TODO: @Evan to remove this after implementing caching weight-stripped engines as default? "immutable_weights", "enable_weight_streaming", ) diff --git a/py/torch_tensorrt/dynamo/backend/backends.py b/py/torch_tensorrt/dynamo/backend/backends.py index e15ed0495f..fa808aa20b 100644 --- a/py/torch_tensorrt/dynamo/backend/backends.py +++ b/py/torch_tensorrt/dynamo/backend/backends.py @@ -111,6 +111,10 @@ def _pretraced_backend( logger.warning( "require_full_compilation arg is not applicable for torch.compile with backend='torch_tensorrt" ) + if settings.strip_engine_weights: + logger.warning( + "strip_engine_weights arg is not supported for torch.compile()" + ) trt_compiled = compile_module( gm, torchtrt_inputs, diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index 1a8ffae7fe..02c0407ee7 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -547,7 +547,7 @@ def _save_weight_mapping(self) -> None: torch.cuda.empty_cache() def _insert_engine_to_cache(self, hash_val: str, serialized_engine: bytes) -> None: - # TODO: Waiting for TRT's feature to cache the weight-stripped engine + # TODO: @Evan is waiting for TRT's feature to cache the weight-stripped engine # if not self.compilation_settings.strip_engine_weights: # # set EXCLUDE_WEIGHTS flag to strip weights # runtime = trt.Runtime(TRT_LOGGER) @@ -624,7 +624,7 @@ def _pull_cached_engine(self, hash_val: str) -> Optional[TRTInterpreterResult]: ) serialized_engine = engine.serialize() - # TODO: Waiting for TRT's feature to load the weight-stripped engine + # TODO: @Evan is waiting for TRT's feature to load the weight-stripped engine # # EXCLUDE_WEIGHTS flag must be cleared # serialization_config = engine.create_serialization_config() # serialization_config.clear_flag( diff --git a/tests/py/dynamo/models/test_weight_stripped_engine.py b/tests/py/dynamo/models/test_weight_stripped_engine.py index 630bf10f4e..39d46267e4 100644 --- a/tests/py/dynamo/models/test_weight_stripped_engine.py +++ b/tests/py/dynamo/models/test_weight_stripped_engine.py @@ -94,13 +94,13 @@ def test_three_ways_to_compile_weight_stripped_engine(self): ) gm2_output = gm2(*example_inputs) - # 3. Compile with torch.compile using tensorrt backend - gm3 = torch.compile( - pyt_model, - backend="tensorrt", - options=settings, - ) - gm3_output = gm3(*example_inputs) + # 3. Compile with torch.compile using tensorrt backend, which is not supported to set strip_engine_weights=True + # gm3 = torch.compile( + # pyt_model, + # backend="tensorrt", + # options=settings, + # ) + # gm3_output = gm3(*example_inputs) assertions.assertEqual( gm1_output.sum(), 0, msg="gm1_output should be all zeros" @@ -110,10 +110,6 @@ def test_three_ways_to_compile_weight_stripped_engine(self): gm2_output.sum(), 0, msg="gm2_output should be all zeros" ) - assertions.assertEqual( - gm3_output.sum(), 0, msg="gm3_output should be all zeros" - ) - def test_weight_stripped_engine_sizes(self): pyt_model = models.resnet18(pretrained=True).eval().to("cuda") example_inputs = (torch.randn((100, 3, 224, 224)).to("cuda"),) From a8a34f6a56944ddd933db008836e81840c20a533 Mon Sep 17 00:00:00 2001 From: Evan Li Date: Wed, 6 Nov 2024 15:28:24 -0800 Subject: [PATCH 19/52] increase ENGINE_CACHE_SIZE --- py/torch_tensorrt/dynamo/_defaults.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py/torch_tensorrt/dynamo/_defaults.py b/py/torch_tensorrt/dynamo/_defaults.py index 3256429be3..ee29e95b72 100644 --- a/py/torch_tensorrt/dynamo/_defaults.py +++ b/py/torch_tensorrt/dynamo/_defaults.py @@ -37,7 +37,7 @@ CACHE_BUILT_ENGINES = False REUSE_CACHED_ENGINES = False ENGINE_CACHE_DIR = os.path.join(tempfile.gettempdir(), "torch_tensorrt_engine_cache") -ENGINE_CACHE_SIZE = 1073741824 +ENGINE_CACHE_SIZE = 5368709120 # 5GB CUSTOM_ENGINE_CACHE = None USE_EXPLICIT_TYPING = False USE_FP32_ACC = False From 285bc90cfa38716cb341d15c5ec216fbbc1988fe Mon Sep 17 00:00:00 2001 From: Evan Li Date: Thu, 7 Nov 2024 13:10:24 -0800 Subject: [PATCH 20/52] skip some tests --- tests/py/dynamo/models/test_weight_stripped_engine.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/tests/py/dynamo/models/test_weight_stripped_engine.py b/tests/py/dynamo/models/test_weight_stripped_engine.py index 39d46267e4..fb76038f22 100644 --- a/tests/py/dynamo/models/test_weight_stripped_engine.py +++ b/tests/py/dynamo/models/test_weight_stripped_engine.py @@ -136,12 +136,12 @@ def test_weight_stripped_engine_sizes(self): ) assertions.assertTrue( len(bytes(weight_included_engine)) > len(bytes(weight_stripped_engine)), - msg=f"Weight-stripped engine size is not smaller than the weight included engine size. Weight included engine size: {len(bytes(weight_included_engine))}, weight stripped engine size: {len(bytes(weight_stripped_engine))}", + msg=f"Weight-stripped engine size is not smaller than the weight included engine size. Weight included engine size: {len(bytes(weight_included_engine))}, weight-stripped engine size: {len(bytes(weight_stripped_engine))}", ) assertions.assertTrue( - len(bytes(weight_stripped_engine)) + len(bytes(weight_included_engine)) > len(bytes(weight_stripped_refit_identical_engine)), - msg=f"Weight-stripped refit-identical engine size is not smaller than the weight-stripped engine size. Weight-stripped engine size: {len(bytes(weight_stripped_engine))}, weight-stripped refit-identical engine size: {len(bytes(weight_stripped_refit_identical_engine))}", + msg=f"Weight-stripped refit-identical engine size is not smaller than the weight included engine size. Weight included engine size: {len(bytes(weight_included_engine))}, weight-stripped refit-identical engine size: {len(bytes(weight_stripped_refit_identical_engine))}", ) def test_weight_stripped_engine_results(self): @@ -200,6 +200,9 @@ def test_weight_stripped_engine_results(self): msg=f"refitted_output doesn't match with compiled_model_output. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", ) + @unittest.skip( + "For now, torch-trt will save weighted engine if strip_engine_weights is False. In the near future, we plan to save weight-stripped engine regardless of strip_engine_weights, which is pending on TRT's feature development: NVBug #4914602" + ) def test_engine_caching_saves_weight_stripped_engine(self): pyt_model = models.resnet18(pretrained=True).eval().to("cuda") example_inputs = (torch.randn((100, 3, 224, 224)).to("cuda"),) From 2d152cf978fce2ec240fa7f0e2289cf9a92e18bb Mon Sep 17 00:00:00 2001 From: Evan Li Date: Thu, 7 Nov 2024 14:45:48 -0800 Subject: [PATCH 21/52] fix tests --- py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 7edb50fdb4..27b501f6e1 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -277,7 +277,7 @@ def embedding_bag_validator( node: Node, settings: Optional[CompilationSettings] = None ) -> bool: # Embedding bag op is not refitable - if not settings or not settings.immutable_weights: + if settings and not settings.immutable_weights: return False if not one_user_validator(node): @@ -944,7 +944,7 @@ def aten_ops_slice( def refit_validator(node: Node, settings: Optional[CompilationSettings] = None) -> bool: # cumsum op is not refitable - if not settings or not settings.immutable_weights: + if settings and not settings.immutable_weights: return False return True From d4616082a79a8174cac8f6bc79963d66c55c35ff Mon Sep 17 00:00:00 2001 From: Evan Li Date: Fri, 8 Nov 2024 11:23:58 -0800 Subject: [PATCH 22/52] try fixing cumsum --- py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 27b501f6e1..693e90b0e8 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -12,6 +12,7 @@ from torch_tensorrt.dynamo.conversion import impl from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext from torch_tensorrt.dynamo.conversion._ConverterRegistry import ( + ConverterPriority, dynamo_tensorrt_converter, has_static_shapes_in_args, ) @@ -953,6 +954,7 @@ def refit_validator(node: Node, settings: Optional[CompilationSettings] = None) torch.ops.aten.cumsum.default, capability_validator=refit_validator, supports_dynamic_shapes=True, + priority=ConverterPriority.HIGH, ) @enforce_tensor_types( { From 23d68d5122951419ca43d865e8aabb1a222f7eaa Mon Sep 17 00:00:00 2001 From: Evan Li Date: Fri, 8 Nov 2024 14:12:09 -0800 Subject: [PATCH 23/52] fix windows cross compile, TODO: whether windows support stripping engine? --- py/torch_tensorrt/dynamo/_compiler.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 730d47a254..007f5632a3 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -63,7 +63,6 @@ def cross_compile_for_windows( Set[Union[torch.dtype, dtype]], Tuple[Union[torch.dtype, dtype]] ] = _defaults.ENABLED_PRECISIONS, engine_capability: EngineCapability = _defaults.ENGINE_CAPABILITY, - make_refittable: bool = _defaults.MAKE_REFITTABLE, debug: bool = _defaults.DEBUG, num_avg_timing_iters: int = _defaults.NUM_AVG_TIMING_ITERS, workspace_size: int = _defaults.WORKSPACE_SIZE, @@ -93,6 +92,7 @@ def cross_compile_for_windows( custom_engine_cache: Optional[BaseEngineCache] = _defaults.CUSTOM_ENGINE_CACHE, use_explicit_typing: bool = _defaults.USE_EXPLICIT_TYPING, use_fp32_acc: bool = _defaults.USE_FP32_ACC, + immutable_weights: bool = _defaults.IMMUTABLE_WEIGHTS, enable_weight_streaming: bool = _defaults.ENABLE_WEIGHT_STREAMING, **kwargs: Any, ) -> torch.fx.GraphModule: @@ -132,7 +132,6 @@ def cross_compile_for_windows( assume_dynamic_shape_support (bool): Setting this to true enables the converters work for both dynamic and static shapes. Default: False sparse_weights (bool): Enable sparsity for convolution and fully connected layers. enabled_precision (Set(Union(torch.dtype, torch_tensorrt.dtype))): The set of datatypes that TensorRT can use when selecting kernels - refit (bool): Enable refitting debug (bool): Enable debuggable engine capability (torch_tensorrt.EngineCapability): Restrict kernel selection to safe gpu kernels or safe dla kernels num_avg_timing_iters (int): Number of averaging timing iterations used to select kernels @@ -164,6 +163,7 @@ def cross_compile_for_windows( custom_engine_cache (Optional[BaseEngineCache]): Engine cache instance to use for saving and loading engines. Users can provide their own engine cache by inheriting from BaseEngineCache. If used, engine_cache_dir and engine_cache_size will be ignored. use_explicit_typing (bool): This flag enables strong typing in TensorRT compilation which respects the precisions set in the Pytorch model. This is useful when users have mixed precision graphs. use_fp32_acc (bool): This option inserts cast to FP32 nodes around matmul layers and TensorRT ensures the accumulation of matmul happens in FP32. Use this only when FP16 precision is configured in enabled_precisions. + immutable_weights (bool): Build non-refittable engines. This is useful for some layers that are not refittable. If this argument is set to true, `strip_engine_weights` and `refit_identical_engine_weights` will be ignored. enable_weight_streaming (bool): Enable weight streaming. **kwargs: Any, Returns: @@ -193,14 +193,17 @@ def cross_compile_for_windows( if "refit" in kwargs.keys(): warnings.warn( - "Refit is deprecated. Please use make_refittable=True if you want to enable refitting of the engine.", + "`refit` is deprecated. Engines are refittable by default. Please set immutable_weights=True to build a non-refittable engine whose weights will be fixed.", + DeprecationWarning, + stacklevel=2, + ) + + if "make_refittable" in kwargs.keys(): + warnings.warn( + "`make_refittable` is deprecated. Engines are refittable by default. Please set immutable_weights=True to build a non-refittable engine whose weights will be fixed.", DeprecationWarning, stacklevel=2, ) - if make_refittable: - raise ValueError("Use flag make_refittable only. Flag refit is deprecated.") - else: - make_refittable = kwargs["refit"] engine_capability = EngineCapability._from(engine_capability) @@ -275,7 +278,6 @@ def cross_compile_for_windows( "require_full_compilation": require_full_compilation, "disable_tf32": disable_tf32, "sparse_weights": sparse_weights, - "make_refittable": make_refittable, "engine_capability": engine_capability, "dla_sram_size": dla_sram_size, "dla_local_dram_size": dla_local_dram_size, @@ -286,6 +288,7 @@ def cross_compile_for_windows( "lazy_engine_init": lazy_engine_init, "cache_built_engines": cache_built_engines, "reuse_cached_engines": reuse_cached_engines, + "immutable_weights": immutable_weights, "enable_cross_compile_for_windows": True, "enable_weight_streaming": enable_weight_streaming, } From a928f673e6a8744437aec245c4be314369bb324f Mon Sep 17 00:00:00 2001 From: Evan Li Date: Wed, 13 Nov 2024 15:33:40 -0800 Subject: [PATCH 24/52] CI debug test 1 --- py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 693e90b0e8..4eec5ee74d 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -946,6 +946,9 @@ def aten_ops_slice( def refit_validator(node: Node, settings: Optional[CompilationSettings] = None) -> bool: # cumsum op is not refitable if settings and not settings.immutable_weights: + print("TEST::: cumsum op is not mutable") + print("settings.immutable_weights:", settings.immutable_weights) + print("settings:", settings) return False return True From 02625ca0eb28154fd792e74e7365620b7f1e83e8 Mon Sep 17 00:00:00 2001 From: Evan Li Date: Thu, 14 Nov 2024 10:47:42 -0800 Subject: [PATCH 25/52] CI debug test 2 --- py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 4eec5ee74d..05c514e8a6 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -12,7 +12,6 @@ from torch_tensorrt.dynamo.conversion import impl from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext from torch_tensorrt.dynamo.conversion._ConverterRegistry import ( - ConverterPriority, dynamo_tensorrt_converter, has_static_shapes_in_args, ) @@ -955,9 +954,8 @@ def refit_validator(node: Node, settings: Optional[CompilationSettings] = None) @dynamo_tensorrt_converter( torch.ops.aten.cumsum.default, - capability_validator=refit_validator, + # capability_validator=refit_validator, supports_dynamic_shapes=True, - priority=ConverterPriority.HIGH, ) @enforce_tensor_types( { From c462e40adba1de7801db402803e15c9fec515550 Mon Sep 17 00:00:00 2001 From: Evan Li Date: Fri, 15 Nov 2024 17:53:10 -0800 Subject: [PATCH 26/52] CI debug test 3 --- .../dynamo/conversion/aten_ops_converters.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 05c514e8a6..acc1b94b98 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -945,16 +945,13 @@ def aten_ops_slice( def refit_validator(node: Node, settings: Optional[CompilationSettings] = None) -> bool: # cumsum op is not refitable if settings and not settings.immutable_weights: - print("TEST::: cumsum op is not mutable") - print("settings.immutable_weights:", settings.immutable_weights) - print("settings:", settings) return False return True @dynamo_tensorrt_converter( torch.ops.aten.cumsum.default, - # capability_validator=refit_validator, + capability_validator=refit_validator, supports_dynamic_shapes=True, ) @enforce_tensor_types( @@ -1018,7 +1015,6 @@ def zero_output_validator( torch.ops.aten.as_strided.default, capability_validator=zero_output_validator, ) -@dynamo_tensorrt_converter(torch.ops.aten.as_strided.default) def aten_ops_as_strided( ctx: ConversionContext, target: Target, @@ -2066,7 +2062,6 @@ def aten_ops_div( @dynamo_tensorrt_converter( torch.ops.aten.pow.Tensor_Scalar, supports_dynamic_shapes=True ) -@dynamo_tensorrt_converter(operator.pow, supports_dynamic_shapes=True) def aten_ops_pow( ctx: ConversionContext, target: Target, @@ -3336,7 +3331,6 @@ def aten_ops_copy( @dynamo_tensorrt_converter( torch.ops.aten.remainder.Tensor, supports_dynamic_shapes=True ) -@dynamo_tensorrt_converter(operator.mod, supports_dynamic_shapes=True) @enforce_tensor_types( { 0: (TRTTensor,), From 3d68039a30aa1a72d77f61dad252a2325db3219c Mon Sep 17 00:00:00 2001 From: Evan Li Date: Wed, 20 Nov 2024 06:53:49 +0000 Subject: [PATCH 27/52] reduce -n to 4 for converter tests on CI --- .github/workflows/build-test-tensorrt-linux.yml | 4 ++-- .github/workflows/build-test-tensorrt-windows.yml | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/build-test-tensorrt-linux.yml b/.github/workflows/build-test-tensorrt-linux.yml index 3f4abb9add..998daee824 100644 --- a/.github/workflows/build-test-tensorrt-linux.yml +++ b/.github/workflows/build-test-tensorrt-linux.yml @@ -129,7 +129,7 @@ jobs: export CI_BUILD=1 pushd . cd tests/py/dynamo - python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_converters_test_results.xml -n 8 conversion/ + python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_converters_test_results.xml -n 4 conversion/ popd tests-py-dynamo-fe: @@ -314,4 +314,4 @@ jobs: concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ inputs.repository }}-${{ github.event_name == 'workflow_dispatch' }}-${{ inputs.job-name }} - cancel-in-progress: true \ No newline at end of file + cancel-in-progress: true diff --git a/.github/workflows/build-test-tensorrt-windows.yml b/.github/workflows/build-test-tensorrt-windows.yml index b6eb1d765c..b8cb3040e1 100644 --- a/.github/workflows/build-test-tensorrt-windows.yml +++ b/.github/workflows/build-test-tensorrt-windows.yml @@ -132,7 +132,7 @@ jobs: export CI_BUILD=1 pushd . cd tests/py/dynamo - python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_converters_test_results.xml -n 8 conversion/ + python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_converters_test_results.xml -n 4 conversion/ popd tests-py-dynamo-fe: @@ -298,4 +298,4 @@ jobs: concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ inputs.repository }}-${{ github.event_name == 'workflow_dispatch' }}-${{ inputs.job-name }} - cancel-in-progress: true \ No newline at end of file + cancel-in-progress: true From 2e7ef3b631337f48053afaf3eeec7241bc618ae5 Mon Sep 17 00:00:00 2001 From: Evan Li Date: Wed, 20 Nov 2024 13:24:07 -0800 Subject: [PATCH 28/52] reduce -n to 4 for converter tests on CI --- .github/workflows/build-test-linux.yml | 2 +- .github/workflows/build-test-windows.yml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/build-test-linux.yml b/.github/workflows/build-test-linux.yml index a660fc4ef2..277b1ef793 100644 --- a/.github/workflows/build-test-linux.yml +++ b/.github/workflows/build-test-linux.yml @@ -137,7 +137,7 @@ jobs: export CI_BUILD=1 pushd . cd tests/py/dynamo - python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_converters_test_results.xml -n 8 conversion/ + python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_converters_test_results.xml -n 4 conversion/ popd tests-py-dynamo-fe: diff --git a/.github/workflows/build-test-windows.yml b/.github/workflows/build-test-windows.yml index 0201ab5ff2..71a1f96f01 100644 --- a/.github/workflows/build-test-windows.yml +++ b/.github/workflows/build-test-windows.yml @@ -119,7 +119,7 @@ jobs: export CI_BUILD=1 pushd . cd tests/py/dynamo - python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_converters_test_results.xml -n 8 conversion/ + python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_converters_test_results.xml -n 4 conversion/ popd tests-py-dynamo-fe: From 9ff165c2528fdfb7728b99b3a174796365382e6c Mon Sep 17 00:00:00 2001 From: Evan Li Date: Thu, 21 Nov 2024 16:22:31 -0800 Subject: [PATCH 29/52] simplify test_different_args_dont_share_cached_engine --- .../dynamo/models/test_weight_stripped_engine.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/tests/py/dynamo/models/test_weight_stripped_engine.py b/tests/py/dynamo/models/test_weight_stripped_engine.py index fb76038f22..e647d623b5 100644 --- a/tests/py/dynamo/models/test_weight_stripped_engine.py +++ b/tests/py/dynamo/models/test_weight_stripped_engine.py @@ -411,13 +411,24 @@ def remove_timing_cache(path=TIMING_CACHE_PATH): ) def test_different_args_dont_share_cached_engine(self): - pyt_model = models.resnet18(pretrained=True).eval().to("cuda") + class MyModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 4, 3, stride=1, bias=True) + self.relu = torch.nn.ReLU() + + def forward(self, x): + out = self.conv(x) + out = self.relu(out) + return out + + pyt_model = MyModel().eval().to("cuda") engine_cache_dir = "/tmp/test_different_args_dont_share_cached_engine" if os.path.exists(engine_cache_dir): shutil.rmtree(engine_cache_dir) - inputs = [torch.rand((128, 3, 224, 224)).to("cuda")] + inputs = [torch.rand((4, 3, 32, 32)).to("cuda")] for i in range(2): if i == 0: From 8ca8e2dc17d21081bd6846a6594b06484b98e780 Mon Sep 17 00:00:00 2001 From: Evan Li Date: Fri, 22 Nov 2024 02:14:01 +0000 Subject: [PATCH 30/52] reduce -n to 2 --- .github/workflows/build-test-linux.yml | 2 +- .github/workflows/build-test-windows.yml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/build-test-linux.yml b/.github/workflows/build-test-linux.yml index 277b1ef793..28ae8e10d6 100644 --- a/.github/workflows/build-test-linux.yml +++ b/.github/workflows/build-test-linux.yml @@ -137,7 +137,7 @@ jobs: export CI_BUILD=1 pushd . cd tests/py/dynamo - python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_converters_test_results.xml -n 4 conversion/ + python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_converters_test_results.xml -n 2 conversion/ popd tests-py-dynamo-fe: diff --git a/.github/workflows/build-test-windows.yml b/.github/workflows/build-test-windows.yml index 71a1f96f01..4e352608ad 100644 --- a/.github/workflows/build-test-windows.yml +++ b/.github/workflows/build-test-windows.yml @@ -119,7 +119,7 @@ jobs: export CI_BUILD=1 pushd . cd tests/py/dynamo - python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_converters_test_results.xml -n 4 conversion/ + python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_converters_test_results.xml -n 2 conversion/ popd tests-py-dynamo-fe: From f9f2a7038f716be16ca61c48df140b7d5c4b97be Mon Sep 17 00:00:00 2001 From: Evan Li Date: Fri, 22 Nov 2024 05:55:26 +0000 Subject: [PATCH 31/52] reduce -n to 1 --- .github/workflows/build-test-linux.yml | 2 +- .github/workflows/build-test-windows.yml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/build-test-linux.yml b/.github/workflows/build-test-linux.yml index 28ae8e10d6..7942681608 100644 --- a/.github/workflows/build-test-linux.yml +++ b/.github/workflows/build-test-linux.yml @@ -137,7 +137,7 @@ jobs: export CI_BUILD=1 pushd . cd tests/py/dynamo - python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_converters_test_results.xml -n 2 conversion/ + python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_converters_test_results.xml -n 1 conversion/ popd tests-py-dynamo-fe: diff --git a/.github/workflows/build-test-windows.yml b/.github/workflows/build-test-windows.yml index 4e352608ad..d0d50cbe75 100644 --- a/.github/workflows/build-test-windows.yml +++ b/.github/workflows/build-test-windows.yml @@ -119,7 +119,7 @@ jobs: export CI_BUILD=1 pushd . cd tests/py/dynamo - python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_converters_test_results.xml -n 2 conversion/ + python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_converters_test_results.xml -n 1 conversion/ popd tests-py-dynamo-fe: From c69c61adb914fd12bd8bd542e2cc0f0835f666ce Mon Sep 17 00:00:00 2001 From: Evan Li Date: Sat, 23 Nov 2024 00:42:27 +0000 Subject: [PATCH 32/52] revert -n back to 4 and chunk converter --- .github/workflows/build-test-linux.yml | 2 +- .github/workflows/build-test-windows.yml | 2 +- .../dynamo/conversion/aten_ops_converters.py | 24 ++++++++ .../dynamo/conversion/impl/slice/ops.py | 55 +++++++++++++++++++ 4 files changed, 81 insertions(+), 2 deletions(-) diff --git a/.github/workflows/build-test-linux.yml b/.github/workflows/build-test-linux.yml index 7942681608..277b1ef793 100644 --- a/.github/workflows/build-test-linux.yml +++ b/.github/workflows/build-test-linux.yml @@ -137,7 +137,7 @@ jobs: export CI_BUILD=1 pushd . cd tests/py/dynamo - python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_converters_test_results.xml -n 1 conversion/ + python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_converters_test_results.xml -n 4 conversion/ popd tests-py-dynamo-fe: diff --git a/.github/workflows/build-test-windows.yml b/.github/workflows/build-test-windows.yml index d0d50cbe75..71a1f96f01 100644 --- a/.github/workflows/build-test-windows.yml +++ b/.github/workflows/build-test-windows.yml @@ -119,7 +119,7 @@ jobs: export CI_BUILD=1 pushd . cd tests/py/dynamo - python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_converters_test_results.xml -n 1 conversion/ + python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_converters_test_results.xml -n 4 conversion/ popd tests-py-dynamo-fe: diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index acc1b94b98..f76bea94b9 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -3651,3 +3651,27 @@ def aten_ops_full( fill_value=args[1], dtype=kwargs.get("dtype", None), ) + + +@dynamo_tensorrt_converter(torch.ops.aten.chunk.default) +@enforce_tensor_types( + { + 0: (TRTTensor,), + } +) +def aten_ops_chunk( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.slice.chunk( + ctx, + target, + SourceIR.ATEN, + name, + args[0], + args[1], + args_bounds_check(args, 2, 0), + ) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py index b58435b489..04eab08c47 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py @@ -324,6 +324,61 @@ def expand( return layer.get_output(0) +def chunk( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, + chunks: int, + dim: int, +) -> TRTTensor: + if chunks <= 0: + raise RuntimeError( + f"chunk expects `chunks` to be greater than 0, got: {chunks}" + ) + + shape = input.shape + dim = get_positive_dim(dim, len(shape)) + + if dim >= len(shape): + raise RuntimeError( + f"chunk expects `dim` to be less than the length of input shape, got: {dim}" + ) + + dynamic_shape = has_dynamic_shape(input.shape) + if dynamic_shape > 0: + # Check whether slice target dim is dynamic shape dim + assert input.shape[dim] != -1, "Can't chunk on dynamic shape dimension!" + + size_dim = shape[dim] + chunk_size = math.ceil(size_dim / chunks) + result = [] + start = 0 + end = min(start + chunk_size, size_dim) + cnt = 0 + + while start < end: + result.append( + slice_op( + ctx, + target, + source_ir, + f"{name}_slice_{cnt}", + input, + dim, + start, + end, + 1, + ) + ) + start = end + end = min(start + chunk_size, size_dim) + cnt += 1 + + return result + + def cumsum( ctx: ConversionContext, target: Target, From 05b560d9a48f16dfd6aac4fbbf232798b9b8b307 Mon Sep 17 00:00:00 2001 From: Evan Li Date: Thu, 28 Nov 2024 00:07:19 +0000 Subject: [PATCH 33/52] change to opt-in feature --- .../dynamo/engine_caching_bert_example.py | 1 + examples/dynamo/engine_caching_example.py | 5 + .../dynamo/mutable_torchtrt_module_example.py | 2 + examples/dynamo/refit_engine_example.py | 7 +- py/torch_tensorrt/dynamo/_compiler.py | 101 ++++++++++++++++-- py/torch_tensorrt/dynamo/_defaults.py | 2 +- py/torch_tensorrt/dynamo/_refit.py | 4 + py/torch_tensorrt/dynamo/backend/backends.py | 2 +- .../dynamo/conversion/_TRTInterpreter.py | 13 +-- .../runtime/_MutableTorchTensorRTModule.py | 7 +- .../py/dynamo/conversion/test_cumsum_aten.py | 4 + .../conversion/test_embedding_bag_aten.py | 4 + tests/py/dynamo/models/test_engine_cache.py | 19 +++- tests/py/dynamo/models/test_model_refit.py | 14 +++ .../models/test_weight_stripped_engine.py | 14 +++ .../runtime/test_mutable_torchtrt_module.py | 8 ++ 16 files changed, 187 insertions(+), 20 deletions(-) diff --git a/examples/dynamo/engine_caching_bert_example.py b/examples/dynamo/engine_caching_bert_example.py index 9cddefd509..1148d4f792 100644 --- a/examples/dynamo/engine_caching_bert_example.py +++ b/examples/dynamo/engine_caching_bert_example.py @@ -52,6 +52,7 @@ def compile_bert(iterations=3): "truncate_double": True, "debug": False, "min_block_size": 1, + "immutable_weights": False, "cache_built_engines": cache_built_engines, "reuse_cached_engines": reuse_cached_engines, "engine_cache_dir": "/tmp/torch_trt_bert_engine_cache", diff --git a/examples/dynamo/engine_caching_example.py b/examples/dynamo/engine_caching_example.py index 20388e9372..fb4c341077 100644 --- a/examples/dynamo/engine_caching_example.py +++ b/examples/dynamo/engine_caching_example.py @@ -62,6 +62,8 @@ def remove_timing_cache(path=TIMING_CACHE_PATH): # engines are saved to disk tied to a hash of their corresponding PyTorch subgraph. If # in a subsequent compilation, either as part of this session or a new session, the cache will # pull the built engine and **refit** the weights which can reduce compilation times by orders of magnitude. +# As such, in order to insert a new engine into the cache (i.e. ``cache_built_engines=True``), +# the engine must be refittable (``immutable_weights=False``). See :ref:`refit_engine_example` for more details. def torch_compile(iterations=3): @@ -95,6 +97,7 @@ def torch_compile(iterations=3): "enabled_precisions": enabled_precisions, "debug": debug, "min_block_size": min_block_size, + "immutable_weights": False, "cache_built_engines": cache_built_engines, "reuse_cached_engines": reuse_cached_engines, }, @@ -154,6 +157,7 @@ def dynamo_compile(iterations=3): enabled_precisions=enabled_precisions, debug=debug, min_block_size=min_block_size, + immutable_weights=False, cache_built_engines=cache_built_engines, reuse_cached_engines=reuse_cached_engines, engine_cache_size=1 << 30, # 1GB @@ -264,6 +268,7 @@ def torch_compile_my_cache(iterations=3): "enabled_precisions": enabled_precisions, "debug": debug, "min_block_size": min_block_size, + "immutable_weights": False, "cache_built_engines": cache_built_engines, "reuse_cached_engines": reuse_cached_engines, "custom_engine_cache": engine_cache, diff --git a/examples/dynamo/mutable_torchtrt_module_example.py b/examples/dynamo/mutable_torchtrt_module_example.py index 3ea9fab9a5..8b62855c32 100644 --- a/examples/dynamo/mutable_torchtrt_module_example.py +++ b/examples/dynamo/mutable_torchtrt_module_example.py @@ -31,6 +31,7 @@ settings = { "use_python": False, "enabled_precisions": {torch.float32}, + "immutable_weights": False, } model = models.resnet18(pretrained=True).eval().to("cuda") @@ -79,6 +80,7 @@ "use_python_runtime": True, "enabled_precisions": {torch.float16}, "debug": True, + "immutable_weights": False, } model_id = "runwayml/stable-diffusion-v1-5" diff --git a/examples/dynamo/refit_engine_example.py b/examples/dynamo/refit_engine_example.py index 44f78abbc0..66a1a70964 100644 --- a/examples/dynamo/refit_engine_example.py +++ b/examples/dynamo/refit_engine_example.py @@ -46,7 +46,10 @@ # Make a refittable Compilation Program # --------------------------------------- # -# The inital step is to compile a module and save it as with a normal. +# The inital step is to compile a module and save it as with a normal. Note that there is an +# additional parameter `immutable_weights` that is set to `False`. This parameter is used to +# indicate that the engine being built should support weight refitting later. Engines built without +# these setttings will not be able to be refit. # # In this case we are going to compile a ResNet18 model with randomly initialized weights and save it. @@ -66,6 +69,8 @@ debug=debug, min_block_size=min_block_size, torch_executed_ops=torch_executed_ops, + immutable_weights=False, + reuse_cached_engines=False, ) # Output is a torch.fx.GraphModule # Save the graph module as an exported program diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 007f5632a3..93fbd675ec 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -92,6 +92,8 @@ def cross_compile_for_windows( custom_engine_cache: Optional[BaseEngineCache] = _defaults.CUSTOM_ENGINE_CACHE, use_explicit_typing: bool = _defaults.USE_EXPLICIT_TYPING, use_fp32_acc: bool = _defaults.USE_FP32_ACC, + refit_identical_engine_weights: bool = _defaults.REFIT_IDENTICAL_ENGINE_WEIGHTS, + strip_engine_weights: bool = _defaults.STRIP_ENGINE_WEIGHTS, immutable_weights: bool = _defaults.IMMUTABLE_WEIGHTS, enable_weight_streaming: bool = _defaults.ENABLE_WEIGHT_STREAMING, **kwargs: Any, @@ -163,6 +165,8 @@ def cross_compile_for_windows( custom_engine_cache (Optional[BaseEngineCache]): Engine cache instance to use for saving and loading engines. Users can provide their own engine cache by inheriting from BaseEngineCache. If used, engine_cache_dir and engine_cache_size will be ignored. use_explicit_typing (bool): This flag enables strong typing in TensorRT compilation which respects the precisions set in the Pytorch model. This is useful when users have mixed precision graphs. use_fp32_acc (bool): This option inserts cast to FP32 nodes around matmul layers and TensorRT ensures the accumulation of matmul happens in FP32. Use this only when FP16 precision is configured in enabled_precisions. + refit_identical_engine_weights (bool): Refit engines with identical weights. This is useful when the same model is compiled multiple times with different inputs and the weights are the same. This will save time by reusing the same engine for different inputs. + strip_engine_weights (bool): Strip engine weights from the serialized engine. This is useful when the engine is to be deployed in an environment where the weights are not required. immutable_weights (bool): Build non-refittable engines. This is useful for some layers that are not refittable. If this argument is set to true, `strip_engine_weights` and `refit_identical_engine_weights` will be ignored. enable_weight_streaming (bool): Enable weight streaming. **kwargs: Any, @@ -193,17 +197,44 @@ def cross_compile_for_windows( if "refit" in kwargs.keys(): warnings.warn( - "`refit` is deprecated. Engines are refittable by default. Please set immutable_weights=True to build a non-refittable engine whose weights will be fixed.", + "`refit` is deprecated. Please set `immutable_weights=True` to build a non-refittable engine whose weights will be fixed.", DeprecationWarning, stacklevel=2, ) + if immutable_weights: + raise ValueError( + "Use flag `immutable_weights` only. Flag `refit` is deprecated." + ) + else: + immutable_weights = not kwargs["refit"] if "make_refittable" in kwargs.keys(): warnings.warn( - "`make_refittable` is deprecated. Engines are refittable by default. Please set immutable_weights=True to build a non-refittable engine whose weights will be fixed.", + "`make_refittable` is deprecated. Please set `immutable_weights=True` to build a non-refittable engine whose weights will be fixed.", DeprecationWarning, stacklevel=2, ) + if immutable_weights: + raise ValueError( + "Use flag `immutable_weights` only. Flag `refit` is deprecated." + ) + else: + immutable_weights = not kwargs["make_refittable"] + + if refit_identical_engine_weights: + if immutable_weights: + raise ValueError( + "`immutable_weights` must be False when `refit_identical_engine_weights` is True." + ) + + if ( + not immutable_weights + and not refit_identical_engine_weights + and enable_weight_streaming + ): + raise ValueError( + "TensorRT's `REFIT` flag is not compatible with `enable_weight_streaming=True` for now. This issue was reported on https://github.com/pytorch/TensorRT/issues/3305" + ) engine_capability = EngineCapability._from(engine_capability) @@ -288,6 +319,8 @@ def cross_compile_for_windows( "lazy_engine_init": lazy_engine_init, "cache_built_engines": cache_built_engines, "reuse_cached_engines": reuse_cached_engines, + "refit_identical_engine_weights": refit_identical_engine_weights, + "strip_engine_weights": strip_engine_weights, "immutable_weights": immutable_weights, "enable_cross_compile_for_windows": True, "enable_weight_streaming": enable_weight_streaming, @@ -475,17 +508,44 @@ def compile( if "refit" in kwargs.keys(): warnings.warn( - "`refit` is deprecated. Engines are refittable by default. Please set immutable_weights=True to build a non-refittable engine whose weights will be fixed.", + "`refit` is deprecated. Please set `immutable_weights=True` to build a non-refittable engine whose weights will be fixed.", DeprecationWarning, stacklevel=2, ) + if immutable_weights: + raise ValueError( + "Use flag `immutable_weights` only. Flag `refit` is deprecated." + ) + else: + immutable_weights = not kwargs["refit"] if "make_refittable" in kwargs.keys(): warnings.warn( - "`make_refittable` is deprecated. Engines are refittable by default. Please set immutable_weights=True to build a non-refittable engine whose weights will be fixed.", + "`make_refittable` is deprecated. Please set `immutable_weights=True` to build a non-refittable engine whose weights will be fixed.", DeprecationWarning, stacklevel=2, ) + if immutable_weights: + raise ValueError( + "Use flag `immutable_weights` only. Flag `refit` is deprecated." + ) + else: + immutable_weights = not kwargs["make_refittable"] + + if refit_identical_engine_weights: + if immutable_weights: + raise ValueError( + "`immutable_weights` must be False when `refit_identical_engine_weights` is True." + ) + + if ( + not immutable_weights + and not refit_identical_engine_weights + and enable_weight_streaming + ): + raise ValueError( + "TensorRT's `REFIT` flag is not compatible with `enable_weight_streaming=True` for now. This issue was reported on https://github.com/pytorch/TensorRT/issues/3305" + ) if ( "enable_cross_compile_for_windows" in kwargs.keys() @@ -965,18 +1025,47 @@ def convert_exported_program_to_serialized_trt_engine( DeprecationWarning, stacklevel=2, ) + if "refit" in kwargs.keys(): warnings.warn( - "`refit` is deprecated. Engines are refittable by default. Please set immutable_weights=True to build a non-refittable engine whose weights will be fixed.", + "`refit` is deprecated. Please set `immutable_weights=True` to build a non-refittable engine whose weights will be fixed.", DeprecationWarning, stacklevel=2, ) + if immutable_weights: + raise ValueError( + "Use flag `immutable_weights` only. Flag `refit` is deprecated." + ) + else: + immutable_weights = not kwargs["refit"] + if "make_refittable" in kwargs.keys(): warnings.warn( - "`make_refittable` is deprecated. Engines are refittable by default. Please set immutable_weights=True to build a non-refittable engine whose weights will be fixed.", + "`make_refittable` is deprecated. Please set `immutable_weights=True` to build a non-refittable engine whose weights will be fixed.", DeprecationWarning, stacklevel=2, ) + if immutable_weights: + raise ValueError( + "Use flag `immutable_weights` only. Flag `refit` is deprecated." + ) + else: + immutable_weights = not kwargs["make_refittable"] + + if refit_identical_engine_weights: + if immutable_weights: + raise ValueError( + "`immutable_weights` must be False when `refit_identical_engine_weights` is True." + ) + + if ( + not immutable_weights + and not refit_identical_engine_weights + and enable_weight_streaming + ): + raise ValueError( + "TensorRT's `REFIT` flag is not compatible with `enable_weight_streaming=True` for now. This issue was reported on https://github.com/pytorch/TensorRT/issues/3305" + ) if arg_inputs is None and inputs is None: raise AssertionError("'arg_inputs' and 'inputs' should not both be None.") diff --git a/py/torch_tensorrt/dynamo/_defaults.py b/py/torch_tensorrt/dynamo/_defaults.py index 1341ca739f..76630a75a5 100644 --- a/py/torch_tensorrt/dynamo/_defaults.py +++ b/py/torch_tensorrt/dynamo/_defaults.py @@ -43,7 +43,7 @@ USE_FP32_ACC = False REFIT_IDENTICAL_ENGINE_WEIGHTS = False STRIP_ENGINE_WEIGHTS = False -IMMUTABLE_WEIGHTS = False +IMMUTABLE_WEIGHTS = True ENABLE_WEIGHT_STREAMING = False ENABLE_CROSS_COMPILE_FOR_WINDOWS = False diff --git a/py/torch_tensorrt/dynamo/_refit.py b/py/torch_tensorrt/dynamo/_refit.py index ca379a9ada..98e6b627ab 100644 --- a/py/torch_tensorrt/dynamo/_refit.py +++ b/py/torch_tensorrt/dynamo/_refit.py @@ -285,6 +285,10 @@ def refit_module_weights( assert settings is not None + assert ( + not settings.immutable_weights + ), "Refitting is not enabled. Please recompile the engine with immutable_weights=False." + if settings.debug: set_log_level(logger.parent, logging.DEBUG) diff --git a/py/torch_tensorrt/dynamo/backend/backends.py b/py/torch_tensorrt/dynamo/backend/backends.py index fa808aa20b..c8a30e656b 100644 --- a/py/torch_tensorrt/dynamo/backend/backends.py +++ b/py/torch_tensorrt/dynamo/backend/backends.py @@ -112,7 +112,7 @@ def _pretraced_backend( "require_full_compilation arg is not applicable for torch.compile with backend='torch_tensorrt" ) if settings.strip_engine_weights: - logger.warning( + logger.error( "strip_engine_weights arg is not supported for torch.compile()" ) trt_compiled = compile_module( diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index 7ffc02ca3d..d7c0ea449e 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -291,6 +291,8 @@ def _populate_trt_builder_config( # non-refittable engine if self.compilation_settings.strip_engine_weights: _LOGGER.warning("strip_engine_weights will be ignored.") + if self.compilation_settings.refit_identical_engine_weights: + _LOGGER.warning("refit_identical_engine_weights will be ignored.") else: # refittable engine if self.compilation_settings.refit_identical_engine_weights: @@ -496,16 +498,15 @@ def _save_weight_mapping(self) -> None: suffix = sd_weight_name_list[-1] # Retrieve each weight name(s) in state_dict if layer_type == "CONSTANT": - if "embedding" in suffix: - sd_weight_name = f"{sd_weight_name}.weight" - elif "weight" in suffix or "mm_other" in suffix: - # Linear layer weight + if ( + "embedding" in suffix + or "weight" in suffix + or "mm_other" in suffix + ): sd_weight_name = f"{sd_weight_name}.weight" elif "running_mean" in suffix: - # Linear layer weight sd_weight_name = f"{sd_weight_name}.running_mean" elif "running_var" in suffix: - # Linear layer weight sd_weight_name = f"{sd_weight_name}.running_var" elif "bias" in suffix: sd_weight_name = f"{sd_weight_name}.bias" diff --git a/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py index f51707768e..134d84cf6d 100644 --- a/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py @@ -65,6 +65,7 @@ def __init__( Union[torch.dtype, dtype] ] = _defaults.ENABLED_PRECISIONS, engine_capability: EngineCapability = _defaults.ENGINE_CAPABILITY, + immutable_weights: bool = _defaults.IMMUTABLE_WEIGHTS, debug: bool = _defaults.DEBUG, num_avg_timing_iters: int = _defaults.NUM_AVG_TIMING_ITERS, workspace_size: int = _defaults.WORKSPACE_SIZE, @@ -102,7 +103,7 @@ def __init__( assume_dynamic_shape_support (bool): Setting this to true enables the converters work for both dynamic and static shapes. Default: False sparse_weights (bool): Enable sparsity for convolution and fully connected layers. enabled_precision (Set(Union(torch.dtype, torch_tensorrt.dtype))): The set of datatypes that TensorRT can use when selecting kernels - refit (bool): Enable refitting + immutable_weights (bool): Build non-refittable engines. This is useful for some layers that are not refittable. debug (bool): Enable debuggable engine capability (torch_tensorrt.EngineCapability): Restrict kernel selection to safe gpu kernels or safe dla kernels num_avg_timing_iters (int): Number of averaging timing iterations used to select kernels @@ -150,6 +151,9 @@ def __init__( self.kwarg_inputs: dict[str, Any] = {} device = to_torch_tensorrt_device(device) enabled_precisions = {dtype._from(p) for p in enabled_precisions} + assert ( + not immutable_weights + ), "`immutable_weights` has to be False for a MutableTorchTensorRTModule." compilation_options = { "enabled_precisions": ( enabled_precisions @@ -176,6 +180,7 @@ def __init__( "require_full_compilation": require_full_compilation, "disable_tf32": disable_tf32, "sparse_weights": sparse_weights, + "immutable_weights": immutable_weights, "engine_capability": engine_capability, "dla_sram_size": dla_sram_size, "dla_local_dram_size": dla_local_dram_size, diff --git a/tests/py/dynamo/conversion/test_cumsum_aten.py b/tests/py/dynamo/conversion/test_cumsum_aten.py index 4143401bd4..8ab699468d 100644 --- a/tests/py/dynamo/conversion/test_cumsum_aten.py +++ b/tests/py/dynamo/conversion/test_cumsum_aten.py @@ -24,6 +24,7 @@ def forward(self, x): self.run_test( Cumsum(), inputs, + immutable_weights=True, ) @parameterized.expand( @@ -43,6 +44,7 @@ def forward(self, x): self.run_test( Cumsum(), inputs, + immutable_weights=True, ) @parameterized.expand( @@ -63,6 +65,7 @@ def forward(self, x): self.run_test( Cumsum(), inputs, + immutable_weights=True, ) @parameterized.expand( @@ -92,6 +95,7 @@ def forward(self, x): self.run_test_with_dynamic_shape( Cumsum(), inputs, + immutable_weights=True, ) diff --git a/tests/py/dynamo/conversion/test_embedding_bag_aten.py b/tests/py/dynamo/conversion/test_embedding_bag_aten.py index 3fef3d70cf..1f119bd77e 100644 --- a/tests/py/dynamo/conversion/test_embedding_bag_aten.py +++ b/tests/py/dynamo/conversion/test_embedding_bag_aten.py @@ -148,6 +148,7 @@ def forward(self, weight, indices): precision=weight.dtype, enable_passes=True, propagate_shapes=True, + immutable_weights=True, ) @parameterized.expand( @@ -345,6 +346,7 @@ def forward(self, weight, indices, offsets): precision=weight.dtype, enable_passes=True, propagate_shapes=True, + immutable_weights=True, ) @parameterized.expand( @@ -409,6 +411,7 @@ def forward(self, weight, indices, offsets): precision=weight.dtype, enable_passes=True, propagate_shapes=True, + immutable_weights=True, ) @parameterized.expand( @@ -490,6 +493,7 @@ def forward(self, weights, indices, offsets, per_sample_weights=None): min_block_size=1, cache_built_engines=False, reuse_cached_engines=False, + immutable_weights=True, ) # use the inputs with different shape to inference: if per_sample_weights is None: diff --git a/tests/py/dynamo/models/test_engine_cache.py b/tests/py/dynamo/models/test_engine_cache.py index 5044654d81..68451674c5 100644 --- a/tests/py/dynamo/models/test_engine_cache.py +++ b/tests/py/dynamo/models/test_engine_cache.py @@ -74,7 +74,7 @@ def test_reexport_is_equal(self): ), ) settings1 = CompilationSettings( - cache_built_engines=True, reuse_cached_engines=True + immutable_weights=False, cache_built_engines=True, reuse_cached_engines=True ) hash1 = BaseEngineCache.get_hash(exp_program1.module(), input_specs1, settings1) @@ -89,7 +89,7 @@ def test_reexport_is_equal(self): ), ) settings2 = CompilationSettings( - cache_built_engines=True, reuse_cached_engines=True + immutable_weights=False, cache_built_engines=True, reuse_cached_engines=True ) hash2 = BaseEngineCache.get_hash(exp_program2.module(), input_specs2, settings2) @@ -111,7 +111,7 @@ def test_input_shape_change_is_not_equal(self): ), ) settings1 = CompilationSettings( - cache_built_engines=True, reuse_cached_engines=True + immutable_weights=False, cache_built_engines=True, reuse_cached_engines=True ) hash1 = BaseEngineCache.get_hash(exp_program1.module(), input_specs1, settings1) @@ -126,7 +126,7 @@ def test_input_shape_change_is_not_equal(self): ), ) settings2 = CompilationSettings( - cache_built_engines=True, reuse_cached_engines=True + immutable_weights=False, cache_built_engines=True, reuse_cached_engines=True ) hash2 = BaseEngineCache.get_hash(exp_program2.module(), input_specs2, settings2) @@ -148,6 +148,7 @@ def test_engine_settings_is_not_equal(self): ), ) settings1 = CompilationSettings( + immutable_weights=False, cache_built_engines=True, reuse_cached_engines=True, enabled_precisions={torch.float32}, @@ -165,6 +166,7 @@ def test_engine_settings_is_not_equal(self): ), ) settings2 = CompilationSettings( + immutable_weights=False, cache_built_engines=True, reuse_cached_engines=True, enabled_precisions={torch.float32, torch.float16}, @@ -223,6 +225,7 @@ def remove_timing_cache(path=TIMING_CACHE_PATH): enabled_precisions={torch.float}, debug=False, min_block_size=1, + immutable_weights=False, cache_built_engines=cache_built_engines, reuse_cached_engines=reuse_cached_engines, engine_cache_dir=engine_cache_dir, @@ -286,6 +289,7 @@ def test_dynamo_compile_with_custom_engine_cache(self): enabled_precisions={torch.float}, debug=False, min_block_size=1, + immutable_weights=False, cache_built_engines=cache_built_engines, reuse_cached_engines=reuse_cached_engines, custom_engine_cache=custom_engine_cache, @@ -332,6 +336,7 @@ def test_dynamo_compile_change_input_shape(self): enabled_precisions={torch.float}, debug=False, min_block_size=1, + immutable_weights=False, cache_built_engines=True, reuse_cached_engines=True, ) @@ -386,6 +391,7 @@ def remove_timing_cache(path=TIMING_CACHE_PATH): "enabled_precisions": {torch.float}, "debug": False, "min_block_size": 1, + "immutable_weights": False, "cache_built_engines": cache_built_engines, "reuse_cached_engines": reuse_cached_engines, "engine_cache_dir": engine_cache_dir, @@ -449,6 +455,7 @@ def test_torch_compile_with_custom_engine_cache(self): "enabled_precisions": {torch.float}, "debug": False, "min_block_size": 1, + "immutable_weights": False, "cache_built_engines": cache_built_engines, "reuse_cached_engines": reuse_cached_engines, "custom_engine_cache": custom_engine_cache, @@ -498,6 +505,7 @@ def test_torch_trt_compile_change_input_shape(self): "enabled_precisions": {torch.float}, "debug": False, "min_block_size": 1, + "immutable_weights": False, "cache_built_engines": True, "reuse_cached_engines": True, "custom_engine_cache": custom_engine_cache, @@ -540,6 +548,7 @@ def forward(self, x): "enabled_precisions": {torch.float}, "debug": False, "min_block_size": 1, + "immutable_weights": False, "cache_built_engines": True, "reuse_cached_engines": True, "custom_engine_cache": custom_engine_cache, @@ -628,6 +637,7 @@ def remove_timing_cache(path=TIMING_CACHE_PATH): enabled_precisions={torch.float}, debug=False, min_block_size=1, + immutable_weights=False, cache_built_engines=False, reuse_cached_engines=False, strip_engine_weights=False, @@ -858,6 +868,7 @@ def remove_timing_cache(path=timing_cache_path): enabled_precisions={torch.float32}, debug=False, min_block_size=1, + immutable_weights=False, truncate_double=True, device=DEVICE, disable_tf32=True, diff --git a/tests/py/dynamo/models/test_model_refit.py b/tests/py/dynamo/models/test_model_refit.py index 331db1d4fd..bb61ac2d43 100644 --- a/tests/py/dynamo/models/test_model_refit.py +++ b/tests/py/dynamo/models/test_model_refit.py @@ -55,6 +55,7 @@ def test_mapping(): enabled_precisions=enabled_precisions, debug=debug, min_block_size=min_block_size, + immutable_weights=False, ) settings = trt_gm._run_on_acc_0.settings runtime = trt.Runtime(TRT_LOGGER) @@ -106,6 +107,7 @@ def test_refit_one_engine_with_weightmap(): enabled_precisions=enabled_precisions, debug=debug, min_block_size=min_block_size, + immutable_weights=False, ) new_trt_gm = refit_module_weights( @@ -155,6 +157,7 @@ def test_refit_one_engine_no_map_with_weightmap(): enabled_precisions=enabled_precisions, debug=debug, min_block_size=min_block_size, + immutable_weights=False, ) trt_gm._run_on_acc_0.weight_name_map = None @@ -205,6 +208,7 @@ def test_refit_one_engine_with_wrong_weightmap(): enabled_precisions=enabled_precisions, debug=debug, min_block_size=min_block_size, + immutable_weights=False, ) # Manually Deleted all batch norm layer. This suppose to fail the fast refit trt_gm._run_on_acc_0.weight_name_map = { @@ -261,6 +265,7 @@ def test_refit_one_engine_bert_with_weightmap(): enabled_precisions=enabled_precisions, debug=debug, min_block_size=min_block_size, + immutable_weights=False, ) new_trt_gm = refit_module_weights( @@ -313,6 +318,7 @@ def test_refit_one_engine_inline_runtime__with_weightmap(): enabled_precisions=enabled_precisions, debug=debug, min_block_size=min_block_size, + immutable_weights=False, ) torchtrt.save(trt_gm, trt_ep_path) trt_gm = torch.export.load(trt_ep_path) @@ -358,6 +364,7 @@ def test_refit_one_engine_python_runtime_with_weightmap(): enabled_precisions=enabled_precisions, debug=debug, min_block_size=min_block_size, + immutable_weights=False, ) new_trt_gm = refit_module_weights( @@ -427,6 +434,7 @@ def forward(self, x): enabled_precisions=enabled_precisions, debug=debug, min_block_size=min_block_size, + immutable_weights=False, torch_executed_ops=torch_executed_ops, reuse_cached_engines=False, ) @@ -477,6 +485,7 @@ def test_refit_one_engine_without_weightmap(): enabled_precisions=enabled_precisions, debug=debug, min_block_size=min_block_size, + immutable_weights=False, ) new_trt_gm = refit_module_weights( @@ -527,6 +536,7 @@ def test_refit_one_engine_bert_without_weightmap(): enabled_precisions=enabled_precisions, debug=debug, min_block_size=min_block_size, + immutable_weights=False, ) new_trt_gm = refit_module_weights( @@ -579,6 +589,7 @@ def test_refit_one_engine_inline_runtime_without_weightmap(): enabled_precisions=enabled_precisions, debug=debug, min_block_size=min_block_size, + immutable_weights=False, ) torchtrt.save(trt_gm, trt_ep_path) trt_gm = torch.export.load(trt_ep_path) @@ -624,6 +635,7 @@ def test_refit_one_engine_python_runtime_without_weightmap(): enabled_precisions=enabled_precisions, debug=debug, min_block_size=min_block_size, + immutable_weights=False, ) new_trt_gm = refit_module_weights( @@ -693,6 +705,7 @@ def forward(self, x): enabled_precisions=enabled_precisions, debug=debug, min_block_size=min_block_size, + immutable_weights=False, torch_executed_ops=torch_executed_ops, reuse_cached_engines=False, ) @@ -746,6 +759,7 @@ def forward(self, x): enabled_precisions={torch.float}, debug=True, min_block_size=1, + immutable_weights=False, ) num_pyt_segments = len( diff --git a/tests/py/dynamo/models/test_weight_stripped_engine.py b/tests/py/dynamo/models/test_weight_stripped_engine.py index e647d623b5..67cfd167ed 100644 --- a/tests/py/dynamo/models/test_weight_stripped_engine.py +++ b/tests/py/dynamo/models/test_weight_stripped_engine.py @@ -26,6 +26,7 @@ def test_three_ways_to_compile(self): "enabled_precisions": {torch.float}, "debug": False, "min_block_size": 1, + "immutable_weights": False, "strip_engine_weights": False, "refit_identical_engine_weights": False, } @@ -76,6 +77,7 @@ def test_three_ways_to_compile_weight_stripped_engine(self): "enabled_precisions": {torch.float}, "debug": False, "min_block_size": 1, + "immutable_weights": False, "strip_engine_weights": True, "refit_identical_engine_weights": False, } @@ -117,12 +119,14 @@ def test_weight_stripped_engine_sizes(self): weight_included_engine = convert_exported_program_to_serialized_trt_engine( exp_program, example_inputs, + immutable_weights=False, strip_engine_weights=False, refit_identical_engine_weights=False, ) weight_stripped_engine = convert_exported_program_to_serialized_trt_engine( exp_program, example_inputs, + immutable_weights=False, strip_engine_weights=True, refit_identical_engine_weights=False, ) @@ -130,6 +134,7 @@ def test_weight_stripped_engine_sizes(self): convert_exported_program_to_serialized_trt_engine( exp_program, example_inputs, + immutable_weights=False, strip_engine_weights=True, refit_identical_engine_weights=True, ) @@ -162,6 +167,7 @@ def test_weight_stripped_engine_results(self): enabled_precisions={torch.float}, debug=False, min_block_size=1, + immutable_weights=False, strip_engine_weights=True, refit_identical_engine_weights=False, ) @@ -187,6 +193,7 @@ def test_weight_stripped_engine_results(self): "enabled_precisions": {torch.float}, "debug": False, "min_block_size": 1, + "immutable_weights": False, "cache_built_engines": False, "reuse_cached_engines": False, "refit_identical_engine_weights": False, @@ -226,6 +233,7 @@ def test_engine_caching_saves_weight_stripped_engine(self): enabled_precisions={torch.float}, debug=False, min_block_size=1, + immutable_weights=False, strip_engine_weights=False, # engine cache will save the stripped engine even if this is False refit_identical_engine_weights=True, cache_built_engines=True, @@ -291,6 +299,7 @@ def remove_timing_cache(path=TIMING_CACHE_PATH): enabled_precisions={torch.float}, debug=False, min_block_size=1, + immutable_weights=False, cache_built_engines=cache_built_engines, reuse_cached_engines=reuse_cached_engines, engine_cache_dir=engine_cache_dir, @@ -371,6 +380,7 @@ def remove_timing_cache(path=TIMING_CACHE_PATH): "enabled_precisions": {torch.float}, "debug": False, "min_block_size": 1, + "immutable_weights": False, "cache_built_engines": cache_built_engines, "reuse_cached_engines": reuse_cached_engines, "engine_cache_dir": engine_cache_dir, @@ -444,6 +454,7 @@ def forward(self, x): "enabled_precisions": {torch.float}, "debug": False, "min_block_size": 1, + "immutable_weights": False, "cache_built_engines": True, "reuse_cached_engines": True, "engine_cache_dir": engine_cache_dir, @@ -478,6 +489,7 @@ def forward(self, x): ir="dynamo", inputs=tuple(inputs), min_block_size=1, + immutable_weights=False, use_python_runtime=True, strip_engine_weights=True, refit_identical_engine_weights=False, @@ -517,6 +529,7 @@ def test_two_TRTRuntime_in_refitting(self): use_python_runtime=use_python_runtime, debug=False, min_block_size=1, + immutable_weights=False, strip_engine_weights=True, refit_identical_engine_weights=False, ) @@ -549,6 +562,7 @@ def test_refit_identical_engine_weights(self): enabled_precisions={torch.float}, debug=False, min_block_size=1, + immutable_weights=False, strip_engine_weights=True, refit_identical_engine_weights=True, ) diff --git a/tests/py/dynamo/runtime/test_mutable_torchtrt_module.py b/tests/py/dynamo/runtime/test_mutable_torchtrt_module.py index fd9fa4e1e0..f2bcaf7ede 100644 --- a/tests/py/dynamo/runtime/test_mutable_torchtrt_module.py +++ b/tests/py/dynamo/runtime/test_mutable_torchtrt_module.py @@ -49,6 +49,7 @@ def test_resnet18(): compile_spec = { "use_python_runtime": False, "enabled_precisions": {torch.float32}, + "immutable_weights": False, } model = models.resnet18(pretrained=True).eval().to("cuda") @@ -88,6 +89,7 @@ def test_save(): compile_spec = { "use_python_runtime": False, "enabled_precisions": {torch.float32}, + "immutable_weights": False, } model = models.resnet18(pretrained=True).eval().to("cuda") @@ -121,6 +123,7 @@ def test_resnet18_modify_attribute(): compile_spec = { "use_python_runtime": False, "enabled_precisions": {torch.float32}, + "immutable_weights": False, } model = models.resnet18(pretrained=True).eval().to("cuda") @@ -161,6 +164,7 @@ def test_resnet18_modify_attribute_no_refit(): compile_spec = { "use_python_runtime": False, "enabled_precisions": {torch.float32}, + "immutable_weights": False, } model = models.resnet18(pretrained=True).eval().to("cuda") @@ -239,6 +243,7 @@ def forward(self, x, b=5, c=None, d=None): "optimization_level": 1, "min_block_size": 1, "ir": "dynamo", + "immutable_weights": False, } mutable_module = torch_trt.MutableTorchTensorRTModule(model, **compile_spec) @@ -299,6 +304,7 @@ def set_weights(self): "optimization_level": 1, "min_block_size": 1, "ir": "dynamo", + "immutable_weights": False, } mutable_module = torch_trt.MutableTorchTensorRTModule(model, **compile_spec) @@ -361,6 +367,7 @@ def set_layer(self): "optimization_level": 1, "min_block_size": 1, "ir": "dynamo", + "immutable_weights": False, } mutable_module = torch_trt.MutableTorchTensorRTModule(model, **compile_spec) @@ -429,6 +436,7 @@ def forward(self, x, b=5, c=None, d=None): "optimization_level": 1, "min_block_size": 1, "ir": "dynamo", + "immutable_weights": False, } mutable_module = torch_trt.MutableTorchTensorRTModule(model, **compile_spec) From 7feea9706b77e81368dbfd71ee8b35876fcacd5c Mon Sep 17 00:00:00 2001 From: Evan Li Date: Thu, 28 Nov 2024 00:11:40 +0000 Subject: [PATCH 34/52] fix conflict --- tests/py/dynamo/models/test_weight_stripped_engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/py/dynamo/models/test_weight_stripped_engine.py b/tests/py/dynamo/models/test_weight_stripped_engine.py index 67cfd167ed..6f9d10c505 100644 --- a/tests/py/dynamo/models/test_weight_stripped_engine.py +++ b/tests/py/dynamo/models/test_weight_stripped_engine.py @@ -234,7 +234,7 @@ def test_engine_caching_saves_weight_stripped_engine(self): debug=False, min_block_size=1, immutable_weights=False, - strip_engine_weights=False, # engine cache will save the stripped engine even if this is False + strip_engine_weights=False, refit_identical_engine_weights=True, cache_built_engines=True, reuse_cached_engines=True, From d1521c332ff165ce7991fc7724166b3fe0756f28 Mon Sep 17 00:00:00 2001 From: Evan Li Date: Thu, 28 Nov 2024 00:15:08 +0000 Subject: [PATCH 35/52] fix typo --- py/torch_tensorrt/dynamo/_compiler.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 93fbd675ec..48285e363d 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -216,7 +216,7 @@ def cross_compile_for_windows( ) if immutable_weights: raise ValueError( - "Use flag `immutable_weights` only. Flag `refit` is deprecated." + "Use flag `immutable_weights` only. Flag `make_refittable` is deprecated." ) else: immutable_weights = not kwargs["make_refittable"] @@ -527,7 +527,7 @@ def compile( ) if immutable_weights: raise ValueError( - "Use flag `immutable_weights` only. Flag `refit` is deprecated." + "Use flag `immutable_weights` only. Flag `make_refittable` is deprecated." ) else: immutable_weights = not kwargs["make_refittable"] @@ -1047,7 +1047,7 @@ def convert_exported_program_to_serialized_trt_engine( ) if immutable_weights: raise ValueError( - "Use flag `immutable_weights` only. Flag `refit` is deprecated." + "Use flag `immutable_weights` only. Flag `make_refittable` is deprecated." ) else: immutable_weights = not kwargs["make_refittable"] From 0b345be6d824730f0704433e40960142ced0b748 Mon Sep 17 00:00:00 2001 From: Evan Li Date: Tue, 3 Dec 2024 11:02:32 -0800 Subject: [PATCH 36/52] small fix --- py/torch_tensorrt/dynamo/_refit.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/py/torch_tensorrt/dynamo/_refit.py b/py/torch_tensorrt/dynamo/_refit.py index 98e6b627ab..f1041682f8 100644 --- a/py/torch_tensorrt/dynamo/_refit.py +++ b/py/torch_tensorrt/dynamo/_refit.py @@ -177,10 +177,6 @@ def _refit_single_trt_engine_with_gm( mapping.update(constant_mapping_with_type) - # Debug Use - # correct = construct_refit_mapping(new_gm, input_list, settings) - # comparison = {k: (np.allclose(correct[k][0], mapping[k][0].cpu().numpy(), 1e-2, 1e-2), correct[k][0], mapping[k][0]) for k in mapping if k in correct} - for layer_name in weight_list: if layer_name not in mapping: logger.warning(f"{layer_name} is not found in weight mapping.") From 4a7e95794de452168f4c2ce27269b3c549117131 Mon Sep 17 00:00:00 2001 From: Evan Li Date: Tue, 10 Dec 2024 10:45:39 -0800 Subject: [PATCH 37/52] update to manylinux2_28-builder --- .github/scripts/generate_binary_build_matrix.py | 10 +++++----- py/ci/Dockerfile.ci | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/.github/scripts/generate_binary_build_matrix.py b/.github/scripts/generate_binary_build_matrix.py index 4ba7e0faeb..26bb447b4f 100644 --- a/.github/scripts/generate_binary_build_matrix.py +++ b/.github/scripts/generate_binary_build_matrix.py @@ -152,10 +152,10 @@ def initialize_globals(channel: str, build_python_only: bool) -> None: "12.4": "pytorch/manylinux2_28-builder:cuda12.4", "12.6": "pytorch/manylinux2_28-builder:cuda12.6", **{ - gpu_arch: f"pytorch/manylinux-builder:rocm{gpu_arch}" + gpu_arch: f"pytorch/manylinux2_28-builder:rocm{gpu_arch}" for gpu_arch in ROCM_ARCHES }, - CPU: "pytorch/manylinux-builder:cpu", + CPU: "pytorch/manylinux2_28-builder:cpu", XPU: "pytorch/manylinux2_28-builder:xpu", # TODO: Migrate CUDA_AARCH64 image to manylinux2_28_aarch64-builder:cuda12.4 CPU_AARCH64: "pytorch/manylinux2_28_aarch64-builder:cpu-aarch64", @@ -163,7 +163,7 @@ def initialize_globals(channel: str, build_python_only: bool) -> None: } LIBTORCH_CONTAINER_IMAGES = { **{ - (gpu_arch, PRE_CXX11_ABI): f"pytorch/manylinux-builder:cuda{gpu_arch}" + (gpu_arch, PRE_CXX11_ABI): f"pytorch/manylinux2_28-builder:cuda{gpu_arch}" for gpu_arch in CUDA_ARCHES }, **{ @@ -171,14 +171,14 @@ def initialize_globals(channel: str, build_python_only: bool) -> None: for gpu_arch in CUDA_ARCHES }, **{ - (gpu_arch, PRE_CXX11_ABI): f"pytorch/manylinux-builder:rocm{gpu_arch}" + (gpu_arch, PRE_CXX11_ABI): f"pytorch/manylinux2_28-builder:rocm{gpu_arch}" for gpu_arch in ROCM_ARCHES }, **{ (gpu_arch, CXX11_ABI): f"pytorch/libtorch-cxx11-builder:rocm{gpu_arch}" for gpu_arch in ROCM_ARCHES }, - (CPU, PRE_CXX11_ABI): "pytorch/manylinux-builder:cpu", + (CPU, PRE_CXX11_ABI): "pytorch/manylinux2_28-builder:cpu", (CPU, CXX11_ABI): "pytorch/libtorch-cxx11-builder:cpu", } diff --git a/py/ci/Dockerfile.ci b/py/ci/Dockerfile.ci index eddf12cefb..2a690ce2ea 100644 --- a/py/ci/Dockerfile.ci +++ b/py/ci/Dockerfile.ci @@ -1,4 +1,4 @@ -FROM pytorch/manylinux-builder:cuda12.4 +FROM pytorch/manylinux2_28-builder:cuda12.4 RUN yum install -y ninja-build From 6e840ba59709e33fa16429afcce20700a2829a07 Mon Sep 17 00:00:00 2001 From: Evan Li Date: Tue, 10 Dec 2024 11:03:48 -0800 Subject: [PATCH 38/52] remove cuda12.6 tests --- .github/scripts/generate_binary_build_matrix.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/scripts/generate_binary_build_matrix.py b/.github/scripts/generate_binary_build_matrix.py index 26bb447b4f..44856f6647 100644 --- a/.github/scripts/generate_binary_build_matrix.py +++ b/.github/scripts/generate_binary_build_matrix.py @@ -24,7 +24,7 @@ "release": ["3.9", "3.10", "3.11", "3.12"], } CUDA_ARCHES_DICT = { - "nightly": ["11.8", "12.4", "12.6"], + "nightly": ["11.8", "12.4"], "test": ["11.8", "12.1", "12.4"], "release": ["11.8", "12.1", "12.4"], } From 9a8473a05af42f0378a528f5b09b36bc2bd24e59 Mon Sep 17 00:00:00 2001 From: Evan Li Date: Tue, 10 Dec 2024 12:42:18 -0800 Subject: [PATCH 39/52] remove one_user_validator for native_layer_norm --- py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py | 1 - 1 file changed, 1 deletion(-) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 1bd82968fb..94d886eea9 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -133,7 +133,6 @@ def aten_ops_batch_norm_legit_no_training( @dynamo_tensorrt_converter( torch.ops.aten.native_layer_norm.default, - capability_validator=one_user_validator, supports_dynamic_shapes=True, ) @enforce_tensor_types( From 6a077675113cacb1a1030101ff35ac8cfce677ed Mon Sep 17 00:00:00 2001 From: Evan Li Date: Tue, 10 Dec 2024 13:13:12 -0800 Subject: [PATCH 40/52] clear tests --- .../models/test_weight_stripped_engine.py | 41 ++++--------------- 1 file changed, 9 insertions(+), 32 deletions(-) diff --git a/tests/py/dynamo/models/test_weight_stripped_engine.py b/tests/py/dynamo/models/test_weight_stripped_engine.py index 6f9d10c505..0c79ba7a3f 100644 --- a/tests/py/dynamo/models/test_weight_stripped_engine.py +++ b/tests/py/dynamo/models/test_weight_stripped_engine.py @@ -39,19 +39,13 @@ def test_three_ways_to_compile(self): ) gm1_output = gm1(*example_inputs) - # 2. Compile with torch_trt.compile using dynamo backend - gm2 = torch_trt.compile( - pyt_model, ir="dynamo", inputs=example_inputs, **settings - ) - gm2_output = gm2(*example_inputs) - - # 3. Compile with torch.compile using tensorrt backend - gm3 = torch.compile( + # 2. Compile with torch.compile using tensorrt backend + gm2 = torch.compile( pyt_model, backend="tensorrt", options=settings, ) - gm3_output = gm3(*example_inputs) + gm2_output = gm2(*example_inputs) pyt_model_output = pyt_model(*example_inputs) @@ -63,14 +57,9 @@ def test_three_ways_to_compile(self): gm1_output, gm2_output, 1e-2, 1e-2 ), "gm2_output is not correct" - assert torch.allclose( - gm2_output, gm3_output, 1e-2, 1e-2 - ), "gm3_output is not correct" - def test_three_ways_to_compile_weight_stripped_engine(self): pyt_model = models.resnet18(pretrained=True).eval().to("cuda") example_inputs = (torch.randn((100, 3, 224, 224)).to("cuda"),) - exp_program = torch.export.export(pyt_model, example_inputs) settings = { "use_python_runtime": False, @@ -82,36 +71,24 @@ def test_three_ways_to_compile_weight_stripped_engine(self): "refit_identical_engine_weights": False, } - # 1. Compile with torch_trt.dynamo.compile - gm1 = torch_trt.dynamo.compile( - exp_program, - example_inputs, - **settings, - ) - gm1_output = gm1(*example_inputs) - - # 2. Compile with torch_trt.compile using dynamo backend - gm2 = torch_trt.compile( + # 1. Compile with torch_trt.compile using dynamo backend + gm1 = torch_trt.compile( pyt_model, ir="dynamo", inputs=example_inputs, **settings ) - gm2_output = gm2(*example_inputs) + gm1_output = gm1(*example_inputs) - # 3. Compile with torch.compile using tensorrt backend, which is not supported to set strip_engine_weights=True - # gm3 = torch.compile( + # 2. Compile with torch.compile using tensorrt backend, which is not supported to set strip_engine_weights=True + # gm2 = torch.compile( # pyt_model, # backend="tensorrt", # options=settings, # ) - # gm3_output = gm3(*example_inputs) + # gm2_output = gm2(*example_inputs) assertions.assertEqual( gm1_output.sum(), 0, msg="gm1_output should be all zeros" ) - assertions.assertEqual( - gm2_output.sum(), 0, msg="gm2_output should be all zeros" - ) - def test_weight_stripped_engine_sizes(self): pyt_model = models.resnet18(pretrained=True).eval().to("cuda") example_inputs = (torch.randn((100, 3, 224, 224)).to("cuda"),) From ed3424a4c91971d325ad67ffd3624effbd0a2383 Mon Sep 17 00:00:00 2001 From: Evan Li Date: Tue, 10 Dec 2024 14:48:52 -0800 Subject: [PATCH 41/52] remove the whole chunk --- .../dynamo/conversion/aten_ops_converters.py | 24 --- tests/py/dynamo/conversion/test_chunk_aten.py | 187 ------------------ 2 files changed, 211 deletions(-) delete mode 100644 tests/py/dynamo/conversion/test_chunk_aten.py diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 94d886eea9..4d2f97de1c 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -3644,27 +3644,3 @@ def aten_ops_full( fill_value=args[1], dtype=kwargs.get("dtype", None), ) - - -@dynamo_tensorrt_converter(torch.ops.aten.chunk.default) -@enforce_tensor_types( - { - 0: (TRTTensor,), - } -) -def aten_ops_chunk( - ctx: ConversionContext, - target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], - name: str, -) -> Union[TRTTensor, Sequence[TRTTensor]]: - return impl.slice.chunk( - ctx, - target, - SourceIR.ATEN, - name, - args[0], - args[1], - args_bounds_check(args, 2, 0), - ) diff --git a/tests/py/dynamo/conversion/test_chunk_aten.py b/tests/py/dynamo/conversion/test_chunk_aten.py deleted file mode 100644 index eb06c04201..0000000000 --- a/tests/py/dynamo/conversion/test_chunk_aten.py +++ /dev/null @@ -1,187 +0,0 @@ -import unittest - -import torch -from parameterized import parameterized -from torch.testing._internal.common_utils import run_tests -from torch_tensorrt import Input - -from .harness import DispatchTestCase - - -class TestChunkConverter(DispatchTestCase): - @parameterized.expand( - [ - ((1,), 3, 0), - ((3,), 3, 0), - ((4,), 3, 0), - ((6,), 3, 0), - ((3,), 1, -1), - ((3,), 3, -1), - ((3,), 4, -1), - ] - ) - def test_chunk_1D(self, shape, chunks, dim): - class TestChunk(torch.nn.Module): - def forward(self, input): - out = torch.ops.aten.chunk.default(input, chunks, dim) - return out - - input = [torch.randn(shape)] - self.run_test( - TestChunk(), - input, - use_dynamo_tracer=True, - ) - - @parameterized.expand( - [ - ((3, 4), 1, 0), - ((3, 4), 3, 0), - ((3, 4), 4, 0), - ((3, 4), 2, -2), - ((3, 4), 6, -2), - ((3, 4), 3, 1), - ((3, 4), 4, 1), - ((3, 4), 5, -1), - ] - ) - def test_chunk_2D(self, shape, chunks, dim): - class TestChunk(torch.nn.Module): - def forward(self, input): - out = torch.ops.aten.chunk.default(input, chunks, dim) - return out - - input = [torch.randn(shape)] - self.run_test( - TestChunk(), - input, - use_dynamo_tracer=True, - ) - - @parameterized.expand( - [ - ((3, 4, 2), 1, 0), - ((3, 4, 2), 3, -3), - ((3, 4, 2), 3, 1), - ((3, 4, 2), 4, 1), - ((3, 4, 2), 6, -2), - ((3, 4, 2), 1, 2), - ((3, 4, 2), 3, -1), - ((3, 4, 2), 4, -1), - ] - ) - def test_chunk_3D(self, shape, chunks, dim): - class TestChunk(torch.nn.Module): - def forward(self, input): - out = torch.ops.aten.chunk.default(input, chunks, dim) - return out - - input = [torch.randn(shape)] - self.run_test( - TestChunk(), - input, - use_dynamo_tracer=True, - ) - - -#######################Dynamic cases####################### -# The tests are skipped for now. Will be addressed once https://github.com/pytorch/pytorch/issues/134663 is addressed -@unittest.skip( - "Pending aten.split dynamic input torch.export guard bug. Issue- https://github.com/pytorch/pytorch/issues/134663" -) -class TestChunkDynamicConverter(DispatchTestCase): - @parameterized.expand( - [ - ((1,), (1,), (3,), 3, 0), - ((3,), (3,), (4,), 3, 0), - ((4,), (4,), (6,), 3, 0), - ((6,), (6,), (9,), 3, 0), - ((3,), (3,), (4,), 1, -1), - ((3,), (3,), (4,), 3, -1), - ((3,), (3,), (4,), 4, -1), - ] - ) - def test_chunk_1D(self, min_shape, opt_shape, max_shape, chunks, dim): - class TestChunk(torch.nn.Module): - def forward(self, input): - out = torch.ops.aten.chunk.default(input, chunks, dim) - return out - - input_specs = [ - Input( - min_shape=min_shape, - opt_shape=opt_shape, - max_shape=max_shape, - ), - ] - self.run_test_with_dynamic_shape( - TestChunk(), - input_specs, - use_dynamo_tracer=True, - ) - - @parameterized.expand( - [ - ((3, 4), (3, 4), (4, 4), 1, 0), - ((3, 4), (3, 4), (4, 4), 3, 0), - ((3, 4), (3, 4), (4, 4), 4, 0), - ((3, 4), (3, 4), (4, 4), 2, -2), - ((3, 4), (3, 4), (4, 4), 6, -2), - ((3, 4), (3, 4), (4, 4), 3, 1), - ((3, 4), (3, 4), (4, 4), 4, 1), - ((3, 4), (3, 4), (4, 4), 5, -1), - ] - ) - def test_chunk_2D(self, min_shape, opt_shape, max_shape, chunks, dim): - class TestChunk(torch.nn.Module): - def forward(self, input): - out = torch.ops.aten.chunk.default(input, chunks, dim) - return out - - input_specs = [ - Input( - min_shape=min_shape, - opt_shape=opt_shape, - max_shape=max_shape, - ), - ] - self.run_test_with_dynamic_shape( - TestChunk(), - input_specs, - use_dynamo_tracer=True, - ) - - @parameterized.expand( - [ - ((3, 4, 2), (3, 4, 2), (4, 4, 2), 1, 0), - ((3, 4, 2), (3, 4, 2), (4, 4, 2), 3, -3), - ((3, 4, 2), (3, 4, 2), (4, 4, 2), 3, 1), - ((3, 4, 2), (3, 4, 2), (4, 4, 2), 4, 1), - ((3, 4, 2), (3, 4, 2), (4, 4, 2), 6, -2), - ((3, 4, 2), (3, 4, 2), (4, 4, 2), 1, 2), - ((3, 4, 2), (3, 4, 2), (4, 4, 2), 3, -1), - ((3, 4, 2), (3, 4, 2), (4, 4, 2), 4, -1), - ] - ) - def test_chunk_3D(self, min_shape, opt_shape, max_shape, chunks, dim): - class TestChunk(torch.nn.Module): - def forward(self, input): - out = torch.ops.aten.chunk.default(input, chunks, dim) - return out - - input_specs = [ - Input( - min_shape=min_shape, - opt_shape=opt_shape, - max_shape=max_shape, - ), - ] - self.run_test_with_dynamic_shape( - TestChunk(), - input_specs, - use_dynamo_tracer=True, - ) - - -if __name__ == "__main__": - run_tests() From ef54239226a287876d39c3b5ae239c23f7ffce7e Mon Sep 17 00:00:00 2001 From: Evan Li Date: Tue, 10 Dec 2024 14:50:18 -0800 Subject: [PATCH 42/52] add cuda12.6 back and export D_GLIBCXX_USE_CXX11_ABI=1 --- .github/scripts/generate_binary_build_matrix.py | 2 +- packaging/env_vars.txt | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/.github/scripts/generate_binary_build_matrix.py b/.github/scripts/generate_binary_build_matrix.py index 44856f6647..26bb447b4f 100644 --- a/.github/scripts/generate_binary_build_matrix.py +++ b/.github/scripts/generate_binary_build_matrix.py @@ -24,7 +24,7 @@ "release": ["3.9", "3.10", "3.11", "3.12"], } CUDA_ARCHES_DICT = { - "nightly": ["11.8", "12.4"], + "nightly": ["11.8", "12.4", "12.6"], "test": ["11.8", "12.1", "12.4"], "release": ["11.8", "12.1", "12.4"], } diff --git a/packaging/env_vars.txt b/packaging/env_vars.txt index 46f906b1ff..844f712560 100644 --- a/packaging/env_vars.txt +++ b/packaging/env_vars.txt @@ -1,2 +1,3 @@ export CI_BUILD="1" -export RELEASE="1" \ No newline at end of file +export RELEASE="1" +export D_GLIBCXX_USE_CXX11_ABI=1 \ No newline at end of file From f16656260789c4b7dcaf4575ba46a121d9e92a0a Mon Sep 17 00:00:00 2001 From: Evan Li Date: Tue, 10 Dec 2024 15:06:29 -0800 Subject: [PATCH 43/52] fix env --- packaging/env_vars.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packaging/env_vars.txt b/packaging/env_vars.txt index 844f712560..22bab28553 100644 --- a/packaging/env_vars.txt +++ b/packaging/env_vars.txt @@ -1,3 +1,3 @@ export CI_BUILD="1" export RELEASE="1" -export D_GLIBCXX_USE_CXX11_ABI=1 \ No newline at end of file +export D_GLIBCXX_USE_CXX11_ABI="1" \ No newline at end of file From 80aae71694371ff926dee4c6a3456d9d1581a1cb Mon Sep 17 00:00:00 2001 From: Evan Li Date: Tue, 10 Dec 2024 15:33:01 -0800 Subject: [PATCH 44/52] fix container --- .github/scripts/generate_binary_build_matrix.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/scripts/generate_binary_build_matrix.py b/.github/scripts/generate_binary_build_matrix.py index 26bb447b4f..b36339c22f 100644 --- a/.github/scripts/generate_binary_build_matrix.py +++ b/.github/scripts/generate_binary_build_matrix.py @@ -150,7 +150,7 @@ def initialize_globals(channel: str, build_python_only: bool) -> None: "11.8": "pytorch/manylinux2_28-builder:cuda11.8", "12.1": "pytorch/manylinux2_28-builder:cuda12.1", "12.4": "pytorch/manylinux2_28-builder:cuda12.4", - "12.6": "pytorch/manylinux2_28-builder:cuda12.6", + "12.6": "pytorch/manylinux-builder:cuda12.6", **{ gpu_arch: f"pytorch/manylinux2_28-builder:rocm{gpu_arch}" for gpu_arch in ROCM_ARCHES From 676c9ce8f29b50d56d715a841cff74c11ebca036 Mon Sep 17 00:00:00 2001 From: Evan Li Date: Tue, 10 Dec 2024 16:02:20 -0800 Subject: [PATCH 45/52] fix env --- .github/scripts/generate_binary_build_matrix.py | 2 +- packaging/env_vars.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/scripts/generate_binary_build_matrix.py b/.github/scripts/generate_binary_build_matrix.py index b36339c22f..26bb447b4f 100644 --- a/.github/scripts/generate_binary_build_matrix.py +++ b/.github/scripts/generate_binary_build_matrix.py @@ -150,7 +150,7 @@ def initialize_globals(channel: str, build_python_only: bool) -> None: "11.8": "pytorch/manylinux2_28-builder:cuda11.8", "12.1": "pytorch/manylinux2_28-builder:cuda12.1", "12.4": "pytorch/manylinux2_28-builder:cuda12.4", - "12.6": "pytorch/manylinux-builder:cuda12.6", + "12.6": "pytorch/manylinux2_28-builder:cuda12.6", **{ gpu_arch: f"pytorch/manylinux2_28-builder:rocm{gpu_arch}" for gpu_arch in ROCM_ARCHES diff --git a/packaging/env_vars.txt b/packaging/env_vars.txt index 22bab28553..a01c4ac4d8 100644 --- a/packaging/env_vars.txt +++ b/packaging/env_vars.txt @@ -1,3 +1,3 @@ export CI_BUILD="1" export RELEASE="1" -export D_GLIBCXX_USE_CXX11_ABI="1" \ No newline at end of file +export D_GLIBCXX_USE_CXX11_ABI="0" \ No newline at end of file From bf2edc627ed3b4dcc3a2d0f73c351e9ae304a7f7 Mon Sep 17 00:00:00 2001 From: Evan Li Date: Tue, 10 Dec 2024 16:30:50 -0800 Subject: [PATCH 46/52] fix env --- .github/workflows/build-tensorrt-linux.yml | 2 +- packaging/env_vars.txt | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/.github/workflows/build-tensorrt-linux.yml b/.github/workflows/build-tensorrt-linux.yml index 7581c38ae8..ec4b4af872 100644 --- a/.github/workflows/build-tensorrt-linux.yml +++ b/.github/workflows/build-tensorrt-linux.yml @@ -191,7 +191,7 @@ jobs: run: | set -x source "${BUILD_ENV_FILE}" - ${CONDA_RUN} python setup.py bdist_wheel + ${CONDA_RUN} python setup.py --use-cxx11-abi bdist_wheel - name: Run Post-Script if: ${{ inputs.post-script != '' }} diff --git a/packaging/env_vars.txt b/packaging/env_vars.txt index a01c4ac4d8..46f906b1ff 100644 --- a/packaging/env_vars.txt +++ b/packaging/env_vars.txt @@ -1,3 +1,2 @@ export CI_BUILD="1" -export RELEASE="1" -export D_GLIBCXX_USE_CXX11_ABI="0" \ No newline at end of file +export RELEASE="1" \ No newline at end of file From 627d510d7b483c650afceadf08fa2fd76e6eab11 Mon Sep 17 00:00:00 2001 From: Evan Li Date: Tue, 10 Dec 2024 17:00:32 -0800 Subject: [PATCH 47/52] fix env --- .github/workflows/build-tensorrt-linux.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/build-tensorrt-linux.yml b/.github/workflows/build-tensorrt-linux.yml index ec4b4af872..9436f1ecd4 100644 --- a/.github/workflows/build-tensorrt-linux.yml +++ b/.github/workflows/build-tensorrt-linux.yml @@ -191,7 +191,7 @@ jobs: run: | set -x source "${BUILD_ENV_FILE}" - ${CONDA_RUN} python setup.py --use-cxx11-abi bdist_wheel + ${CONDA_RUN} python setup.py bdist_wheel --use-cxx11-abi - name: Run Post-Script if: ${{ inputs.post-script != '' }} From b393b6f1aadb940e9e843067e4d1f45f7bc55b90 Mon Sep 17 00:00:00 2001 From: Evan Li Date: Tue, 10 Dec 2024 17:24:27 -0800 Subject: [PATCH 48/52] fix env --- .github/workflows/linux-test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/linux-test.yml b/.github/workflows/linux-test.yml index 6ddc601f2c..26444695d2 100644 --- a/.github/workflows/linux-test.yml +++ b/.github/workflows/linux-test.yml @@ -43,7 +43,7 @@ on: default: true script: description: 'Script to utilize' - default: "python setup.py bdist_wheel" + default: "python setup.py bdist_wheel --use-cxx11-abi" type: string continue-on-error: description: "Prevents a job from failing when a step fails. Set to true to allow a job to pass when exec script step fails." From 78d72b6a6d0b8896206e8dc14ddcd77103bc94ad Mon Sep 17 00:00:00 2001 From: Evan Li Date: Tue, 10 Dec 2024 23:14:43 -0800 Subject: [PATCH 49/52] fix env --- .github/workflows/build-tensorrt-linux.yml | 2 +- .github/workflows/linux-test.yml | 2 +- packaging/env_vars.txt | 3 ++- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/.github/workflows/build-tensorrt-linux.yml b/.github/workflows/build-tensorrt-linux.yml index 9436f1ecd4..7581c38ae8 100644 --- a/.github/workflows/build-tensorrt-linux.yml +++ b/.github/workflows/build-tensorrt-linux.yml @@ -191,7 +191,7 @@ jobs: run: | set -x source "${BUILD_ENV_FILE}" - ${CONDA_RUN} python setup.py bdist_wheel --use-cxx11-abi + ${CONDA_RUN} python setup.py bdist_wheel - name: Run Post-Script if: ${{ inputs.post-script != '' }} diff --git a/.github/workflows/linux-test.yml b/.github/workflows/linux-test.yml index 26444695d2..6ddc601f2c 100644 --- a/.github/workflows/linux-test.yml +++ b/.github/workflows/linux-test.yml @@ -43,7 +43,7 @@ on: default: true script: description: 'Script to utilize' - default: "python setup.py bdist_wheel --use-cxx11-abi" + default: "python setup.py bdist_wheel" type: string continue-on-error: description: "Prevents a job from failing when a step fails. Set to true to allow a job to pass when exec script step fails." diff --git a/packaging/env_vars.txt b/packaging/env_vars.txt index 46f906b1ff..9442a7fe98 100644 --- a/packaging/env_vars.txt +++ b/packaging/env_vars.txt @@ -1,2 +1,3 @@ export CI_BUILD="1" -export RELEASE="1" \ No newline at end of file +export RELEASE="1" +export CXX11_ABI="1" \ No newline at end of file From a5d3c18458f7b236e50cf32393efac1150d3f811 Mon Sep 17 00:00:00 2001 From: Evan Li Date: Wed, 11 Dec 2024 10:18:26 -0800 Subject: [PATCH 50/52] export USE_CXX11_ABI=1 for cuda12.6 --- packaging/env_vars.txt | 2 +- setup.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/packaging/env_vars.txt b/packaging/env_vars.txt index 9442a7fe98..44a2350e0f 100644 --- a/packaging/env_vars.txt +++ b/packaging/env_vars.txt @@ -1,3 +1,3 @@ export CI_BUILD="1" export RELEASE="1" -export CXX11_ABI="1" \ No newline at end of file +export USE_CXX11_ABI="1" \ No newline at end of file diff --git a/setup.py b/setup.py index 0b8f47fb6f..fee4bf4c45 100644 --- a/setup.py +++ b/setup.py @@ -141,7 +141,9 @@ def load_dep_info(): CXX11_ABI = True if (cxx11_abi_env_var := os.environ.get("USE_CXX11_ABI")) is not None: - if cxx11_abi_env_var == "1": + if ( + cxx11_abi_env_var == "1" and __cuda_version__ == "12.6" + ): # Only use CXX11_ABI for CUDA 12.6 CXX11_ABI = True if platform.uname().processor == "aarch64": From 4f02da8f32a7f2a9b4882e51445ac643d59a99c6 Mon Sep 17 00:00:00 2001 From: Evan Li Date: Wed, 11 Dec 2024 11:48:20 -0800 Subject: [PATCH 51/52] remove chunk --- .../dynamo/conversion/impl/slice/ops.py | 55 ------------------- 1 file changed, 55 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py index 3d9f962fef..3274d78c2b 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py @@ -324,61 +324,6 @@ def expand( return layer.get_output(0) -def chunk( - ctx: ConversionContext, - target: Target, - source_ir: Optional[SourceIR], - name: str, - input: TRTTensor, - chunks: int, - dim: int, -) -> TRTTensor: - if chunks <= 0: - raise RuntimeError( - f"chunk expects `chunks` to be greater than 0, got: {chunks}" - ) - - shape = input.shape - dim = get_positive_dim(dim, len(shape)) - - if dim >= len(shape): - raise RuntimeError( - f"chunk expects `dim` to be less than the length of input shape, got: {dim}" - ) - - dynamic_shape = has_dynamic_shape(input.shape) - if dynamic_shape > 0: - # Check whether slice target dim is dynamic shape dim - assert input.shape[dim] != -1, "Can't chunk on dynamic shape dimension!" - - size_dim = shape[dim] - chunk_size = math.ceil(size_dim / chunks) - result = [] - start = 0 - end = min(start + chunk_size, size_dim) - cnt = 0 - - while start < end: - result.append( - slice_op( - ctx, - target, - source_ir, - f"{name}_slice_{cnt}", - input, - dim, - start, - end, - 1, - ) - ) - start = end - end = min(start + chunk_size, size_dim) - cnt += 1 - - return result - - def cumsum( ctx: ConversionContext, target: Target, From 7d7423a156a0dd2108c1a550edb7190306090116 Mon Sep 17 00:00:00 2001 From: Evan Li Date: Wed, 11 Dec 2024 16:23:57 -0800 Subject: [PATCH 52/52] resolve comments --- packaging/env_vars.txt | 3 +-- py/ci/Dockerfile.ci | 2 +- py/torch_tensorrt/dynamo/_compiler.py | 12 ++++++------ setup.py | 4 +--- 4 files changed, 9 insertions(+), 12 deletions(-) diff --git a/packaging/env_vars.txt b/packaging/env_vars.txt index 44a2350e0f..46f906b1ff 100644 --- a/packaging/env_vars.txt +++ b/packaging/env_vars.txt @@ -1,3 +1,2 @@ export CI_BUILD="1" -export RELEASE="1" -export USE_CXX11_ABI="1" \ No newline at end of file +export RELEASE="1" \ No newline at end of file diff --git a/py/ci/Dockerfile.ci b/py/ci/Dockerfile.ci index 2a690ce2ea..823c8bb7a1 100644 --- a/py/ci/Dockerfile.ci +++ b/py/ci/Dockerfile.ci @@ -1,4 +1,4 @@ -FROM pytorch/manylinux2_28-builder:cuda12.4 +FROM pytorch/manylinux2_28-builder:cuda12.6 RUN yum install -y ninja-build diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 48285e363d..88e66b0f3c 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -197,7 +197,7 @@ def cross_compile_for_windows( if "refit" in kwargs.keys(): warnings.warn( - "`refit` is deprecated. Please set `immutable_weights=True` to build a non-refittable engine whose weights will be fixed.", + "`refit` is deprecated. Please set `immutable_weights=False` to build a refittable engine whose weights can be refitted.", DeprecationWarning, stacklevel=2, ) @@ -210,7 +210,7 @@ def cross_compile_for_windows( if "make_refittable" in kwargs.keys(): warnings.warn( - "`make_refittable` is deprecated. Please set `immutable_weights=True` to build a non-refittable engine whose weights will be fixed.", + "`make_refittable` is deprecated. Please set `immutable_weights=False` to build a refittable engine whose weights can be refitted", DeprecationWarning, stacklevel=2, ) @@ -508,7 +508,7 @@ def compile( if "refit" in kwargs.keys(): warnings.warn( - "`refit` is deprecated. Please set `immutable_weights=True` to build a non-refittable engine whose weights will be fixed.", + "`refit` is deprecated. Please set `immutable_weights=False` to build a refittable engine whose weights can be refitted", DeprecationWarning, stacklevel=2, ) @@ -521,7 +521,7 @@ def compile( if "make_refittable" in kwargs.keys(): warnings.warn( - "`make_refittable` is deprecated. Please set `immutable_weights=True` to build a non-refittable engine whose weights will be fixed.", + "`make_refittable` is deprecated. Please set `immutable_weights=False` to build a refittable engine whose weights can be refitted", DeprecationWarning, stacklevel=2, ) @@ -1028,7 +1028,7 @@ def convert_exported_program_to_serialized_trt_engine( if "refit" in kwargs.keys(): warnings.warn( - "`refit` is deprecated. Please set `immutable_weights=True` to build a non-refittable engine whose weights will be fixed.", + "`refit` is deprecated. Please set `immutable_weights=False` to build a refittable engine whose weights can be refitted", DeprecationWarning, stacklevel=2, ) @@ -1041,7 +1041,7 @@ def convert_exported_program_to_serialized_trt_engine( if "make_refittable" in kwargs.keys(): warnings.warn( - "`make_refittable` is deprecated. Please set `immutable_weights=True` to build a non-refittable engine whose weights will be fixed.", + "`make_refittable` is deprecated. Please set `immutable_weights=False` to build a refittable engine whose weights can be refitted", DeprecationWarning, stacklevel=2, ) diff --git a/setup.py b/setup.py index fee4bf4c45..0b8f47fb6f 100644 --- a/setup.py +++ b/setup.py @@ -141,9 +141,7 @@ def load_dep_info(): CXX11_ABI = True if (cxx11_abi_env_var := os.environ.get("USE_CXX11_ABI")) is not None: - if ( - cxx11_abi_env_var == "1" and __cuda_version__ == "12.6" - ): # Only use CXX11_ABI for CUDA 12.6 + if cxx11_abi_env_var == "1": CXX11_ABI = True if platform.uname().processor == "aarch64":