Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support weight-stripped engine and REFIT_IDENTICAL flag #3167

Open
wants to merge 36 commits into
base: main
Choose a base branch
from

Conversation

zewenli98
Copy link
Collaborator

@zewenli98 zewenli98 commented Sep 19, 2024

Description

  1. Supported weight-stripped engine
  2. Added REFIT_IDENTICAL flag

Fixes #3146

Type of change

  • New feature (non-breaking change which adds functionality)

Checklist:

  • My code follows the style guidelines of this project (You can use the linters)
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas and hacks
  • I have made corresponding changes to the documentation
  • I have added tests to verify my fix or my feature
  • New and existing unit tests pass locally with my changes
  • I have added the relevant labels to my PR in so that relevant reviewers are notified

@zewenli98 zewenli98 self-assigned this Sep 19, 2024
@github-actions github-actions bot added component: tests Issues re: Tests component: conversion Issues re: Conversion stage component: api [Python] Issues re: Python API component: runtime component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths labels Sep 19, 2024
Comment on lines 79 to 82
name: str = "",
settings: CompilationSettings = CompilationSettings(), # Assumes engine was built with default compilation settings if object not passed
weight_name_map: Optional[dict[Any, Any]] = None,
graph_module: torch.fx.GraphModule = None,
Copy link
Collaborator Author

@zewenli98 zewenli98 Sep 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@narendasan I tried to do refitting for C++ runtime like for Python runtime but didn't work. Any suggestions? should I do in C++ or Python?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesnt refit already work on both apis?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also why do we need the graph module in this module?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. In this PR I moved the refitting part into TRTModule, so only works for Python runtime.

  2. graph module is used for refitting

@@ -619,27 +609,32 @@ def run(
builder_config, self.compilation_settings.timing_cache_path
)

