Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid FX Constant Folding in rw_sharding #2777

Closed
wants to merge 1 commit into from

Conversation

cp2923
Copy link

@cp2923 cp2923 commented Mar 4, 2025

Summary:
Tensor_constants nodes are known to have issues with delta update, we need to ensure the model graph does not containing tensor_constants nodes

| Constant tensor name | code file | code |
| -- |
| Input_dist._tensor_constant0, ..1, ..2, etc | rw_sharding.py | torch.tensor(block_sizes, ...) |

Differential Revision: D70577218

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 4, 2025
@cp2923 cp2923 changed the title avoid Constant Folding in rw_sharding avoid FX Constant Folding in rw_sharding Mar 4, 2025
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D70577218

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D70577218

@cp2923 cp2923 changed the title avoid FX Constant Folding in rw_sharding Avoid FX Constant Folding in rw_sharding Mar 4, 2025
@cp2923 cp2923 force-pushed the export-D70577218 branch from 58ae69c to 71fe2aa Compare March 5, 2025 22:49
cp2923 pushed a commit to cp2923/torchrec that referenced this pull request Mar 5, 2025
Summary:

Tensor_constants nodes are known to have issues with delta update, that the predictor cannot find the tensors in the delta weights. To enable the delta publish works for blue_reels_vdd/v5, we need to ensure the model graph does not containing tensor_constants nodes

Breakdown to two diffs, this one contains the file change requires OSS pull request

| Constant tensor name  |  fx paste |  code file | code |
| -- |
| merge._tensor_constant0  |  https://www.internalfb.com/intern/everpaste/?color=0&handle=GLJIlhayl6fBeToDACvUvTwjbsMkbr0LAAAz   |  model_hstu_standalone_cint.py | torch.tensor(self.tasks_with_uid_emb_indices).to(r_over.device) |
|  merge.main_module.module.sim_module.module_impl._tensor_constant0 | https://www.internalfb.com/intern/everpaste/?handle=GG71xRydBrkmAncFAPQ_36FF6X8XbsIXAAAB  |  sim.py | self.action_weights[None, : ] |
|  Input_dist._tensor_constant0, ..1, ..2, etc |  https://www.internalfb.com/intern/everpaste/?color=0&handle=GG7hnhaueOUr4MIFAAViKnOmlYozbr0LAAAz  | rw_sharding.py | torch.tensor(block_sizes, ...) |

Differential Revision: D70577218
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D70577218

cp2923 pushed a commit to cp2923/torchrec that referenced this pull request Mar 5, 2025
Summary:
Pull Request resolved: pytorch#2777

Tensor_constants nodes are known to have issues with delta update, that the predictor cannot find the tensors in the delta weights. To enable the delta publish works for blue_reels_vdd/v5, we need to ensure the model graph does not containing tensor_constants nodes

Breakdown to two diffs, this one contains the file change requires OSS pull request

| Constant tensor name  |  fx paste |  code file | code |
| -- |
| merge._tensor_constant0  |  https://www.internalfb.com/intern/everpaste/?color=0&handle=GLJIlhayl6fBeToDACvUvTwjbsMkbr0LAAAz   |  model_hstu_standalone_cint.py | torch.tensor(self.tasks_with_uid_emb_indices).to(r_over.device) |
|  merge.main_module.module.sim_module.module_impl._tensor_constant0 | https://www.internalfb.com/intern/everpaste/?handle=GG71xRydBrkmAncFAPQ_36FF6X8XbsIXAAAB  |  sim.py | self.action_weights[None, : ] |
|  Input_dist._tensor_constant0, ..1, ..2, etc |  https://www.internalfb.com/intern/everpaste/?color=0&handle=GG7hnhaueOUr4MIFAAViKnOmlYozbr0LAAAz  | rw_sharding.py | torch.tensor(block_sizes, ...) |

