diff --git a/core/runtime/TRTEngine.cpp b/core/runtime/TRTEngine.cpp index 5a5c1ad83d..52ac58dae5 100644 --- a/core/runtime/TRTEngine.cpp +++ b/core/runtime/TRTEngine.cpp @@ -307,6 +307,9 @@ bool TRTEngine::set_device_memory_budget(int64_t budget) { if (profile_execution) { enable_profiling(); } + // Indicates to reevaluate the runtime settings + has_context_changed = true; + return result; } diff --git a/core/runtime/TRTEngine.h b/core/runtime/TRTEngine.h index 88fb7ab275..9d8b7a91ba 100644 --- a/core/runtime/TRTEngine.h +++ b/core/runtime/TRTEngine.h @@ -102,6 +102,7 @@ struct TRTEngine : torch::CustomClassHolder { std::vector input_buffers = {}; std::vector output_buffers = {}; std::string shape_key; + bool has_context_changed = false; // TODO: Implement a call method // c10::List Run(c10::List inputs); diff --git a/core/runtime/execute_engine.cpp b/core/runtime/execute_engine.cpp index a7908468f4..7aef810818 100644 --- a/core/runtime/execute_engine.cpp +++ b/core/runtime/execute_engine.cpp @@ -115,11 +115,15 @@ std::vector execute_engine(std::vector inputs, c10::intr } // Whether cudagraphs needs to record the graph on this pass - bool need_cudagraphs_record = (CUDAGRAPHS_MODE && (!_cudagraphs_validate_shapes(inputs, compiled_engine))); + bool need_cudagraphs_record = + (CUDAGRAPHS_MODE && + (!_cudagraphs_validate_shapes(inputs, compiled_engine) || compiled_engine->has_context_changed)); - if (!CUDAGRAPHS_MODE) { + if (!CUDAGRAPHS_MODE || compiled_engine->has_context_changed) { compiled_engine->cudagraph.reset(); } + // Reset the flag + compiled_engine->has_context_changed = false; // this is a buffer to store shape tensor input addresses throughout the runtime scope std::list> inputShapeTensorValues; diff --git a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py index e31d73f337..f55eb63e54 100644 --- a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py @@ -107,6 +107,7 @@ def __init__( self.engine = None self.weight_name_map = weight_name_map self.target_platform = Platform.current_platform() + self.has_context_changed = False if self.serialized_engine is not None and not self.settings.lazy_engine_init: self.setup_engine() @@ -126,6 +127,9 @@ def set_device_memory_budget(self, budget_bytes: int) -> int: del self.context budget_bytes = self._set_device_memory_budget(budget_bytes) self.context = self.engine.create_execution_context() + # Indicates to reevaluate the runtime settings + self.has_context_changed = True + return budget_bytes def _set_device_memory_budget(self, budget_bytes: int) -> int: @@ -247,18 +251,21 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, . self._check_initialized() cudagraphs_enabled = torch_tensorrt.runtime.get_cudagraphs_mode() - need_cudagraphs_record = ( - cudagraphs_enabled and not self.cudagraphs_validate_shapes(inputs) + need_cudagraphs_record = cudagraphs_enabled and ( + not self.cudagraphs_validate_shapes(inputs) or self.has_context_changed ) if need_cudagraphs_record: self._input_buffers = [None] * len(self.input_names) self._output_buffers = [None] * len(self.output_names) - if not cudagraphs_enabled and self.cudagraph: + if self.cudagraph and (not cudagraphs_enabled or self.has_context_changed): self.cudagraph.reset() self.cudagraph = None + # Reset the flag + self.has_context_changed = False + # If in safe mode, check at each iteration for for whether a switch is required if ( torch_tensorrt.runtime._multi_device_safe_mode._PY_RT_MULTI_DEVICE_SAFE_MODE