serialized_engine = self.builder.build_serialized_network(
# if strip_engine_weights is true, the serialized engine need to be refitted before using
maybe_unrefitted_serialized_engine = self.builder.build_serialized_network(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this maybe unrefitted engine?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please see the design in the comment below. If compilation_settings.strip_engine_weights is true, it needs to be refitted, else it doesn't. so it's maybe

), "weight-stripped engines must be refittable, please set make_refittable=True"

# Refit the weights
refitter = trt.Refitter(self.engine, TRT_LOGGER)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you use this function?

def _refit_single_trt_engine_with_gm(

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function requires input_list which is not provided in the caller.

@@ -121,6 +124,52 @@ def setup_engine(self) -> None:
self.engine = runtime.deserialize_cuda_engine(self.serialized_engine)
self.context = self.engine.create_execution_context()

if self.settings.strip_engine_weights:
Copy link
Collaborator

@narendasan narendasan Sep 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We likely shouldnt be doing the refit in these modules

I think for weight stripping there are 3 workflows.

  1. a user just wants a weight stripped engine. They should use convert_exported_program_to_trt_engine with settings strip_weights. The choice of make_refittable can be used to decide between kREFIT and kREFIT_IDENTICAL (though it might not be entirely clear so we might want to think about that setting).
  2. We want to utilize weight stripping to have a lighter weight cache. Here this choice is opaque to the user. The user choice of make_refittable controls if we use kREFIT or kREFIT_IDENTICAL. But once the engine is loaded or we pull from cache we immediately refit (prior to passing the engine to the TRTModule). Same as we do today
  3. The user wants a stripped weights compiled program (im not sure why or if this is a real usecase). Here, this is basically the same as lazy engine loading. We would require that users need to run through refit_engine_weights before executing.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it. The very beginning idea/design is commented below. I'll move the refitting part back to TRTInterpreter.run()

The choice of make_refittable can be used to decide between kREFIT and kREFIT_IDENTICAL

Do you mean we use make_refittable to control both kREFIT and kREFIT_IDENTICAL?

Copy link
Collaborator

@narendasan narendasan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zewenli98 do you have a design for this feature?

@zewenli98
Copy link
Collaborator Author

@narendasan Ok, at first the overall design was like:

In TRTInterpreter.run():

if compilation_settings.strip_engine_weights is True:
    if engine_cache not hit:
        1. build a weight-stripped engine
        2. save the weight-stripped engine if engine_cache is set
        3. return the weight-stripped engine (not yet refit)
    else:
        load and return the weight-stripped engine (not yet refit)
else:
    if engine_cache not hit:
        1. build a weight-included engine
        2. save the weight-included engine if engine_cache is set
        3. return the weight-included engine (don't need to refit)
    else:
        load and return the weight-included engine (not yet refit)

Then, in TRTModule, refit if necessary before inference.
The reason that I didn't put the refitting part into TRTInterpreter.run() is that I want to avoid repeated de/serializations of TRT engines: (1) deserialize in TRTInterpreter.run() for refitting and then serialize (2) deserialize in TRTModule again.

@narendasan narendasan closed this Sep 20, 2024
@narendasan narendasan reopened this Sep 20, 2024
@zewenli98
Copy link
Collaborator Author

@narendasan The design was updated.

From the users' perspective, they are able to set make_refittable and refit_identical_engine_weights.
make_refittable for general refitting and refit_identical_engine_weights only for refitting with identical weights

if self.compilation_settings.make_refittable:
    if version.parse(trt.__version__) >= version.parse("10.0"):
        if self.compilation_settings.refit_identical_engine_weights:
            builder_config.set_flag(trt.BuilderFlag.REFIT_IDENTICAL)
        else:
            builder_config.set_flag(trt.BuilderFlag.REFIT)
    else:
        builder_config.set_flag(trt.BuilderFlag.REFIT)

Besides, users can specify strip_engine_weights. If strip_engine_weights is True, TRTInterpreter.run() will return weight-stripped engine. Otherwise, return general engine (with weights).

For the 3 workflows mentioned above,

  1. controlling the args above, users can call convert_exported_program_to_trt_engine specifying strip_engine_weights=True to get weight-stripped engine.

  2. For engine caching, the implementation of weight-stripped engine is opaque to users, which means engine caching mechanism will (1) save weight-stripped engine no matter what settings users specify (make_refittable is required to be true) and then (2) load and refit the weight-stripped engine while reusing cached engines.
    If strip_engine_weights is True, the engine will not be refitted. Instead, just returns weight-stripped engine.

  3. If users specify strip_engine_weights=True, calling torch.compile() or torch_trt.dynamo.compile() will return weight-stripped compiled program. If running the compiled program with inputs, all the results will be zeros. Then, calling refit_module_weights will make weights back, e.g.:

from torch_tensorrt.dynamo._refit import refit_module_weights
refitted_trt_gm = refit_module_weights(trt_gm, exp_program)
refitted_output = refitted_trt_gm(*inputs)

Please see more details in the tests.

@narendasan
Copy link
Collaborator

The reason that I didn't put the refitting part into TRTInterpreter.run() is that I want to avoid repeated de/serializations of TRT engines: (1) deserialize in TRTInterpreter.run() for refitting and then serialize (2) deserialize in TRTModule again.

I think that we need to separate the runtime and the compiler so im willing to spend the time serializing and deserializing.

I think we should frame PR this around moving TRTInterpreter to default to building weight stripped engines. There will be 3 kinds of engines now.

  1. weight strip + refittable (strip_weights + kREFIT) - should move towards this being the default
  2. weight strip + refittable with original weights (strip_weights + kREFIT_INDIVIDUAL)
  3. non_refittable

The first 2 need separate cache entries. So we need to be able to hash on the weights in the case that the model is being built with kREFIT_INDIVIDUAL

We should look to prefer case 1 in the long term as it allows us to reuse the most work, case 2 would be the next preference. Case 2 should produce faster engines than Case 1 so there remains a need to support kREFIT_IDENTICAL

Do you mean we use make_refittable to control both kREFIT and kREFIT_IDENTICAL?

The case for type 3 engines now is only valid if building a non refittable engine is faster than building a refit_identical engine then refitting the weights. If it is not by a significant enough margin I propose we remove that workflow and just have refit or refit_identical engines.

So assuming that we can remove type 3 engines, make_refittable really means "allows the weights to be changed" (we can change the name if needed here), since now both engines are refittable they just have different weight constraints.

@narendasan
Copy link
Collaborator

Some of the open questions are:

  • how we determine if the weights have been refit prior to running the engine. Can TRT tell us without an error?
  • How can we tell if a user is trying to refit an engine with different weights to an engine built with REFIT_IDENTICAL?
  • If building strip weights refit identical + refit is slower than just building?

@zewenli98
Copy link
Collaborator Author

The reason that I didn't put the refitting part into TRTInterpreter.run() is that I want to avoid repeated de/serializations of TRT engines: (1) deserialize in TRTInterpreter.run() for refitting and then serialize (2) deserialize in TRTModule again.

I think that we need to separate the runtime and the compiler so im willing to spend the time serializing and deserializing.

I think we should frame PR this around moving TRTInterpreter to default to building weight stripped engines. There will be 3 kinds of engines now.

  1. weight strip + refittable (strip_weights + kREFIT) - should move towards this being the default
  2. weight strip + refittable with original weights (strip_weights + kREFIT_INDIVIDUAL)
  3. non_refittable

The first 2 need separate cache entries. So we need to be able to hash on the weights in the case that the model is being built with kREFIT_INDIVIDUAL

We should look to prefer case 1 in the long term as it allows us to reuse the most work, case 2 would be the next preference. Case 2 should produce faster engines than Case 1 so there remains a need to support kREFIT_IDENTICAL

Are you referring to kREFIT_IDENTICAL or kREFIT_INDIVIDUAL? The updated design only considered kREFIT_IDENTICAL. kREFIT_INDIVIDUAL is for fine-grained control which is not yet to be considered.

@zewenli98
Copy link
Collaborator Author

zewenli98 commented Sep 20, 2024

  • how we determine if the weights have been refit prior to running the engine. Can TRT tell us without an error?

My current design is: If users specify strip_engine_weights=True in compile, the weights will not be refitted. They will get a weight-stripped engine.
However, if they get an engine somewhere, they can call get_missing_weights() to see if there's any weight not gets refitted.

  • How can we tell if a user is trying to refit an engine with different weights to an engine built with REFIT_IDENTICAL?

I also thought about it earlier. The TRT doc says "if the refit weights are not identical to the build-time weights, behavior is undefined... This enables use of a single set of weights with different inference backends, or with TensorRT plans for multiple GPU architectures."
My understanding is that we cannot tell if weights are identical in build time and refitting, from the perspective of engine itself, because weight-stripped engine doesn't compare weights in build time and refitting phase, or give any prompts. So users need to be clear what they are refitting.

  • If building strip weights refit identical + refit is slower than just building?

will investigate on it.

@zewenli98
Copy link
Collaborator Author

zewenli98 commented Sep 23, 2024

The case for type 3 engines now is only valid if building a non refittable engine is faster than building a refit_identical engine then refitting the weights. If it is not by a significant enough margin I propose we remove that workflow and just have refit or refit_identical engines.

@narendasan I tested on building Resnet18 and vgg16 via the two paths: (1) strip weights + refit_identical + refit (2) non-refittable, build time of the two ways are almost same (diff < 1%), and engine sizes are also almost same (diff < 0.1%). I'm not sure if there are other benefits from non-refittable engines even though the build time, engine size, and performance are the same, like in deployment weights are not allowed to be changed in terms of safety?

@zewenli98
Copy link
Collaborator Author

@narendasan I just confirmed with TRT team, the conclusion is engine built with STRIP_PLAN + REFIT_IDENTICAL + refit is almost same as non-refittable engine. Do you prefer to remove non-refittable engine path?
If yes, the paths would be:

  1. weight strip + refittable (strip_weights + kREFIT) - default
  2. weight strip + refittable with original weights (strip_weights + kREFIT_IDENTICAL)

So assuming that we can remove type 3 engines, make_refittable really means "allows the weights to be changed" (we can change the name if needed here), since now both engines are refittable they just have different weight constraints.

I think we can rename make_refittable to refit_mode: str: Union["general", "identical"] (may be easier to extend in the future?) or refit_identical_weights: bool. Then, we can remove refit_identical_engine_weights arg which has been committed in this PR.

On top of this, STRIP_PLAN will be always on while building engines. we have strip_engine_weights arg to allow users to control if they want to get weight-stripped engines.

In summary, the 3 workflows mentioned above would be:

  1. Users just want a weight stripped engine. They can call convert_exported_program_to_trt_engine specifying strip_engine_weights=True to get weight-stripped engine. It is also supported if the engine is loaded from engine cache.

  2. We want to utilize weight stripping to have a lighter weight engine cache. The implementation of weight-stripped engine is opaque to users. However, if users specify kREFIT or kREFIT_IDENTICAL, they would be considered as different engine and cached twice.

  3. Users want a stripped weights compiled program. They just need to call torch.compile() or torch_trt.dynamo.compile() with strip_engine_weights=True. If running the compiled program with inputs immediately, all the results will always be zeros. Calling refit_module_weights() will make weights back

@narendasan
Copy link
Collaborator

I think we should remove non-refittable then and we can add it back as a non default workflow later if theres some reason to.

Users want a stripped weights compiled program. They just need to call torch.compile() or torch_trt.dynamo.compile() with strip_engine_weights=True. If running the compiled program with inputs immediately, all the results will always be zeros. Calling refit_module_weights() will make weights back

I still dont know what the usecase for this is

@narendasan
Copy link
Collaborator

How can we tell if a user is trying to refit an engine with different weights to an engine built with REFIT_IDENTICAL?

I also thought about it earlier. The TRT doc says "if the refit weights are not identical to the build-time weights, behavior is undefined... This enables use of a single set of weights with different inference backends, or with TensorRT plans for multiple GPU architectures."
My understanding is that we cannot tell if weights are identical in build time and refitting, from the perspective of engine itself, because weight-stripped engine doesn't compare weights in build time and refitting phase, or give any prompts. So users need to be clear what they are refitting.

We should think about a solution for this since behavior is undefined


if "make_refittable" in kwargs.keys():
warnings.warn(
"`make_refittable` is deprecated. All engines are refittable now. If you want to disable refitting, please open an issue on the Github repo with reasons.",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we not giving users the option to build non-refittable engines with this change ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do provide the option but rename to immutable_weights.

builder_config.set_flag(trt.BuilderFlag.REFIT)

builder_config.set_flag(trt.BuilderFlag.STRIP_PLAN)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this happen only if strip_engine_weights is True ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like you were reviewing an outdated commit (not submit after that review lol). To clarify the intention here, previously we planned to embed weight stripping into the compilation workflow whatever strip_engine_weights is because it would benefit engine caching. Afterwards, we will add weights back if users want a weighted engine. However, due to the lack of TRT support, as we discussed, it was deprecated. So for now the latest commit added the if condition.

@@ -220,7 +220,7 @@ def _from(
return dtype.f32
elif t == np.float64:
return dtype.f64
elif t == np.bool:
elif t == np.bool_:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need both np.bool_ and np.bool ?

@@ -111,6 +111,10 @@ def _pretraced_backend(
logger.warning(
"require_full_compilation arg is not applicable for torch.compile with backend='torch_tensorrt"
)
if settings.strip_engine_weights:
logger.warning(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we halt compilation in this case instead of warning ?

@@ -1049,7 +1058,7 @@ def aten_ops_permute(


def to_copy_dtype_validator(
placeholder_only: bool, settings: CompilationSettings = None
placeholder_only: bool, settings: Optional[CompilationSettings] = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why the change to optional here ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lint pointed it out. I think it's because the default is None

@github-actions github-actions bot added the component: converters Issues re: Specific op converters label Nov 20, 2024
@github-actions github-actions bot removed the component: converters Issues re: Specific op converters label Nov 20, 2024
@keehyuna
Copy link
Collaborator

@zewenli98 / @peri044
I could reproduce "misaligned address cuda error" with weight streaming test.
This problem started when refit build flag is add as default. When I remove trt.BuilderFlag.REFIT, test passed.
I will check further.

def _populate_trt_builder_config(
    self,
    strict_type_constraints: bool = False,
    algorithm_selector: Optional[trt.IAlgorithmSelector] = None,
    tactic_sources: Optional[int] = None,
) -> trt.IBuilderConfig:

...
if self.compilation_settings.immutable_weights:
# non-refittable engine
if self.compilation_settings.strip_engine_weights:
_LOGGER.warning("strip_engine_weights will be ignored.")
else:
# refittable engine
if self.compilation_settings.refit_identical_engine_weights:
builder_config.set_flag(trt.BuilderFlag.REFIT_IDENTICAL)
else:
builder_config.set_flag(trt.BuilderFlag.REFIT) <----- refit flag is set as default

@zewenli98
Copy link
Collaborator Author

Thanks @keehyuna! Does this error happen with trt.BuilderFlag.REFIT_IDENTICAL flag?

@keehyuna
Copy link
Collaborator

Thanks @keehyuna! Does this error happen with trt.BuilderFlag.REFIT_IDENTICAL flag?

This is no problem when trt.BuilderFlag.REFIT_IDENTICAL is used. It was tested on main branch.

@github-actions github-actions bot added the component: converters Issues re: Specific op converters label Nov 23, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed component: api [Python] Issues re: Python API component: conversion Issues re: Conversion stage component: converters Issues re: Specific op converters component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths component: runtime component: tests Issues re: Tests component: torch_compile
Projects
None yet
Development

Successfully merging this pull request may close these issues.

✨[Feature] Weight specific engine caching
6 participants