Skip to content

[PT2E][X86] Migrate fusion passes in Inductor to torchao #2140

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 14 commits into
base: main
Choose a base branch
from

Conversation

Xia-Weiwen
Copy link
Collaborator

@Xia-Weiwen Xia-Weiwen commented Apr 28, 2025

Summary

In this PR, we migrate the fusion passes of quantized ops for X86Inductor backend from PyTorch Inductor source code to Torchao. This is the first step to migrate quantization-related fusion passes in PyTorch core to Torchao.
With this PR landed, we can add fusion passes for new ops in Torchao instead of in PyTorch core. So, we want this PR merged early.
We plan to do the migration in the following steps:

  1. Copy fusion passes from PyTorch core to Torchao (this PR)
  2. Depreacte and remove fusion passes in PyTorch core (TODO)
  3. Switch to new quantize/dequantize ops in Torchao (TODO)

(Step 2 and 3 have no dependency on each other and can be reordered.)

Fusion passes need to be registered to Inductor before calling torch.compile. And it would be less user-friendly if we ask users to register them in their code. So, we decide to put the registration inside the lowering function. In other words, this PR wraps the registration in the API lower_pt2e_quantized_to_x86. So, users need to call lower_pt2e_quantized_to_x86 instead of torch.compile to get lowered model with torch.compile. For eager mode, users use the same API. The API is now designed as below:

def lower_pt2e_quantized_to_x86(
    model: torch.fx.GraphModule,
    example_inputs: Optional[tuple[torch.Tensor, ...]] = None,
    compile: bool = True,
    **compile_options: Optional[dict],
) -> torch.fx.GraphModule

The compile flag indicates using torch.compile or not (eager mode). For eager mode, users just set compile=False.

Test plan

We copied related UTs from https://github.com/pytorch/pytorch/blob/main/test/inductor/test_mkldnn_pattern_matcher.py
The test cases are run only with torch nightly since some torch features are only available in nightly, such as onednn.qconv_pointwise.
Use the following cmd to run tests

pytest test/quantization/pt2e/test_x86inductor_fusion.py

Explanation of implementation

  • In this PR, we mostly copy the code from torch Inductor https://github.com/pytorch/pytorch/blob/main/torch/_inductor/fx_passes/quantization.py, using internal functions, methods and utilities in Inductor by importing them directly. We think it's the simplest way to register the fusion passes.
  • For now, the fusion passes in torch Inductor will co-exist with the passes registered in torchao. There won't be an issue because duplicate passes won't be applied twice. It's because the patterns no longer exist in graph after fusion, and one pattern won't be matched twice.
  • In the future, we will switch to the new quantize/dequantize ops in torchao when they are ready. At that time, the patterns registered in torchao will be different from those in Inductor. After that, the passes in torch Inductor will be deprecated and eventually removed.

Copy link

pytorch-bot bot commented Apr 28, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2140

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure

As of commit 8e4532f with merge base 137b079 (image):

NEW FAILURE - The following job has failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 28, 2025
@Xia-Weiwen Xia-Weiwen added topic: new feature Use this tag if this PR adds a new feature and removed CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. labels Apr 28, 2025
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 28, 2025
@Xia-Weiwen Xia-Weiwen marked this pull request as ready for review April 29, 2025 01:46
@Xia-Weiwen
Copy link
Collaborator Author

Hi @jerryzh168 @jansel Could you please review this PR? We would like to hear your comments especially on (1) if it sounds ok to you that we copy Inductor code here in Torchao with Inductor's internal utilities, (2) if it is ok that we keep duplicate passes for now. Thanks!

@jansel
Copy link

jansel commented Apr 29, 2025

I think out of tree passes are fine.

Do we need a better registration system so the changes can be local to a specific torch.compile() call rather than mutating globals?

cc @eellison

Copy link
Contributor

@jerryzh168 jerryzh168 Apr 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this be in prototype? I think under torchao/quantization/pt2e might be better?

also the folder name can probably be something like inductor_passes to be more specific

I'd recommend: torchao/quantization/pt2e/inductor_passes/x86.py

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. I have moved it as you suggested.

jerryzh168
jerryzh168 approved these changes Apr 29, 2025
@jerryzh168
Copy link
Contributor

I also feel hiding compile API in lower_pt2e_quantized_to_x86 is not a good idea and compile stack should allow registering fusion passes out of tree

Comment on lines 40 to 52
global FUSION_PATH_REGISTERED
if not FUSION_PATH_REGISTERED:
global torch
import torch._inductor.config

from torchao.prototype.inductor.fx_passes.quantization import (
_register_quantization_weight_pack_pass,
quant_lift_up,
)

torch._inductor.config.pre_grad_custom_pass = quant_lift_up
_register_quantization_weight_pack_pass()
FUSION_PATH_REGISTERED = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can this part happen during import of x86_inductor_quantizer?

Copy link
Collaborator Author

@Xia-Weiwen Xia-Weiwen Apr 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. I have modified per your suggestion.

@jerryzh168 jerryzh168 self-requested a review April 29, 2025 19:41
@Xia-Weiwen
Copy link
Collaborator Author

I think out of tree passes are fine.

Do we need a better registration system so the changes can be local to a specific torch.compile() call rather than mutating globals?

cc @eellison

Thanks for your comments. We will just keep the current implementation then.
As for a new registration system, maybe we can have something similar to pre_grad_custom_pass?

@jansel
Copy link

jansel commented Apr 30, 2025

Yeah, I think this might be cleaner with something like pre_grad_custom_pass instead of global registration.

Copy link
Collaborator

@leslie-fang-intel leslie-fang-intel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @Xia-Weiwen, will you add the registration system in PyTorch firstly then refine this PR?

@Xia-Weiwen
Copy link
Collaborator Author

Hi @Xia-Weiwen, will you add the registration system in PyTorch firstly then refine this PR?

No. I plan to keep the current implementation. When new registration system is added in Inductor by Meta Inductor team, I will switch to that in another PR.

)

torch._inductor.config.pre_grad_custom_pass = quant_lift_up
_register_quantization_weight_pack_pass()
Copy link
Collaborator

@leslie-fang-intel leslie-fang-intel Apr 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit concerned about this. Not sure how we should handle it, but it seems that

  • The patterns from quantization.py in TorchAO will be registered here once
  • And inside torch.compile when freezing turns on, the same patterns from pytorch/torch/_inductor/fx_passes/quantization.py inside Torch Inductor will be registered again.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the comments. As we discussed offline and I have explained in the summary above, duplicate passes will be applied only once because once applied, the pattern is gone.

quant_lift_up,
)

torch._inductor.config.pre_grad_custom_pass = quant_lift_up
Copy link
Collaborator

@leslie-fang-intel leslie-fang-intel Apr 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Be careful to check if there is any other pre_grad_custom_pass registered before, check this pytorch/pytorch#151876 issue cc @Valentine233 who are working on it.

Copy link
Collaborator Author

@Xia-Weiwen Xia-Weiwen Apr 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the comment. It is potentially unsafe. I will modify this part after the pre_grad_custom_pass is refactored, probably in another PR.

@Xia-Weiwen
Copy link
Collaborator Author

I also feel hiding compile API in lower_pt2e_quantized_to_x86 is not a good idea and compile stack should allow registering fusion passes out of tree

Thanks for the comments. I have moved the registration out of the lowering function. Please review again. Thanks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: new feature Use this tag if this PR adds a new feature
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants