Skip to content

Commit

Permalink
fix: Record cudagraphs when weight streaming budget has changed
Browse files Browse the repository at this point in the history
  • Loading branch information
keehyuna committed Dec 1, 2024
1 parent 1b40bcc commit 7bb66da
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 5 deletions.
3 changes: 3 additions & 0 deletions core/runtime/TRTEngine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,9 @@ bool TRTEngine::set_device_memory_budget(int64_t budget) {
if (profile_execution) {
enable_profiling();
}
// Indicates to reevaluate the runtime settings
has_context_changed = true;

return result;
}

Expand Down
1 change: 1 addition & 0 deletions core/runtime/TRTEngine.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ struct TRTEngine : torch::CustomClassHolder {
std::vector<at::Tensor> input_buffers = {};
std::vector<at::Tensor> output_buffers = {};
std::string shape_key;
bool has_context_changed = false;

// TODO: Implement a call method
// c10::List<at::Tensor> Run(c10::List<at::Tensor> inputs);
Expand Down
8 changes: 6 additions & 2 deletions core/runtime/execute_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,11 +115,15 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
}

// Whether cudagraphs needs to record the graph on this pass
bool need_cudagraphs_record = (CUDAGRAPHS_MODE && (!_cudagraphs_validate_shapes(inputs, compiled_engine)));
bool need_cudagraphs_record =
(CUDAGRAPHS_MODE &&
(!_cudagraphs_validate_shapes(inputs, compiled_engine) || compiled_engine->has_context_changed));

if (!CUDAGRAPHS_MODE) {
if (!CUDAGRAPHS_MODE || compiled_engine->has_context_changed) {
compiled_engine->cudagraph.reset();
}
// Reset the flag
compiled_engine->has_context_changed = false;

// this is a buffer to store shape tensor input addresses throughout the runtime scope
std::list<std::vector<int64_t>> inputShapeTensorValues;
Expand Down
13 changes: 10 additions & 3 deletions py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ def __init__(
self.engine = None
self.weight_name_map = weight_name_map
self.target_platform = Platform.current_platform()
self.has_context_changed = False

if self.serialized_engine is not None and not self.settings.lazy_engine_init:
self.setup_engine()
Expand All @@ -126,6 +127,9 @@ def set_device_memory_budget(self, budget_bytes: int) -> int:
del self.context
budget_bytes = self._set_device_memory_budget(budget_bytes)
self.context = self.engine.create_execution_context()
# Indicates to reevaluate the runtime settings
self.has_context_changed = True

return budget_bytes

def _set_device_memory_budget(self, budget_bytes: int) -> int:
Expand Down Expand Up @@ -247,18 +251,21 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
self._check_initialized()

cudagraphs_enabled = torch_tensorrt.runtime.get_cudagraphs_mode()
need_cudagraphs_record = (
cudagraphs_enabled and not self.cudagraphs_validate_shapes(inputs)
need_cudagraphs_record = cudagraphs_enabled and (
not self.cudagraphs_validate_shapes(inputs) or self.has_context_changed
)

if need_cudagraphs_record:
self._input_buffers = [None] * len(self.input_names)
self._output_buffers = [None] * len(self.output_names)

if not cudagraphs_enabled and self.cudagraph:
if self.cudagraph and (not cudagraphs_enabled or self.has_context_changed):
self.cudagraph.reset()
self.cudagraph = None

# Reset the flag
self.has_context_changed = False

# If in safe mode, check at each iteration for for whether a switch is required
if (
torch_tensorrt.runtime._multi_device_safe_mode._PY_RT_MULTI_DEVICE_SAFE_MODE
Expand Down

0 comments on commit 7bb66da

Please sign in to comment.