Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update to pytorch 2.6 #8944

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open

Update to pytorch 2.6 #8944

wants to merge 1 commit into from

Conversation

qihqi
Copy link
Collaborator

@qihqi qihqi commented Apr 5, 2025

Pytorch 2.6 introduced usage of decompositions via CompositeImplicitAutograd key: https://github.com/pytorch/pytorch/blob/60a45eb862d5e8b4ba2dd435d34ef04ae231e885/torch/_export/utils.py#L1249

These decompositions currently does not play nice with torch_dispatch based decomposition, resulting infinite recursion.
Likely we are using it wrong, I made a post here soliciting the right way to use decomposition https://dev-discuss.pytorch.org/t/what-is-the-right-way-to-use-decompositions-in-dispatch-mode/2888 soliciting advice on the topic.

Meanwhile the workaround is to explicitly list out decompositions we like without the CompositeImplicitAutograd decompositions. The explicit list is produced via keys in core_aten_decompositions() call using torch 2.5.1.

@qihqi qihqi requested review from tengyifei and ManfeiBai April 5, 2025 17:50
[tool.pytest.ini_options]
addopts="-n auto"
#[tool.pytest.ini_options]
#addopts="-n auto"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this intended change?

@@ -129,6 +129,7 @@ def run_export_and_compare(testcase,
equal_nan=True,
ignore_indices=False):
atol, rtol = (1e-3, 1e-5)
#breakpoint()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this intended change?

@tengyifei
Copy link
Collaborator

I wonder if it's possible to eliminate all autograd-related dispatch keys when calling these torch ops, since we just need decompositions in the forward direction and can rely on Jax autograd, I think..

It maybe runs into other PyTorch problems

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants