Skip to content

Commit bcfac7b

Browse files
che-shfacebook-github-bot
authored andcommitted
Refactor get_node_args and friends into a class (#2741)
Summary: Pull Request resolved: #2741 Torchrec rewriting logic got a bit hairy over the years, this sequence of changes aims to refactor the rewrite logic to be less convoluted and more maintainable in the future. This change: _get_node_args and related functions pass around lot of "context" (train_pipeline_context, streams, etc.) that rarely or never changes + some "state" (model, pipelined_preprocs) that is accumulated during the run. Refactoring `_get_node_args` (and friends) into a class allows initializing/passing those into class constructor, and simplifies the call signatures a lot Internal Diff stack navigation: 1. D69292525 and below - before refactoring 2. D69438143 - Refactor get_node_args and friends into a class (**you are here**) 3. D69461227 - refactor "joint lists" in ArgInfo into a list of ArgInfoStep 4. D69461226 - refactor `_build_args_kwargs` into instance methods on ArgInfo and ArgInfoStep 5. D69461228 - split monolithic `ArgInfoStep` into a class hierarchy Reviewed By: sarckk Differential Revision: D69438143
1 parent 919bbcb commit bcfac7b

File tree

2 files changed

+332
-355
lines changed

2 files changed

+332
-355
lines changed

torchrec/distributed/train_pipeline/tests/test_train_pipelines_utils.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,9 @@
2424
)
2525
from torchrec.distributed.train_pipeline.utils import (
2626
_build_args_kwargs,
27-
_get_node_args,
2827
_rewrite_model,
2928
ArgInfo,
29+
NodeArgsHelper,
3030
PipelinedForward,
3131
PipelinedPostproc,
3232
TrainPipelineContext,
@@ -367,10 +367,9 @@ def test_get_node_args_helper_call_module_kjt(self) -> None:
367367
{},
368368
)
369369

370-
num_found = 0
371-
_, num_found = _get_node_args(
372-
MagicMock(), kjt_node, set(), TrainPipelineContext(), False
373-
)
370+
node_args_helper = NodeArgsHelper(MagicMock(), TrainPipelineContext(), False)
371+
372+
_, num_found = node_args_helper.get_node_args(kjt_node)
374373

375374
# Weights is call_module node, so we should only find 2 args unmodified
376375
self.assertEqual(num_found, len(kjt_args) - 1)

0 commit comments

Comments
 (0)