Skip to content

Commit

Permalink
feat: Runtime output buffer optimization
Browse files Browse the repository at this point in the history
  • Loading branch information
keehyuna committed Nov 14, 2024
1 parent c24ef24 commit 1bd9fdc
Show file tree
Hide file tree
Showing 6 changed files with 108 additions and 54 deletions.
4 changes: 4 additions & 0 deletions core/runtime/TRTEngine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,10 @@ int64_t TRTEngine::get_automatic_device_memory_budget() {
return cuda_engine->getWeightStreamingAutomaticBudget();
}

void TRTEngine::set_pre_allocated_outputs(bool enable) {
use_pre_allocated_outputs = enable;
}

std::string TRTEngine::to_str() const {
// clang-format off
std::stringstream ss;
Expand Down
4 changes: 4 additions & 0 deletions core/runtime/TRTEngine.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ struct TRTEngine : torch::CustomClassHolder {
bool set_device_memory_budget(int64_t budget);
int64_t get_streamable_device_memory_budget();
int64_t get_automatic_device_memory_budget();
void set_pre_allocated_outputs(bool enable);
friend std::ostream& operator<<(std::ostream& os, const TRTEngine& engine);
static const char BINDING_DELIM = '%';

Expand All @@ -85,6 +86,9 @@ struct TRTEngine : torch::CustomClassHolder {
std::vector<at::Tensor> input_buffers = {};
std::vector<at::Tensor> output_buffers = {};
std::string shape_key;
bool cudagraphs_enabled = false;
bool use_pre_allocated_outputs = true;
std::vector<at::Tensor> pre_allocated_outputs;

// TODO: Implement a call method
// c10::List<at::Tensor> Run(c10::List<at::Tensor> inputs);
Expand Down
64 changes: 45 additions & 19 deletions core/runtime/execute_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "torch/csrc/jit/runtime/custom_operator.h"
#include "torch/torch.h"

#include <ATen/record_function.h>
#include "core/runtime/TRTEngineProfiler.h"
#include "core/runtime/runtime.h"
#include "core/util/prelude.h"
Expand Down Expand Up @@ -60,9 +61,8 @@ RTDevice select_rt_device(const RTDevice& engine_device, const RTDevice& curr_de
return new_target_device_opt.value();
}

bool _cudagraphs_validate_shapes(std::vector<at::Tensor> inputs, c10::intrusive_ptr<TRTEngine> compiled_engine) {
// Validate whether the current input shapes to the engine
// invalidate the existing cudagraphs object
bool _validate_shapes(std::vector<at::Tensor> inputs, c10::intrusive_ptr<TRTEngine> compiled_engine) {
// Validate whether the current input shapes to the engine has changed

// Populate the shape key for the inputs
// x: (3, 4), y: (4, 5) --> Key: (3,4)(4,5)
Expand All @@ -83,15 +83,32 @@ bool _cudagraphs_validate_shapes(std::vector<at::Tensor> inputs, c10::intrusive_

auto new_shape_key = new_shape_key_ss.str();

// Compare the shape key to the original key and invalidate shapes if they do not match
// Compare the shape key to the original key
if (new_shape_key != compiled_engine->shape_key) {
LOG_DEBUG("Resetting Cudagraph on New Shape Key " << new_shape_key);
LOG_DEBUG("Input shape changed " << compiled_engine->shape_key << " -> " << new_shape_key);
compiled_engine->shape_key = new_shape_key;
compiled_engine->cudagraph.reset();
return false;
return true;
}

return true;
return false;
}

std::vector<at::Tensor> create_output_tensors(c10::intrusive_ptr<TRTEngine> compiled_engine) {
std::vector<at::Tensor> outputs(compiled_engine->num_io.second);
for (auto output_indices : compiled_engine->out_binding_map) {
// out_binding_map stores TRT_IDX: PYT_IDX
auto pyt_idx = output_indices.second;

std::string name = compiled_engine->out_binding_names[pyt_idx];
auto out_shape = compiled_engine->exec_ctx->getTensorShape(name.c_str());
LOG_DEBUG("Output Name: " << name << " Shape: " << out_shape);

auto dims = core::util::toVec(out_shape);
auto type = util::TRTDataTypeToScalarType(compiled_engine->exec_ctx->getEngine().getTensorDataType(name.c_str()));
outputs[pyt_idx] = std::move(at::empty(dims, {at::kCUDA}).to(type).contiguous());
}

return outputs;
}

std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intrusive_ptr<TRTEngine> compiled_engine) {
Expand All @@ -114,10 +131,15 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
compiled_engine->cudagraph.enable_debug_mode();
}

bool shape_changed = _validate_shapes(inputs, compiled_engine);

// Whether cudagraphs needs to record the graph on this pass
bool need_cudagraphs_record = (CUDAGRAPHS_MODE && (!_cudagraphs_validate_shapes(inputs, compiled_engine)));
// Cudagraphs record is required if cudagraphs_enabled is switched to True regardless of shape change
bool need_cudagraphs_record =
(((!compiled_engine->cudagraphs_enabled) && CUDAGRAPHS_MODE) || (CUDAGRAPHS_MODE && shape_changed));
compiled_engine->cudagraphs_enabled = CUDAGRAPHS_MODE;

if (!CUDAGRAPHS_MODE) {
if (!CUDAGRAPHS_MODE || shape_changed) {
compiled_engine->cudagraph.reset();
}

Expand Down Expand Up @@ -178,6 +200,7 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr

{ // Input Setup
std::unique_ptr<torch::autograd::profiler::RecordProfile> input_profiler_guard;
RECORD_FUNCTION("process input", std::vector<c10::IValue>());
if (compiled_engine->profile_execution) {
input_profiler_guard =
std::make_unique<torch::autograd::profiler::RecordProfile>(compiled_engine->input_profile_path);
Expand Down Expand Up @@ -259,23 +282,20 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr

{ // Output Setup
std::unique_ptr<torch::autograd::profiler::RecordProfile> output_profiler_guard;
RECORD_FUNCTION("process output", std::vector<c10::IValue>());
if (compiled_engine->profile_execution) {
output_profiler_guard =
std::make_unique<torch::autograd::profiler::RecordProfile>(compiled_engine->output_profile_path);
}
if ((false == compiled_engine->use_pre_allocated_outputs) || shape_changed) {
outputs = create_output_tensors(compiled_engine);
} else {
outputs = compiled_engine->pre_allocated_outputs;
}

for (auto output_indices : compiled_engine->out_binding_map) {
// out_binding_map stores TRT_IDX: PYT_IDX
auto pyt_idx = output_indices.second;

std::string name = compiled_engine->out_binding_names[pyt_idx];
auto out_shape = compiled_engine->exec_ctx->getTensorShape(name.c_str());
LOG_DEBUG("Output Name: " << name << " Shape: " << out_shape);

auto dims = core::util::toVec(out_shape);
auto type = util::TRTDataTypeToScalarType(compiled_engine->exec_ctx->getEngine().getTensorDataType(name.c_str()));
outputs[pyt_idx] = std::move(at::empty(dims, {at::kCUDA}).to(type).contiguous());

if (need_cudagraphs_record) {
// If we are recording the cuda graph then we need to update the persistent output buffer
compiled_engine->output_buffers[pyt_idx] = std::move(outputs[pyt_idx].clone());
Expand Down Expand Up @@ -311,6 +331,7 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
std::unique_lock<std::mutex> lock(compiled_engine->mu);

{ // Engine Execution (execute on engine stream)
RECORD_FUNCTION("Trt runtime", std::vector<c10::IValue>());
c10::cuda::CUDAStreamGuard stream_guard(compiled_engine->engine_stream);

std::unique_ptr<torch::autograd::profiler::RecordProfile> enqueue_profiler_guard;
Expand Down Expand Up @@ -345,6 +366,11 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
}
} // End engine exeuction (resets to caller stream)

// Create output buffer for next execution of graph or trt context.
if (compiled_engine->use_pre_allocated_outputs) {
compiled_engine->pre_allocated_outputs = create_output_tensors(compiled_engine);
}

// Block caller stream until engine execution is complete
at::cuda::CUDAEvent trt_exec_complete;
trt_exec_complete.record(compiled_engine->engine_stream);
Expand Down
1 change: 1 addition & 0 deletions core/runtime/register_jit_hooks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ static auto TORCHTRT_UNUSED TRTEngineTSRegistrtion =
.def("dump_engine_layer_info_to_file", &TRTEngine::dump_engine_layer_info_to_file)
.def("dump_engine_layer_info", &TRTEngine::dump_engine_layer_info)
.def("get_engine_layer_info", &TRTEngine::get_engine_layer_info)
.def("set_pre_allocated_outputs", &TRTEngine::set_pre_allocated_outputs)
.def_property(
"device_memory_budget",
&TRTEngine::get_device_memory_budget,
Expand Down
88 changes: 53 additions & 35 deletions py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,9 @@ def __init__(
self.engine = None
self.weight_name_map = weight_name_map
self.target_platform = Platform.current_platform()
self.cudagraphs_enabled = False
self.pre_allocated_outputs: List[torch.Tensor] = []
self.use_pre_allocated_outputs = False

if self.serialized_engine is not None and not self.settings.lazy_engine_init:
self.setup_engine()
Expand Down Expand Up @@ -171,7 +174,7 @@ def setup_engine(self) -> None:
self.engine.get_tensor_shape(input_name) for input_name in self.input_names
]
self.output_dtypes = [
dtype._from(self.engine.get_tensor_dtype(output_name))
dtype._from(self.engine.get_tensor_dtype(output_name)).to(torch.dtype)
for output_name in self.output_names
]
self.output_shapes = [
Expand Down Expand Up @@ -232,6 +235,19 @@ def __del__(self) -> None:
if self.cudagraph:
self.cudagraph.reset()

def create_output_tensors(self) -> List[torch.Tensor]:
# create output tensors
outputs: List[torch.Tensor] = []

for o, _ in enumerate(self.output_names):
output = torch.empty(
size=self.output_shapes[o],
dtype=self.output_dtypes[o],
device=torch.cuda.current_device(),
)
outputs.append(output)
return outputs

def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, ...]:
# Ensure inputs are available in all scopes and cast symbolic integers to Tensors
contiguous_inputs: List[torch.Tensor] = [
Expand All @@ -247,19 +263,25 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
self._check_initialized()

cudagraphs_enabled = torch_tensorrt.runtime.get_cudagraphs_mode()
need_cudagraphs_record = (
cudagraphs_enabled and not self.cudagraphs_validate_shapes(inputs)
)
shape_changed = self.validate_input_shapes(inputs)
# Cudagraphs record is required if cudagraphs_enabled is toggled to True regardless of shape change
if not self.cudagraphs_enabled and cudagraphs_enabled:
need_cudagraphs_record = True
else:
need_cudagraphs_record = cudagraphs_enabled and shape_changed
self.cudagraphs_enabled = cudagraphs_enabled

if need_cudagraphs_record:
if self.cudagraph:
self.cudagraph.reset()
self._input_buffers = [None] * len(self.input_names)
self._output_buffers = [None] * len(self.output_names)

if not cudagraphs_enabled and self.cudagraph:
self.cudagraph.reset()
self.cudagraph = None

# If in safe mode, check at each iteration for for whether a switch is required
# If in safe mode, check at each iteration for whether a switch is required
if (
torch_tensorrt.runtime._multi_device_safe_mode._PY_RT_MULTI_DEVICE_SAFE_MODE
):
Expand Down Expand Up @@ -350,14 +372,14 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
self.context.set_tensor_address(
input_name, contiguous_inputs[i].data_ptr()
)

# Check if input shapes can be inferred.
uninferred_input_names = self.context.infer_shapes()
if uninferred_input_names:
logger.warning(
f"The shapes of the inputs: {uninferred_input_names} cannot be inferred and could lead to undefined behavior. \
This could happen if the input tensor addresses/shapes haven't been configured correctly"
)
if shape_changed:
# Check if input shapes can be inferred.
uninferred_input_names = self.context.infer_shapes()
if uninferred_input_names:
logger.warning(
f"The shapes of the inputs: {uninferred_input_names} cannot be inferred and could lead to undefined behavior. \
This could happen if the input tensor addresses/shapes haven't been configured correctly"
)

with (
torch.autograd.profiler.record_function(
Expand All @@ -366,24 +388,20 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
if self.profiling_enabled
else nullcontext()
):
# create output tensors
outputs: List[torch.Tensor] = []

for o, output_name in enumerate(self.output_names):
shape = tuple(self.context.get_tensor_shape(output_name))

if DYNAMIC_DIM in shape:
if not self.use_pre_allocated_outputs or shape_changed:
self.output_shapes = [
tuple(self.context.get_tensor_shape(output_name))
for output_name in self.output_names
]
if DYNAMIC_DIM in self.output_shapes:
raise ValueError(
"Encountered dynamic output shapes during runtime. This could mean the network has data-dependent output shapes which is not currently supported."
)
outputs = self.create_output_tensors()
else:
outputs = self.pre_allocated_outputs

output = torch.empty(
size=shape,
dtype=self.output_dtypes[o].to(torch.dtype),
device=torch.cuda.current_device(),
)

outputs.append(output)
for o, output_name in enumerate(self.output_names):

if need_cudagraphs_record:
self._output_buffers[o] = outputs[o].clone()
Expand Down Expand Up @@ -444,6 +462,9 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .

self._caller_stream.wait_stream(self._engine_stream)

if self.use_pre_allocated_outputs:
self.pre_allocated_outputs = self.create_output_tensors()

if cudagraphs_enabled:
for idx, o in enumerate(outputs):
o.copy_(self._output_buffers[idx])
Expand Down Expand Up @@ -485,10 +506,9 @@ def get_layer_info(self) -> str:
)
return engine_json

def cudagraphs_validate_shapes(self, inputs: Sequence[torch.Tensor]) -> bool:
def validate_input_shapes(self, inputs: Sequence[torch.Tensor]) -> bool:
"""
Validates the input shapes of the forward function
versus the version currently active for the
Validates the input shapes of the forward function has changed
"""
# Representation of input shapes to a given model
# Shapes are concatenated as so:
Expand All @@ -498,10 +518,8 @@ def cudagraphs_validate_shapes(self, inputs: Sequence[torch.Tensor]) -> bool:
# If the new shape key differs from the existing one,
# invalidate the old shape key and remove the CUDAGraph
if new_shape_key != self.shape_key:
logger.debug(f"Resetting Cudagraph on new shape key {new_shape_key}")
logger.debug(f"Input shape changed {self.shape_key} -> {new_shape_key}")
self.shape_key = new_shape_key
if self.cudagraph:
self.cudagraph.reset()
return False
return True

return True
return False
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ def setup_engine(self) -> None:
if self.engine is not None:
return
self.engine = torch.classes.tensorrt.Engine(self._pack_engine_info())
self.engine.set_pre_allocated_outputs(True)

def encode_metadata(self, metadata: Any) -> str:
metadata = copy.deepcopy(metadata)
Expand Down

0 comments on commit 1bd9fdc

Please sign in to comment.