-
Notifications
You must be signed in to change notification settings - Fork 257
[PT2E][X86] Migrate fusion passes in Inductor to torchao #2140
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2140
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New FailureAs of commit 8e4532f with merge base 137b079 ( NEW FAILURE - The following job has failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Hi @jerryzh168 @jansel Could you please review this PR? We would like to hear your comments especially on (1) if it sounds ok to you that we copy Inductor code here in Torchao with Inductor's internal utilities, (2) if it is ok that we keep duplicate passes for now. Thanks! |
I think out of tree passes are fine. Do we need a better registration system so the changes can be local to a specific torch.compile() call rather than mutating globals? cc @eellison |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should this be in prototype? I think under torchao/quantization/pt2e
might be better?
also the folder name can probably be something like inductor_passes
to be more specific
I'd recommend: torchao/quantization/pt2e/inductor_passes/x86.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. I have moved it as you suggested.
I also feel hiding compile API in |
global FUSION_PATH_REGISTERED | ||
if not FUSION_PATH_REGISTERED: | ||
global torch | ||
import torch._inductor.config | ||
|
||
from torchao.prototype.inductor.fx_passes.quantization import ( | ||
_register_quantization_weight_pack_pass, | ||
quant_lift_up, | ||
) | ||
|
||
torch._inductor.config.pre_grad_custom_pass = quant_lift_up | ||
_register_quantization_weight_pack_pass() | ||
FUSION_PATH_REGISTERED = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can this part happen during import of x86_inductor_quantizer
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. I have modified per your suggestion.
Thanks for your comments. We will just keep the current implementation then. |
Yeah, I think this might be cleaner with something like |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @Xia-Weiwen, will you add the registration system in PyTorch firstly then refine this PR?
No. I plan to keep the current implementation. When new registration system is added in Inductor by Meta Inductor team, I will switch to that in another PR. |
) | ||
|
||
torch._inductor.config.pre_grad_custom_pass = quant_lift_up | ||
_register_quantization_weight_pack_pass() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm a bit concerned about this. Not sure how we should handle it, but it seems that
- The patterns from
quantization.py in TorchAO
will be registered here once - And inside
torch.compile
when freezing turns on, the same patterns frompytorch/torch/_inductor/fx_passes/quantization.py inside Torch Inductor
will be registered again.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the comments. As we discussed offline and I have explained in the summary above, duplicate passes will be applied only once because once applied, the pattern is gone.
quant_lift_up, | ||
) | ||
|
||
torch._inductor.config.pre_grad_custom_pass = quant_lift_up |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Be careful to check if there is any other pre_grad_custom_pass
registered before, check this pytorch/pytorch#151876 issue cc @Valentine233 who are working on it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the comment. It is potentially unsafe. I will modify this part after the pre_grad_custom_pass
is refactored, probably in another PR.
Thanks for the comments. I have moved the registration out of the lowering function. Please review again. Thanks. |
Summary
In this PR, we migrate the fusion passes of quantized ops for X86Inductor backend from PyTorch Inductor source code to Torchao. This is the first step to migrate quantization-related fusion passes in PyTorch core to Torchao.
With this PR landed, we can add fusion passes for new ops in Torchao instead of in PyTorch core. So, we want this PR merged early.
We plan to do the migration in the following steps:
(Step 2 and 3 have no dependency on each other and can be reordered.)
Fusion passes need to be registered to Inductor before calling torch.compile. And it would be less user-friendly if we ask users to register them in their code. So, we decide to put the registration inside the lowering function. In other words, this PR wraps the registration in the API
lower_pt2e_quantized_to_x86
. So, users need to calllower_pt2e_quantized_to_x86
instead oftorch.compile
to get lowered model withtorch.compile
. For eager mode, users use the same API. The API is now designed as below:The
compile
flag indicates usingtorch.compile
or not (eager mode). For eager mode, users just setcompile=False
.Test plan
We copied related UTs from https://github.com/pytorch/pytorch/blob/main/test/inductor/test_mkldnn_pattern_matcher.py
The test cases are run only with torch nightly since some torch features are only available in nightly, such as
onednn.qconv_pointwise
.Use the following cmd to run tests
Explanation of implementation