Differential Revision: D70577218
@cp2923 cp2923 force-pushed the export-D70577218 branch from 71fe2aa to 936e497 Compare March 5, 2025 22:54
cp2923 pushed a commit to cp2923/torchrec that referenced this pull request Mar 6, 2025
Summary:

Tensor_constants nodes are known to have issues with delta update, that the predictor cannot find the tensors in the delta weights. To enable the delta publish works for blue_reels_vdd/v5, we need to ensure the model graph does not containing tensor_constants nodes

Breakdown to two diffs, this one contains the file change requires OSS pull request

| Constant tensor name  |  fx paste |  code file | code |
| -- |
| merge._tensor_constant0  |  https://www.internalfb.com/intern/everpaste/?color=0&handle=GLJIlhayl6fBeToDACvUvTwjbsMkbr0LAAAz   |  model_hstu_standalone_cint.py | torch.tensor(self.tasks_with_uid_emb_indices).to(r_over.device) |
|  merge.main_module.module.sim_module.module_impl._tensor_constant0 | https://www.internalfb.com/intern/everpaste/?handle=GG71xRydBrkmAncFAPQ_36FF6X8XbsIXAAAB  |  sim.py | self.action_weights[None, : ] |
|  Input_dist._tensor_constant0, ..1, ..2, etc |  https://www.internalfb.com/intern/everpaste/?color=0&handle=GG7hnhaueOUr4MIFAAViKnOmlYozbr0LAAAz  | rw_sharding.py | torch.tensor(block_sizes, ...) |

Differential Revision: D70577218
@cp2923 cp2923 force-pushed the export-D70577218 branch from 936e497 to 59df76a Compare March 6, 2025 00:28
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D70577218

cp2923 pushed a commit to cp2923/torchrec that referenced this pull request Mar 6, 2025
Summary:
Pull Request resolved: pytorch#2777

Tensor_constants nodes are known to have issues with delta update, that the predictor cannot find the tensors in the delta weights. To enable the delta publish works for blue_reels_vdd/v5, we need to ensure the model graph does not containing tensor_constants nodes

Breakdown to two diffs, this one contains the file change requires OSS pull request

| Constant tensor name  |  fx paste |  code file | code |
| -- |
| merge._tensor_constant0  |  https://www.internalfb.com/intern/everpaste/?color=0&handle=GLJIlhayl6fBeToDACvUvTwjbsMkbr0LAAAz   |  model_hstu_standalone_cint.py | torch.tensor(self.tasks_with_uid_emb_indices).to(r_over.device) |
|  merge.main_module.module.sim_module.module_impl._tensor_constant0 | https://www.internalfb.com/intern/everpaste/?handle=GG71xRydBrkmAncFAPQ_36FF6X8XbsIXAAAB  |  sim.py | self.action_weights[None, : ] |
|  Input_dist._tensor_constant0, ..1, ..2, etc |  https://www.internalfb.com/intern/everpaste/?color=0&handle=GG7hnhaueOUr4MIFAAViKnOmlYozbr0LAAAz  | rw_sharding.py | torch.tensor(block_sizes, ...) |

Differential Revision: D70577218
@cp2923 cp2923 force-pushed the export-D70577218 branch from 59df76a to 19e17fd Compare March 6, 2025 00:34
Summary:

Tensor_constants nodes are known to have issues with delta update, that the predictor cannot find the tensors in the delta weights.
'tensor_cache' in this diff introduces tenor constant in fx.

Solution:
The solution is to replace 'tensor_cache' through 'register_buffer' + move device and data_type at first forward (so as to keep performance parity).

Note this solution breaks a unit test
```
assertFalse(
            hasattr(
                local.ro_ec,
                "_root_mc_embedding_collection",
            )
        )
```

but it still won't introduce big table like tbes.

Reviewed By: jingsh

Differential Revision: D70577218
@cp2923 cp2923 force-pushed the export-D70577218 branch from 19e17fd to 2f83768 Compare March 10, 2025 23:23
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D70577218

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. fb-exported
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants