-
Notifications
You must be signed in to change notification settings - Fork 351
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: Support weight-stripped engine and REFIT_IDENTICAL flag #3167
base: main
Are you sure you want to change the base?
Conversation
name: str = "", | ||
settings: CompilationSettings = CompilationSettings(), # Assumes engine was built with default compilation settings if object not passed | ||
weight_name_map: Optional[dict[Any, Any]] = None, | ||
graph_module: torch.fx.GraphModule = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@narendasan I tried to do refitting for C++ runtime like for Python runtime but didn't work. Any suggestions? should I do in C++ or Python?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Doesnt refit already work on both apis?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also why do we need the graph module in this module?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
-
In this PR I moved the refitting part into TRTModule, so only works for Python runtime.
-
graph module is used for refitting
@@ -619,27 +609,32 @@ def run( | |||
builder_config, self.compilation_settings.timing_cache_path | |||
) | |||
|
|||
serialized_engine = self.builder.build_serialized_network( | |||
# if strip_engine_weights is true, the serialized engine need to be refitted before using | |||
maybe_unrefitted_serialized_engine = self.builder.build_serialized_network( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this maybe unrefitted engine?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please see the design in the comment below. If compilation_settings.strip_engine_weights is true, it needs to be refitted, else it doesn't. so it's maybe
), "weight-stripped engines must be refittable, please set make_refittable=True" | ||
|
||
# Refit the weights | ||
refitter = trt.Refitter(self.engine, TRT_LOGGER) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you use this function?
TensorRT/py/torch_tensorrt/dynamo/_refit.py
Line 138 in fa02fd3
def _refit_single_trt_engine_with_gm( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The function requires input_list
which is not provided in the caller.
@@ -121,6 +124,52 @@ def setup_engine(self) -> None: | |||
self.engine = runtime.deserialize_cuda_engine(self.serialized_engine) | |||
self.context = self.engine.create_execution_context() | |||
|
|||
if self.settings.strip_engine_weights: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We likely shouldnt be doing the refit in these modules
I think for weight stripping there are 3 workflows.
- a user just wants a weight stripped engine. They should use
convert_exported_program_to_trt_engine
with settingsstrip_weights
. The choice ofmake_refittable
can be used to decide betweenkREFIT
andkREFIT_IDENTICAL
(though it might not be entirely clear so we might want to think about that setting). - We want to utilize weight stripping to have a lighter weight cache. Here this choice is opaque to the user. The user choice of
make_refittable
controls if we usekREFIT
orkREFIT_IDENTICAL
. But once the engine is loaded or we pull from cache we immediately refit (prior to passing the engine to the TRTModule). Same as we do today - The user wants a stripped weights compiled program (im not sure why or if this is a real usecase). Here, this is basically the same as lazy engine loading. We would require that users need to run through
refit_engine_weights
before executing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it. The very beginning idea/design is commented below. I'll move the refitting part back to TRTInterpreter.run()
The choice of
make_refittable
can be used to decide betweenkREFIT
andkREFIT_IDENTICAL
Do you mean we use make_refittable
to control both kREFIT
and kREFIT_IDENTICAL
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@zewenli98 do you have a design for this feature?
@narendasan Ok, at first the overall design was like: In TRTInterpreter.run(): if compilation_settings.strip_engine_weights is True:
if engine_cache not hit:
1. build a weight-stripped engine
2. save the weight-stripped engine if engine_cache is set
3. return the weight-stripped engine (not yet refit)
else:
load and return the weight-stripped engine (not yet refit)
else:
if engine_cache not hit:
1. build a weight-included engine
2. save the weight-included engine if engine_cache is set
3. return the weight-included engine (don't need to refit)
else:
load and return the weight-included engine (not yet refit) Then, in TRTModule, refit if necessary before inference. |
@narendasan The design was updated. From the users' perspective, they are able to set
Besides, users can specify For the 3 workflows mentioned above,
Please see more details in the tests. |
I think that we need to separate the runtime and the compiler so im willing to spend the time serializing and deserializing. I think we should frame PR this around moving TRTInterpreter to default to building weight stripped engines. There will be 3 kinds of engines now.
The first 2 need separate cache entries. So we need to be able to hash on the weights in the case that the model is being built with We should look to prefer case 1 in the long term as it allows us to reuse the most work, case 2 would be the next preference. Case 2 should produce faster engines than Case 1 so there remains a need to support
The case for type 3 engines now is only valid if building a non refittable engine is faster than building a refit_identical engine then refitting the weights. If it is not by a significant enough margin I propose we remove that workflow and just have So assuming that we can remove type 3 engines, |
Some of the open questions are:
|
Are you referring to |
My current design is: If users specify
I also thought about it earlier. The TRT doc says "if the refit weights are not identical to the build-time weights, behavior is undefined... This enables use of a single set of weights with different inference backends, or with TensorRT plans for multiple GPU architectures."
will investigate on it. |
@narendasan I tested on building Resnet18 and vgg16 via the two paths: (1) |
@narendasan I just confirmed with TRT team, the conclusion is
I think we can rename On top of this, In summary, the 3 workflows mentioned above would be:
|
I think we should remove non-refittable then and we can add it back as a non default workflow later if theres some reason to.
I still dont know what the usecase for this is |
We should think about a solution for this since behavior is undefined |
04cebc2
to
a928f67
Compare
|
||
if "make_refittable" in kwargs.keys(): | ||
warnings.warn( | ||
"`make_refittable` is deprecated. All engines are refittable now. If you want to disable refitting, please open an issue on the Github repo with reasons.", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are we not giving users the option to build non-refittable engines with this change ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We do provide the option but rename to immutable_weights
.
builder_config.set_flag(trt.BuilderFlag.REFIT) | ||
|
||
builder_config.set_flag(trt.BuilderFlag.STRIP_PLAN) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this happen only if strip_engine_weights
is True ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like you were reviewing an outdated commit (not submit after that review lol). To clarify the intention here, previously we planned to embed weight stripping into the compilation workflow whatever strip_engine_weights
is because it would benefit engine caching. Afterwards, we will add weights back if users want a weighted engine. However, due to the lack of TRT support, as we discussed, it was deprecated. So for now the latest commit added the if condition.
@@ -220,7 +220,7 @@ def _from( | |||
return dtype.f32 | |||
elif t == np.float64: | |||
return dtype.f64 | |||
elif t == np.bool: | |||
elif t == np.bool_: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we need both np.bool_ and np.bool ?
@@ -111,6 +111,10 @@ def _pretraced_backend( | |||
logger.warning( | |||
"require_full_compilation arg is not applicable for torch.compile with backend='torch_tensorrt" | |||
) | |||
if settings.strip_engine_weights: | |||
logger.warning( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we halt compilation in this case instead of warning ?
@@ -1049,7 +1058,7 @@ def aten_ops_permute( | |||
|
|||
|
|||
def to_copy_dtype_validator( | |||
placeholder_only: bool, settings: CompilationSettings = None | |||
placeholder_only: bool, settings: Optional[CompilationSettings] = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why the change to optional here ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lint pointed it out. I think it's because the default is None
4524e94
to
3d68039
Compare
@zewenli98 / @peri044
... |
Thanks @keehyuna! Does this error happen with trt.BuilderFlag.REFIT_IDENTICAL flag? |
This is no problem when trt.BuilderFlag.REFIT_IDENTICAL is used. It was tested on main branch. |
Description
Fixes #3146
Type of change
Checklist: