Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support graph-by-graph benchmarking for PyTorch native checkpointing #1437

Merged
merged 4 commits into from
Nov 19, 2024

Conversation

kiya00
Copy link
Collaborator

@kiya00 kiya00 commented Nov 14, 2024

Before submitting
  • Was this discussed/approved via a Github issue? (no need for typos and docs improvements)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure to update the docs?
  • Did you write any new necessary tests?

What does this PR do?

The converter replaces the Torch operators in the checkpoint function with Thunder operators in-place, and also the compiled thunder/inductor module is replaced in-place, but the ThunderCompilerGraphBenchmarking/saving reproduction script needs the original GraphModule to compile/save.

Previously, the deepcopy of GraphModule is blocked by a pytorch/pytorch#139275. Thanks to @kshitij12345 for helping to fix it, we can use it to support the graph-by-graph benchmarking for PyTorch native checkpointing starting from Torch 2.6

Note:

  • The code is also useful for saving the reproduction script ThunderFX: Save the reproducer script into files #1380
  • The above mentioned deepcopy bug affects the test case test_thundercompiler_optim_step, so it's skipped
  • Torch 2.6 changes the GraphModule structure, the checkpoint function becomes a submodule of the module containing tag_activation_checkpoint
# Torch 2.6
GraphModule(
  (submod_1): GraphModule(
    (wrap_body_0): GraphModule()
  )
)
# before 2.6
GraphModule(
  (wrap_body_0): GraphModule()
  (submod_1): GraphModule()
)

Before 2.6, in order to get the input tensor of submod_1 we need to calculate the wrap_body_0 ourselves (wrap_body_0 is a placeholder node in submod_1 module, and there's no example_value in node.meta); In 2.6, the wrap_body_0 is a get_attr node in submod_1 module, not an input. Since the latest Torch is much cleaner, we don't currently support benchmark checkpointing in Torch<2.6.

Fixes #1381.

@kiya00
Copy link
Collaborator Author

kiya00 commented Nov 19, 2024

Hi @IvanYashchuk @kshitij12345 , could you review it?

Copy link
Collaborator

@kshitij12345 kshitij12345 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks @kiya00

Copy link
Collaborator

@t-vi t-vi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Stamping Thank you @kiya00 @kshitij12345

@t-vi t-vi enabled auto-merge (squash) November 19, 2024 11:10
@t-vi t-vi merged commit 60f3ee1 into main Nov 19, 2024
41 checks passed
@t-vi t-vi deleted the bench_checkpoint branch November 19, 2024 12:44
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

ThunderFX: support graph-by-graph benchmarking for PyTorch native checkpointing
3 participants