Skip to content

Commit 9217ae0

Browse files
committed
Enabled weight streaming and CudaGraph. Supported MTTM saving with dynamic shapes.
1 parent 56a426a commit 9217ae0

File tree

4 files changed

+84
-20
lines changed

4 files changed

+84
-20
lines changed

examples/apps/flux-demo.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -41,13 +41,14 @@
4141
"use_fp32_acc": True,
4242
"use_explicit_typing": True,
4343
"debug": False,
44-
"use_python_runtime": True,
44+
"use_python_runtime": False,
4545
"immutable_weights": False,
4646
# "cache_built_engines": True,
4747
# "reuse_cached_engines": True,
4848
# "timing_cache_path": "/home/engine_cache/flux.bin",
4949
# "engine_cache_size": 40 * 1 << 30,
50-
# "enable_weight_streaming": False,
50+
# "enable_weight_streaming": True,
51+
# "weight_streaming_budget": 8 * 1 << 30
5152
# "enable_cuda_graph": True,
5253
}
5354

@@ -68,6 +69,7 @@ def generate_image(prompt, inference_step, batch_size=2):
6869

6970
generate_image(["Test"], 2)
7071
torch.cuda.empty_cache()
72+
# torch_tensorrt.MutableTorchTensorRTModule.save(trt_gm, "weight_streaming_Flux.pkl")
7173

7274

7375
def model_change(model):
@@ -83,7 +85,7 @@ def model_change(model):
8385
def load_lora(path):
8486

8587
pipe.load_lora_weights(
86-
"/home/TensorRT/examples/apps/NGRVNG.safetensors",
88+
path,
8789
adapter_name="lora1",
8890
)
8991
pipe.set_adapters(["lora1"], adapter_weights=[1])

py/torch_tensorrt/dynamo/_compiler.py

-9
Original file line numberDiff line numberDiff line change
@@ -551,15 +551,6 @@ def compile(
551551
"`immutable_weights` must be False when `refit_identical_engine_weights` is True."
552552
)
553553

