Skip to content

Commit 56a426a

Browse files
committed
Enabled Cuda Graph
1 parent 2eefd73 commit 56a426a

File tree

3 files changed

+54
-45
lines changed

3 files changed

+54
-45
lines changed

examples/apps/flux-demo.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,8 @@ def load_lora(path):
9595

9696

9797
generate_image(["Test"], 2)
98-
load_lora("")
99-
generate_image(["A golden retriever holding a sign to code"], 2)
98+
# load_lora("")
99+
# generate_image(["A golden retriever holding a sign to code"], 2)
100100

101101
# Create Gradio interface
102102
with gr.Blocks(title="Flux Demo with Torch-TensorRT") as demo:

py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
to_torch_device,
1717
to_torch_tensorrt_device,
1818
)
19+
from torch_tensorrt.runtime._cudagraphs import get_cuda_graph_module
1920

2021
logger = logging.getLogger(__name__)
2122

@@ -335,14 +336,16 @@ def compile(self) -> None:
335336
)
336337
self.original_model.to("cpu")
337338
torch.cuda.empty_cache()
338-
# torch_tensorrt.runtime.set_cudagraphs_mode(self.enable_cuda_graph)
339-
# if self.enable_cuda_graph:
340-
# self.gm = torch_tensorrt.runtime.enable_cudagraphs(self.gm)
339+
if self.enable_cuda_graph:
340+
self._enable_cuda_graph()
341341
if self.enable_weight_streaming:
342342
self.weight_streaming_ctx = torch_tensorrt.runtime.weight_streaming(self.gm)
343343
requested_budget = int(16 * 2 << 20)
344344
self.weight_streaming_ctx.device_budget = requested_budget
345345

346+
def _enable_cuda_graph(self) -> None:
347+
self.gm = get_cuda_graph_module(self.gm)
348+
346349
def _validate_inputs(self, *args: Any, **kwargs: Any) -> None:
347350

348351
if not self.arg_inputs and not self.kwarg_inputs:

py/torch_tensorrt/runtime/_cudagraphs.py

+46-40
Original file line numberDiff line numberDiff line change
@@ -69,52 +69,58 @@ def __init__(self, compiled_module: torch.nn.Module) -> None:
6969
self.old_mode = _PY_RT_CUDAGRAPHS
7070
self.compiled_module = compiled_module
7171

72-
def __enter__(self) -> torch.nn.Module:
73-
global _PY_RT_CUDAGRAPHS
74-
75-
num_torch_module = 0
76-
num_trt_module = 0
77-
for name, module in self.compiled_module.named_children():
78-
# need to disable cudagraphs if any model requires output allocator
79-
if (
80-
hasattr(module, "requires_output_allocator")
81-
and module.requires_output_allocator
82-
):
83-
raise RuntimeError(
84-
"The model contains submodules that require a dynamic output allocator at runtime, which is incompatible with CUDA Graphs. Please disable CUDA Graphs."
85-
)
86-
if "_run_on_acc" in name:
87-
num_trt_module += 1
88-
elif "_run_on_gpu" in name:
89-
num_torch_module += 1
90-
91-
if num_torch_module > 0:
92-
# Set whole cudagraphs mode and returns wrapped module
93-
_PY_RT_CUDAGRAPHS = CudaGraphsMode.WHOLE_GRAPH_CUDAGRAPHS
94-
# Set new mode for C++
95-
if torch_tensorrt.ENABLED_FEATURES.torch_tensorrt_runtime:
96-
torch.ops.tensorrt.set_cudagraphs_mode(_PY_RT_CUDAGRAPHS)
97-
98-
logger.debug(
99-
"Found pytorch subgraphs in module, wrapping module in CudaGraphsTorchTensorRTModule"
100-
)
101-
return CudaGraphsTorchTensorRTModule(self.compiled_module)
102-
else:
103-
if num_trt_module > 0:
104-
logger.debug("No graph breaks detected, using runtime cudagraphs mode")
105-
else:
106-
logger.debug(
107-
"Please consider dynamo if there is graph breaks. Using runtime cudagraphs mode"
108-
)
109-
# Enable cudagraphs for TRT submodule
110-
set_cudagraphs_mode(True)
111-
return self.compiled_module
72+
def __enter__(self) -> torch.nn.Module | torch.fx.GraphModule:
73+
return get_cuda_graph_module(self.compiled_module)
11274

11375
def __exit__(self, *args: Any) -> None:
11476
# Set cudagraphs back to old mode
11577
set_cudagraphs_mode(self.old_mode)
11678

11779

80+
def get_cuda_graph_module(
81+
compiled_module: torch.fx.GraphModule,
82+
) -> torch.nn.Module | torch.fx.GraphModule:
83+
global _PY_RT_CUDAGRAPHS
84+
85+
num_torch_module = 0
86+
num_trt_module = 0
87+
for name, module in compiled_module.named_children():
88+
# need to disable cudagraphs if any model requires output allocator
89+
if (
90+
hasattr(module, "requires_output_allocator")
91+
and module.requires_output_allocator
92+
):
93+
raise RuntimeError(
94+
"The model contains submodules that require a dynamic output allocator at runtime, which is incompatible with CUDA Graphs. Please disable CUDA Graphs."
95+
)
96+
if "_run_on_acc" in name:
97+
num_trt_module += 1
98+
elif "_run_on_gpu" in name:
99+
num_torch_module += 1
100+
101+
if num_torch_module > 0:
102+
# Set whole cudagraphs mode and returns wrapped module
103+
_PY_RT_CUDAGRAPHS = CudaGraphsMode.WHOLE_GRAPH_CUDAGRAPHS
104+
# Set new mode for C++
105+
if torch_tensorrt.ENABLED_FEATURES.torch_tensorrt_runtime:
106+
torch.ops.tensorrt.set_cudagraphs_mode(_PY_RT_CUDAGRAPHS)
107+
108+
logger.debug(
109+
"Found pytorch subgraphs in module, wrapping module in CudaGraphsTorchTensorRTModule"
110+
)
111+
return CudaGraphsTorchTensorRTModule(compiled_module)
112+
else:
113+
if num_trt_module > 0:
114+
logger.debug("No graph breaks detected, using runtime cudagraphs mode")
115+
else:
116+
logger.debug(
117+
"Please consider dynamo if there is graph breaks. Using runtime cudagraphs mode"
118+
)
119+
# Enable cudagraphs for TRT submodule
120+
set_cudagraphs_mode(True)
121+
return compiled_module
122+
123+
118124
def enable_cudagraphs(
119125
compiled_module: Union[torch.fx.GraphModule, torch.nn.Module],
120126
) -> _CudagraphsContextManager:

0 commit comments

Comments
 (0)