@@ -63,10 +63,11 @@ def __init__(
63
63
* ,
64
64
device : Optional [Union [Device , torch .device , str ]] = _defaults .DEVICE ,
65
65
use_python_runtime : bool = _defaults .USE_PYTHON_RUNTIME ,
66
- enable_cuda_graph : bool = True ,
66
+ enable_cuda_graph : bool = False ,
67
67
immutable_weights : bool = False ,
68
68
strict : bool = True ,
69
69
allow_complex_guards_as_runtime_asserts : bool = False ,
70
+ weight_streaming_budget : Optional [int ] = None ,
70
71
** kwargs : Any ,
71
72
) -> None :
72
73
"""
@@ -130,7 +131,6 @@ def __init__(
130
131
self .arg_inputs : tuple [Any , ...] = tuple ()
131
132
self .kwarg_inputs : dict [str , Any ] = {}
132
133
self .additional_settings = kwargs
133
- self .enable_cuda_graph = enable_cuda_graph
134
134
self .strict = strict
135
135
self .allow_complex_guards_as_runtime_asserts = (
136
136
allow_complex_guards_as_runtime_asserts
@@ -143,6 +143,7 @@ def __init__(
143
143
144
144
self .arg_dynamic_shapes : Optional [tuple [Any ]] = None
145
145
self .kwarg_dynamic_shapes : Optional [dict [Any , Any ]] = None
146
+ self .serializable_dynamic_shapes_dims : dict [str , tuple [str , int , int ]] = {}
146
147
self .run_info : Optional [tuple [Any , ...]] = None
147
148
self .state_dict_metadata : dict [str , torch .Size ] = {}
148
149
self ._store_state_dict_metadata ()
@@ -151,6 +152,15 @@ def __init__(
151
152
if "enable_weight_streaming" in kwargs
152
153
else False
153
154
)
155
+ self .weight_streaming_ctx = None
156
+ self .weight_streaming_budget = weight_streaming_budget
157
+ if self .enable_weight_streaming :
158
+ if weight_streaming_budget is None :
159
+ logger .warning (
160
+ "Weight stremaing budget is not set. Using auto weight streaming budget"
161
+ )
162
+ self .enable_cuda_graph = enable_cuda_graph
163
+
154
164
cls = self .__class__
155
165
self .__class__ = type (
156
166
self .original_model .__class__ .__name__ ,
@@ -339,9 +349,20 @@ def compile(self) -> None:
339
349
if self .enable_cuda_graph :
340
350
self ._enable_cuda_graph ()
341
351
if self .enable_weight_streaming :
342
- self .weight_streaming_ctx = torch_tensorrt .runtime .weight_streaming (self .gm )
343
- requested_budget = int (16 * 2 << 20 )
344
- self .weight_streaming_ctx .device_budget = requested_budget
352
+ self .set_weight_streaming_ctx (self .weight_streaming_budget )
353
+
354
+ def set_weight_streaming_ctx (self , requested_budget : Optional [int ] = None ) -> None :
355
+ """
356
+ Set the weight streaming budget. If budget is not set, then automatic weight streaming budget
357
+ is used.
358
+ """
359
+ self .weight_streaming_ctx = torch_tensorrt .runtime .weight_streaming (self .gm )
360
+ requested_budget = (
361
+ requested_budget
362
+ if requested_budget is not None
363
+ else self .weight_streaming_ctx .get_automatic_weight_streaming_budget ()
364
+ )
365
+ self .weight_streaming_ctx .device_budget = requested_budget
345
366
346
367
def _enable_cuda_graph (self ) -> None :
347
368
self .gm = get_cuda_graph_module (self .gm )
@@ -465,7 +486,9 @@ def forward(self, *args: Any, **kwargs: Any) -> Any:
465
486
self ._store_state_dict_metadata ()
466
487
self .refit_state .set_state (RefitFlag .LIVE )
467
488
468
- # weight_streaming_ctx = self.weight_streaming_ctx if self.enable_weight_streaming else None
489
+ weight_streaming_ctx = (
490
+ self .weight_streaming_ctx if self .enable_weight_streaming else None
491
+ )
469
492
result = self .gm (* args , ** kwargs )
470
493
# Storing inputs and outputs for verification when the state is unknown
471
494
self .run_info = (args , kwargs , result )
@@ -605,6 +628,45 @@ def _check_tensor_shapes_with_dynamic_shapes(
605
628
606
629
return True
607
630
631
+ def serialize_dynamic_shapes (self ) -> None :
632
+ dims = self .serializable_dynamic_shapes_dims
633
+
634
+ def resursivly_serialize_dynamic_shape (obj : Any ) -> None :
635
+ if isinstance (obj , dict ):
636
+ for axis , v in obj .items ():
637
+ if isinstance (v , torch .export .dynamic_shapes ._Dim ):
638
+ name = str (v ).split ("'" )[1 ].split ("." )[- 1 ]
639
+ # We use string of the hash to be the unique identifier of Dim object
640
+ dims .setdefault (str (hash (v )), (name , v .min , v .max ))
641
+ obj [axis ] = str (hash (v ))
642
+ else :
643
+ resursivly_serialize_dynamic_shape (v )
644
+ if isinstance (obj , (tuple , list )):
645
+ for v in obj :
646
+ resursivly_serialize_dynamic_shape (v )
647
+
648
+ resursivly_serialize_dynamic_shape (self .arg_dynamic_shapes )
649
+ resursivly_serialize_dynamic_shape (self .kwarg_dynamic_shapes )
650
+
651
+ def deserialize_dynamic_shapes (self ) -> None :
652
+ dims = self .serializable_dynamic_shapes_dims
653
+
654
+ def resursivly_deserialize_dynamic_shape (obj : Any ) -> None :
655
+ if isinstance (obj , dict ):
656
+ for axis , v in obj .items ():
657
+ if isinstance (v , str ):
658
+ obj [axis ] = torch .export .Dim (
659
+ dims [v ][0 ], min = dims [v ][1 ], max = dims [v ][2 ]
660
+ )
661
+ else :
662
+ resursivly_deserialize_dynamic_shape (v )
663
+ if isinstance (obj , (tuple , list )):
664
+ for v in obj :
665
+ resursivly_deserialize_dynamic_shape (v )
666
+
667
+ resursivly_deserialize_dynamic_shape (self .arg_dynamic_shapes )
668
+ resursivly_deserialize_dynamic_shape (self .kwarg_dynamic_shapes )
669
+
608
670
@staticmethod
609
671
def save (module : Any , path : str ) -> None :
610
672
# Cast the object back to MutableTorchTensorRTModule to save
@@ -616,7 +678,8 @@ def save(module: Any, path: str) -> None:
616
678
exp_program = module .exp_program
617
679
module .pytorch_model = None
618
680
module .exp_program = None
619
- torch .save (module , path )
681
+ module .serialize_dynamic_shapes ()
682
+ torch .save (module , path , pickle_protocol = 4 )
620
683
# Restore deleted attributes
621
684
module .exp_program = exp_program
622
685
module .pytorch_model = _make_refit_change_trigger (
@@ -650,6 +713,7 @@ def load(path: str) -> Any:
650
713
(cls , module .original_model .__class__ ),
651
714
{},
652
715
)
716
+ module .deserialize_dynamic_shapes ()
653
717
module .init_finished = True
654
718
return module
655
719
0 commit comments