Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch.fx of torch==2.5.1 seems to have some limitation with which ThunderFX is not content #1457

Closed
crcrpar opened this issue Nov 20, 2024 · 4 comments
Labels

Comments

@crcrpar
Copy link
Collaborator

crcrpar commented Nov 20, 2024

Observed in #1456.
This could fall into a category of "won't get fixed" because the nightly looks fine.

🐛 Bug

To Reproduce

  1. Start pytorchlightning/lightning-thunder:ubuntu22.04-cuda12.1.1-cudnn-fe1.5.2-py3.10-pt_2.5.1-dev
  2. Clone crpa/inplace-test-real-adamw branch, more specifically, f94892c.
  3. Run thunder/tests/test_inplace_functionalization.py::test_adamw_with_pythia14m_torchcompile

Error:

______ test_adamw_with_pythia14m_torchcompile_cuda_thunder.dtypes.float32 ______
[gw5] linux -- Python 3.10.12 /usr/bin/python3.10

self = <torch._dynamo.output_graph.OutputGraph object at 0x7f358c1d9330>
gm = GraphModule()

    def _call_user_compiler(self, gm: fx.GraphModule) -> CompiledFn:
        assert self.compiler_fn is not None
        tot = 0
        placeholders = []
        for node in gm.graph.nodes:
            if node.op in ("call_function", "call_method", "call_module"):
                tot += 1
            if node.op == "placeholder":
                placeholders.append(node)
        increment_op_count(tot)
        for pl in placeholders:
            arg = pl.meta["grapharg"]
            # TODO: Why isn't this stored in meta :think:
            pl._dynamo_source = arg.source
    
        gm._param_name_to_source = self.param_name_to_source  # type: ignore[assignment]
        gm._source_to_user_stacks = self.source_to_user_stacks  # type: ignore[assignment]
    
        try:
            name = (
                self.compiler_fn.__name__
                if hasattr(self.compiler_fn, "__name__")
                else ""
            )
            _step_logger()(logging.INFO, f"calling compiler function {name}")
            compiler_fn = self.compiler_fn
            if config.verify_correctness:
                compiler_fn = WrapperBackend(compiler_fn)
>           compiled_fn = compiler_fn(gm, self.example_inputs())

/usr/local/lib/python3.10/dist-packages/torch/_dynamo/output_graph.py:1446: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
/usr/local/lib/python3.10/dist-packages/torch/_dynamo/repro/after_dynamo.py:129: in __call__
    compiled_gm = compiler_fn(gm, example_inputs)
/usr/local/lib/python3.10/dist-packages/torch/__init__.py:2279: in __call__

Expected behavior

Environment

  • PyTorch Version (e.g., 1.0): 2.5.1+cu121
  • OS (e.g., Linux): Ubuntu 22.04
  • How you installed PyTorch (conda, pip, source): pip
  • Build command you used (if compiling from source): N/A
  • Python version: 3.10.12
  • CUDA/cuDNN version: N/A
  • GPU models and configuration: H100
  • Any other relevant information: N/A

Additional context

torch.fx.Graph.__deepcopy__ doesn't seem to have been updated for a while according to https://github.com/pytorch/pytorch/blob/main/torch/fx/graph.py#L1090-L1108.

@IvanYashchuk
Copy link
Collaborator

If GraphModule.__deepcopy__ misses something we can of course patch it temporarily in Thunder. But as you say the nightly build is working so we should wait for the next stable release.

@kshitij12345, could this problem be the same as pytorch/pytorch#138207?

@kshitij12345
Copy link
Collaborator

Yes, it looks to be the same as pytorch/pytorch#138207. Fix will be available in next release.

@kshitij12345
Copy link
Collaborator

Also, currently we have this the skip below for thunderFX optim test which fails for the same reason

pytest.mark.skipif(
LooseVersion(torch.__version__) < LooseVersion("2.6.0"),
reason="Skip until the Torch bug is fixed - https://github.com/pytorch/pytorch/pull/139275",
),
),
)
@requiresCUDA
def test_thundercompiler_optim_step(executor, device, dtype, optim):

@crcrpar
Copy link
Collaborator Author

crcrpar commented Nov 20, 2024

ah, I forgot about the test case. thank you both

@crcrpar crcrpar closed this as completed Nov 20, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

3 participants