554-
if (
555-
not immutable_weights
556-
and not refit_identical_engine_weights
557-
and enable_weight_streaming
558-
):
559-
raise ValueError(
560-
"TensorRT's `REFIT` flag is not compatible with `enable_weight_streaming=True` for now. This issue was reported on https://github.com/pytorch/TensorRT/issues/3305"
561-
)
562-
563554
if (
564555
"enable_cross_compile_for_windows" in kwargs.keys()
565556
and kwargs["enable_cross_compile_for_windows"]

py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py

+71-7
Original file line numberDiff line numberDiff line change
@@ -63,10 +63,11 @@ def __init__(
6363
*,
6464
device: Optional[Union[Device, torch.device, str]] = _defaults.DEVICE,
6565
use_python_runtime: bool = _defaults.USE_PYTHON_RUNTIME,
66-
enable_cuda_graph: bool = True,
66+
enable_cuda_graph: bool = False,
6767
immutable_weights: bool = False,
6868
strict: bool = True,
6969
allow_complex_guards_as_runtime_asserts: bool = False,
70+
weight_streaming_budget: Optional[int] = None,
7071
**kwargs: Any,
7172
) -> None:
7273
"""
@@ -130,7 +131,6 @@ def __init__(
130131
self.arg_inputs: tuple[Any, ...] = tuple()
131132
self.kwarg_inputs: dict[str, Any] = {}
132133
self.additional_settings = kwargs
133-
self.enable_cuda_graph = enable_cuda_graph
134134
self.strict = strict
135135
self.allow_complex_guards_as_runtime_asserts = (
136136
allow_complex_guards_as_runtime_asserts
@@ -143,6 +143,7 @@ def __init__(
143143

144144
self.arg_dynamic_shapes: Optional[tuple[Any]] = None
145145
self.kwarg_dynamic_shapes: Optional[dict[Any, Any]] = None
146+
self.serializable_dynamic_shapes_dims: dict[str, tuple[str, int, int]] = {}
146147
self.run_info: Optional[tuple[Any, ...]] = None
147148
self.state_dict_metadata: dict[str, torch.Size] = {}
148149
self._store_state_dict_metadata()
@@ -151,6 +152,15 @@ def __init__(
151152
if "enable_weight_streaming" in kwargs
152153
else False
153154
)
155+
self.weight_streaming_ctx = None
156+
self.weight_streaming_budget = weight_streaming_budget
157+
if self.enable_weight_streaming:
158+
if weight_streaming_budget is None:
159+
logger.warning(
160+
"Weight stremaing budget is not set. Using auto weight streaming budget"
161+
)
162+
self.enable_cuda_graph = enable_cuda_graph
163+
154164
cls = self.__class__
155165
self.__class__ = type(
156166
self.original_model.__class__.__name__,
@@ -339,9 +349,20 @@ def compile(self) -> None:
339349
if self.enable_cuda_graph:
340350
self._enable_cuda_graph()
341351
if self.enable_weight_streaming:
342-
self.weight_streaming_ctx = torch_tensorrt.runtime.weight_streaming(self.gm)
343-
requested_budget = int(16 * 2 << 20)
344-
self.weight_streaming_ctx.device_budget = requested_budget
352+
self.set_weight_streaming_ctx(self.weight_streaming_budget)
353+
354+
def set_weight_streaming_ctx(self, requested_budget: Optional[int] = None) -> None:
355+
"""
356+
Set the weight streaming budget. If budget is not set, then automatic weight streaming budget
357+
is used.
358+
"""
359+
self.weight_streaming_ctx = torch_tensorrt.runtime.weight_streaming(self.gm)
360+
requested_budget = (
361+
requested_budget
362+
if requested_budget is not None
363+
else self.weight_streaming_ctx.get_automatic_weight_streaming_budget()
364+
)
365+
self.weight_streaming_ctx.device_budget = requested_budget
345366

346367
def _enable_cuda_graph(self) -> None:
347368
self.gm = get_cuda_graph_module(self.gm)
@@ -465,7 +486,9 @@ def forward(self, *args: Any, **kwargs: Any) -> Any:
465486
self._store_state_dict_metadata()
466487
self.refit_state.set_state(RefitFlag.LIVE)
467488

468-
# weight_streaming_ctx = self.weight_streaming_ctx if self.enable_weight_streaming else None
489+
weight_streaming_ctx = (
490+
self.weight_streaming_ctx if self.enable_weight_streaming else None
491+
)
469492
result = self.gm(*args, **kwargs)
470493
# Storing inputs and outputs for verification when the state is unknown
471494
self.run_info = (args, kwargs, result)
@@ -605,6 +628,45 @@ def _check_tensor_shapes_with_dynamic_shapes(
605628

606629
return True
607630

631+
def serialize_dynamic_shapes(self) -> None:
632+
dims = self.serializable_dynamic_shapes_dims
633+
634+
def resursivly_serialize_dynamic_shape(obj: Any) -> None:
635+
if isinstance(obj, dict):
636+
for axis, v in obj.items():
637+
if isinstance(v, torch.export.dynamic_shapes._Dim):
638+
name = str(v).split("'")[1].split(".")[-1]
639+
# We use string of the hash to be the unique identifier of Dim object
640+
dims.setdefault(str(hash(v)), (name, v.min, v.max))
641+
obj[axis] = str(hash(v))
642+
else:
643+
resursivly_serialize_dynamic_shape(v)
644+
if isinstance(obj, (tuple, list)):
645+
for v in obj:
646+
resursivly_serialize_dynamic_shape(v)
647+
648+
resursivly_serialize_dynamic_shape(self.arg_dynamic_shapes)
649+
resursivly_serialize_dynamic_shape(self.kwarg_dynamic_shapes)
650+
651+
def deserialize_dynamic_shapes(self) -> None:
652+
dims = self.serializable_dynamic_shapes_dims
653+
654+
def resursivly_deserialize_dynamic_shape(obj: Any) -> None:
655+
if isinstance(obj, dict):
656+
for axis, v in obj.items():
657+
if isinstance(v, str):
658+
obj[axis] = torch.export.Dim(
659+
dims[v][0], min=dims[v][1], max=dims[v][2]
660+
)
661+
else:
662+
resursivly_deserialize_dynamic_shape(v)
663+
if isinstance(obj, (tuple, list)):
664+
for v in obj:
665+
resursivly_deserialize_dynamic_shape(v)
666+
667+
resursivly_deserialize_dynamic_shape(self.arg_dynamic_shapes)
668+
resursivly_deserialize_dynamic_shape(self.kwarg_dynamic_shapes)
669+
608670
@staticmethod
609671
def save(module: Any, path: str) -> None:
610672
# Cast the object back to MutableTorchTensorRTModule to save
@@ -616,7 +678,8 @@ def save(module: Any, path: str) -> None:
616678
exp_program = module.exp_program
617679
module.pytorch_model = None
618680
module.exp_program = None
619-
torch.save(module, path)
681+
module.serialize_dynamic_shapes()
682+
torch.save(module, path, pickle_protocol=4)
620683
# Restore deleted attributes
621684
module.exp_program = exp_program
622685
module.pytorch_model = _make_refit_change_trigger(
@@ -650,6 +713,7 @@ def load(path: str) -> Any:
650713
(cls, module.original_model.__class__),
651714
{},
652715
)
716+
module.deserialize_dynamic_shapes()
653717
module.init_finished = True
654718
return module
655719

tests/py/dynamo/runtime/test_mutable_torchtrt_module.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def test_check_input_shape_dynamic():
7575

7676

7777
@pytest.mark.unit
78-
def test_model_complex_dynamic_shape():
78+
def test_model_complex_dynamic_shape_with_saving():
7979
device = "cuda:0"
8080

8181
class Model(torch.nn.Module):
@@ -111,6 +111,13 @@ def forward(self, a, b, c=None):
111111
# Run inference
112112
trt_gm(*inputs, **kwargs)
113113

114+
try:
115+
save_path = os.path.join(tempfile.gettempdir(), "mutable_module.pkl")
116+
torch_trt.MutableTorchTensorRTModule.save(mutable_module, save_path)
117+
model = torch_trt.MutableTorchTensorRTModule.load("mutable_module.pkl")
118+
except Exception as e:
119+
assert "Module saving and reloading with dynamic shape failed."
120+
114121
inputs_2 = [torch.rand(10, 9).to(device)]
115122
kwargs_2 = {
116123
"b": torch.rand(9, 30).to(device),

0 commit comments

Comments
 (0)