diff --git a/core/runtime/TRTEngine.cpp b/core/runtime/TRTEngine.cpp index adc21bd496..ba78c30a90 100644 --- a/core/runtime/TRTEngine.cpp +++ b/core/runtime/TRTEngine.cpp @@ -296,6 +296,10 @@ int64_t TRTEngine::get_automatic_device_memory_budget() { return cuda_engine->getWeightStreamingAutomaticBudget(); } +void TRTEngine::set_pre_allocated_outputs(bool enable) { + use_pre_allocated_outputs = enable; +} + std::string TRTEngine::to_str() const { // clang-format off std::stringstream ss; diff --git a/core/runtime/TRTEngine.h b/core/runtime/TRTEngine.h index 4895fd006e..41db51158b 100644 --- a/core/runtime/TRTEngine.h +++ b/core/runtime/TRTEngine.h @@ -75,6 +75,7 @@ struct TRTEngine : torch::CustomClassHolder { bool set_device_memory_budget(int64_t budget); int64_t get_streamable_device_memory_budget(); int64_t get_automatic_device_memory_budget(); + void set_pre_allocated_outputs(bool enable); friend std::ostream& operator<<(std::ostream& os, const TRTEngine& engine); static const char BINDING_DELIM = '%'; @@ -85,6 +86,9 @@ struct TRTEngine : torch::CustomClassHolder { std::vector input_buffers = {}; std::vector output_buffers = {}; std::string shape_key; + bool cudagraphs_enabled = false; + bool use_pre_allocated_outputs = true; + std::vector pre_allocated_outputs; // TODO: Implement a call method // c10::List Run(c10::List inputs); diff --git a/core/runtime/execute_engine.cpp b/core/runtime/execute_engine.cpp index a7908468f4..f7ba509494 100644 --- a/core/runtime/execute_engine.cpp +++ b/core/runtime/execute_engine.cpp @@ -5,6 +5,7 @@ #include "torch/csrc/jit/runtime/custom_operator.h" #include "torch/torch.h" +#include #include "core/runtime/TRTEngineProfiler.h" #include "core/runtime/runtime.h" #include "core/util/prelude.h" @@ -60,9 +61,8 @@ RTDevice select_rt_device(const RTDevice& engine_device, const RTDevice& curr_de return new_target_device_opt.value(); } -bool _cudagraphs_validate_shapes(std::vector inputs, c10::intrusive_ptr compiled_engine) { - // Validate whether the current input shapes to the engine - // invalidate the existing cudagraphs object +bool _validate_shapes(std::vector inputs, c10::intrusive_ptr compiled_engine) { + // Validate whether the current input shapes to the engine has changed // Populate the shape key for the inputs // x: (3, 4), y: (4, 5) --> Key: (3,4)(4,5) @@ -83,15 +83,32 @@ bool _cudagraphs_validate_shapes(std::vector inputs, c10::intrusive_ auto new_shape_key = new_shape_key_ss.str(); - // Compare the shape key to the original key and invalidate shapes if they do not match + // Compare the shape key to the original key if (new_shape_key != compiled_engine->shape_key) { - LOG_DEBUG("Resetting Cudagraph on New Shape Key " << new_shape_key); + LOG_DEBUG("Input shape changed " << compiled_engine->shape_key << " -> " << new_shape_key); compiled_engine->shape_key = new_shape_key; - compiled_engine->cudagraph.reset(); - return false; + return true; } - return true; + return false; +} + +std::vector create_output_tensors(c10::intrusive_ptr compiled_engine) { + std::vector outputs(compiled_engine->num_io.second); + for (auto output_indices : compiled_engine->out_binding_map) { + // out_binding_map stores TRT_IDX: PYT_IDX + auto pyt_idx = output_indices.second; + + std::string name = compiled_engine->out_binding_names[pyt_idx]; + auto out_shape = compiled_engine->exec_ctx->getTensorShape(name.c_str()); + LOG_DEBUG("Output Name: " << name << " Shape: " << out_shape); + + auto dims = core::util::toVec(out_shape); + auto type = util::TRTDataTypeToScalarType(compiled_engine->exec_ctx->getEngine().getTensorDataType(name.c_str())); + outputs[pyt_idx] = std::move(at::empty(dims, {at::kCUDA}).to(type).contiguous()); + } + + return outputs; } std::vector execute_engine(std::vector inputs, c10::intrusive_ptr compiled_engine) { @@ -114,10 +131,15 @@ std::vector execute_engine(std::vector inputs, c10::intr compiled_engine->cudagraph.enable_debug_mode(); } + bool shape_changed = _validate_shapes(inputs, compiled_engine); + // Whether cudagraphs needs to record the graph on this pass - bool need_cudagraphs_record = (CUDAGRAPHS_MODE && (!_cudagraphs_validate_shapes(inputs, compiled_engine))); + // Cudagraphs record is required if cudagraphs_enabled is switched to True regardless of shape change + bool need_cudagraphs_record = + (((!compiled_engine->cudagraphs_enabled) && CUDAGRAPHS_MODE) || (CUDAGRAPHS_MODE && shape_changed)); + compiled_engine->cudagraphs_enabled = CUDAGRAPHS_MODE; - if (!CUDAGRAPHS_MODE) { + if (!CUDAGRAPHS_MODE || shape_changed) { compiled_engine->cudagraph.reset(); } @@ -178,6 +200,7 @@ std::vector execute_engine(std::vector inputs, c10::intr { // Input Setup std::unique_ptr input_profiler_guard; + RECORD_FUNCTION("process input", std::vector()); if (compiled_engine->profile_execution) { input_profiler_guard = std::make_unique(compiled_engine->input_profile_path); @@ -259,23 +282,20 @@ std::vector execute_engine(std::vector inputs, c10::intr { // Output Setup std::unique_ptr output_profiler_guard; + RECORD_FUNCTION("process output", std::vector()); if (compiled_engine->profile_execution) { output_profiler_guard = std::make_unique(compiled_engine->output_profile_path); } + if ((false == compiled_engine->use_pre_allocated_outputs) || shape_changed) { + outputs = create_output_tensors(compiled_engine); + } else { + outputs = compiled_engine->pre_allocated_outputs; + } for (auto output_indices : compiled_engine->out_binding_map) { - // out_binding_map stores TRT_IDX: PYT_IDX auto pyt_idx = output_indices.second; - std::string name = compiled_engine->out_binding_names[pyt_idx]; - auto out_shape = compiled_engine->exec_ctx->getTensorShape(name.c_str()); - LOG_DEBUG("Output Name: " << name << " Shape: " << out_shape); - - auto dims = core::util::toVec(out_shape); - auto type = util::TRTDataTypeToScalarType(compiled_engine->exec_ctx->getEngine().getTensorDataType(name.c_str())); - outputs[pyt_idx] = std::move(at::empty(dims, {at::kCUDA}).to(type).contiguous()); - if (need_cudagraphs_record) { // If we are recording the cuda graph then we need to update the persistent output buffer compiled_engine->output_buffers[pyt_idx] = std::move(outputs[pyt_idx].clone()); @@ -311,6 +331,7 @@ std::vector execute_engine(std::vector inputs, c10::intr std::unique_lock lock(compiled_engine->mu); { // Engine Execution (execute on engine stream) + RECORD_FUNCTION("Trt runtime", std::vector()); c10::cuda::CUDAStreamGuard stream_guard(compiled_engine->engine_stream); std::unique_ptr enqueue_profiler_guard; @@ -345,6 +366,11 @@ std::vector execute_engine(std::vector inputs, c10::intr } } // End engine exeuction (resets to caller stream) + // Create output buffer for next execution of graph or trt context. + if (compiled_engine->use_pre_allocated_outputs) { + compiled_engine->pre_allocated_outputs = create_output_tensors(compiled_engine); + } + // Block caller stream until engine execution is complete at::cuda::CUDAEvent trt_exec_complete; trt_exec_complete.record(compiled_engine->engine_stream); diff --git a/core/runtime/register_jit_hooks.cpp b/core/runtime/register_jit_hooks.cpp index 1a0a371562..2918cee367 100644 --- a/core/runtime/register_jit_hooks.cpp +++ b/core/runtime/register_jit_hooks.cpp @@ -86,6 +86,7 @@ static auto TORCHTRT_UNUSED TRTEngineTSRegistrtion = .def("dump_engine_layer_info_to_file", &TRTEngine::dump_engine_layer_info_to_file) .def("dump_engine_layer_info", &TRTEngine::dump_engine_layer_info) .def("get_engine_layer_info", &TRTEngine::get_engine_layer_info) + .def("set_pre_allocated_outputs", &TRTEngine::set_pre_allocated_outputs) .def_property( "device_memory_budget", &TRTEngine::get_device_memory_budget, diff --git a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py index e31d73f337..17a38c716d 100644 --- a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py @@ -107,6 +107,9 @@ def __init__( self.engine = None self.weight_name_map = weight_name_map self.target_platform = Platform.current_platform() + self.cudagraphs_enabled = False + self.pre_allocated_outputs: List[torch.Tensor] = [] + self.use_pre_allocated_outputs = False if self.serialized_engine is not None and not self.settings.lazy_engine_init: self.setup_engine() @@ -171,7 +174,7 @@ def setup_engine(self) -> None: self.engine.get_tensor_shape(input_name) for input_name in self.input_names ] self.output_dtypes = [ - dtype._from(self.engine.get_tensor_dtype(output_name)) + dtype._from(self.engine.get_tensor_dtype(output_name)).to(torch.dtype) for output_name in self.output_names ] self.output_shapes = [ @@ -232,6 +235,19 @@ def __del__(self) -> None: if self.cudagraph: self.cudagraph.reset() + def create_output_tensors(self) -> List[torch.Tensor]: + # create output tensors + outputs: List[torch.Tensor] = [] + + for o, _ in enumerate(self.output_names): + output = torch.empty( + size=self.output_shapes[o], + dtype=self.output_dtypes[o], + device=torch.cuda.current_device(), + ) + outputs.append(output) + return outputs + def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, ...]: # Ensure inputs are available in all scopes and cast symbolic integers to Tensors contiguous_inputs: List[torch.Tensor] = [ @@ -247,11 +263,17 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, . self._check_initialized() cudagraphs_enabled = torch_tensorrt.runtime.get_cudagraphs_mode() - need_cudagraphs_record = ( - cudagraphs_enabled and not self.cudagraphs_validate_shapes(inputs) - ) + shape_changed = self.validate_input_shapes(inputs) + # Cudagraphs record is required if cudagraphs_enabled is toggled to True regardless of shape change + if not self.cudagraphs_enabled and cudagraphs_enabled: + need_cudagraphs_record = True + else: + need_cudagraphs_record = cudagraphs_enabled and shape_changed + self.cudagraphs_enabled = cudagraphs_enabled if need_cudagraphs_record: + if self.cudagraph: + self.cudagraph.reset() self._input_buffers = [None] * len(self.input_names) self._output_buffers = [None] * len(self.output_names) @@ -259,7 +281,7 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, . self.cudagraph.reset() self.cudagraph = None - # If in safe mode, check at each iteration for for whether a switch is required + # If in safe mode, check at each iteration for whether a switch is required if ( torch_tensorrt.runtime._multi_device_safe_mode._PY_RT_MULTI_DEVICE_SAFE_MODE ): @@ -350,14 +372,14 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, . self.context.set_tensor_address( input_name, contiguous_inputs[i].data_ptr() ) - - # Check if input shapes can be inferred. - uninferred_input_names = self.context.infer_shapes() - if uninferred_input_names: - logger.warning( - f"The shapes of the inputs: {uninferred_input_names} cannot be inferred and could lead to undefined behavior. \ - This could happen if the input tensor addresses/shapes haven't been configured correctly" - ) + if shape_changed: + # Check if input shapes can be inferred. + uninferred_input_names = self.context.infer_shapes() + if uninferred_input_names: + logger.warning( + f"The shapes of the inputs: {uninferred_input_names} cannot be inferred and could lead to undefined behavior. \ + This could happen if the input tensor addresses/shapes haven't been configured correctly" + ) with ( torch.autograd.profiler.record_function( @@ -366,24 +388,20 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, . if self.profiling_enabled else nullcontext() ): - # create output tensors - outputs: List[torch.Tensor] = [] - - for o, output_name in enumerate(self.output_names): - shape = tuple(self.context.get_tensor_shape(output_name)) - - if DYNAMIC_DIM in shape: + if not self.use_pre_allocated_outputs or shape_changed: + self.output_shapes = [ + tuple(self.context.get_tensor_shape(output_name)) + for output_name in self.output_names + ] + if DYNAMIC_DIM in self.output_shapes: raise ValueError( "Encountered dynamic output shapes during runtime. This could mean the network has data-dependent output shapes which is not currently supported." ) + outputs = self.create_output_tensors() + else: + outputs = self.pre_allocated_outputs - output = torch.empty( - size=shape, - dtype=self.output_dtypes[o].to(torch.dtype), - device=torch.cuda.current_device(), - ) - - outputs.append(output) + for o, output_name in enumerate(self.output_names): if need_cudagraphs_record: self._output_buffers[o] = outputs[o].clone() @@ -444,6 +462,9 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, . self._caller_stream.wait_stream(self._engine_stream) + if self.use_pre_allocated_outputs: + self.pre_allocated_outputs = self.create_output_tensors() + if cudagraphs_enabled: for idx, o in enumerate(outputs): o.copy_(self._output_buffers[idx]) @@ -485,10 +506,9 @@ def get_layer_info(self) -> str: ) return engine_json - def cudagraphs_validate_shapes(self, inputs: Sequence[torch.Tensor]) -> bool: + def validate_input_shapes(self, inputs: Sequence[torch.Tensor]) -> bool: """ - Validates the input shapes of the forward function - versus the version currently active for the + Validates the input shapes of the forward function has changed """ # Representation of input shapes to a given model # Shapes are concatenated as so: @@ -498,10 +518,8 @@ def cudagraphs_validate_shapes(self, inputs: Sequence[torch.Tensor]) -> bool: # If the new shape key differs from the existing one, # invalidate the old shape key and remove the CUDAGraph if new_shape_key != self.shape_key: - logger.debug(f"Resetting Cudagraph on new shape key {new_shape_key}") + logger.debug(f"Input shape changed {self.shape_key} -> {new_shape_key}") self.shape_key = new_shape_key - if self.cudagraph: - self.cudagraph.reset() - return False + return True - return True + return False diff --git a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py index 1bebe20fda..99f863f1da 100644 --- a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py @@ -207,6 +207,7 @@ def setup_engine(self) -> None: if self.engine is not None: return self.engine = torch.classes.tensorrt.Engine(self._pack_engine_info()) + self.engine.set_pre_allocated_outputs(True) def encode_metadata(self, metadata: Any) -> str: metadata = copy.deepcopy(metadata)