Skip to content

Commit

Permalink
Revert torchrec pipeline refactor (#2791)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2791

Reverts torchrec pipeline refactoring in stack #2741

Reviewed By: dstaay-fb

Differential Revision: D70911851

fbshipit-source-id: a8cdd51e61c9ead0916c2cc6661dd1dd636eb14e
  • Loading branch information
sarckk authored and facebook-github-bot committed Mar 11, 2025
1 parent c5a4ff1 commit 315539b
Show file tree
Hide file tree
Showing 4 changed files with 577 additions and 595 deletions.
2 changes: 0 additions & 2 deletions torchrec/distributed/train_pipeline/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@
_to_device, # noqa
_wait_for_batch, # noqa
ArgInfo, # noqa
ArgInfoStepFactory, # noqa
CallArgs, # noqa
DataLoadingThread, # noqa
In, # noqa
Out, # noqa
Expand Down
151 changes: 71 additions & 80 deletions torchrec/distributed/train_pipeline/tests/test_train_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,9 @@
from torchrec.distributed.train_pipeline.utils import (
DataLoadingThread,
get_h2d_func,
GetAttrArgInfoStep,
GetItemArgInfoStep,
NoopArgInfoStep,
PipelinedForward,
PipelinedPostproc,
PipelineStage,
PostprocArgInfoStep,
SparseDataDistUtil,
StageOut,
TrainPipelineContext,
Expand Down Expand Up @@ -1025,56 +1022,44 @@ def test_pipeline_postproc_not_shared_with_arg_transform(self) -> None:
pipelined_weighted_ebc = pipeline._pipelined_modules[1]

# Check pipelined args
self.assertEqual(len(pipelined_ebc.forward._args.args), 1)
self.assertEqual(len(pipelined_ebc.forward._args.kwargs), 0)
for ebc in [pipelined_ebc, pipelined_weighted_ebc]:
self.assertEqual(len(ebc.forward._args), 1)
self.assertEqual(ebc.forward._args[0].input_attrs, ["", 0])
self.assertEqual(ebc.forward._args[0].is_getitems, [False, True])
self.assertEqual(len(ebc.forward._args[0].postproc_modules), 2)
self.assertIsInstance(
ebc.forward._args[0].postproc_modules[0], PipelinedPostproc
)
self.assertEqual(ebc.forward._args[0].postproc_modules[1], None)

self.assertEqual(
pipelined_ebc.forward._args.args[0].steps,
[
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `_postproc_module`.
PostprocArgInfoStep(pipelined_model.module.postproc_nonweighted),
GetItemArgInfoStep(0),
],
pipelined_ebc.forward._args[0].postproc_modules[0],
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute
# `postproc_nonweighted`.
pipelined_model.module.postproc_nonweighted,
)
self.assertEqual(len(pipelined_weighted_ebc.forward._args.args), 1)
self.assertEqual(len(pipelined_weighted_ebc.forward._args.kwargs), 0)
self.assertEqual(
pipelined_weighted_ebc.forward._args.args[0].steps,
[
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `_postproc_module`.
PostprocArgInfoStep(pipelined_model.module.postproc_weighted),
GetItemArgInfoStep(0),
],
pipelined_weighted_ebc.forward._args[0].postproc_modules[0],
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute
# `postproc_weighted`.
pipelined_model.module.postproc_weighted,
)

# postproc args
self.assertEqual(len(pipeline._pipelined_postprocs), 2)
# postprocs can be added in any order, so we can't assert on exact steps structures
self.assertEqual(len(pipeline._pipelined_postprocs[0]._args.args), 1)
self.assertEqual(len(pipeline._pipelined_postprocs[0]._args.kwargs), 0)
self.assertEqual(len(pipeline._pipelined_postprocs[0]._args.args[0].steps), 2)
self.assertEqual(
pipeline._pipelined_postprocs[0]._args.args[0].steps[0], NoopArgInfoStep()
)
self.assertIsInstance(
pipeline._pipelined_postprocs[0]._args.args[0].steps[1], GetAttrArgInfoStep
)
input_attr_names = {"idlist_features", "idscore_features"}
for i in range(len(pipeline._pipelined_postprocs)):
postproc_mod = pipeline._pipelined_postprocs[i]
self.assertEqual(len(postproc_mod._args), 1)

self.assertEqual(len(pipeline._pipelined_postprocs[1]._args.args), 1)
self.assertEqual(len(pipeline._pipelined_postprocs[1]._args.kwargs), 0)
self.assertEqual(len(pipeline._pipelined_postprocs[1]._args.args[0].steps), 2)
self.assertEqual(
pipeline._pipelined_postprocs[1]._args.args[0].steps[0], NoopArgInfoStep()
)
self.assertIsInstance(
pipeline._pipelined_postprocs[1]._args.args[0].steps[1], GetAttrArgInfoStep
)
input_attr_name = postproc_mod._args[0].input_attrs[1]
self.assertTrue(input_attr_name in input_attr_names)
self.assertEqual(postproc_mod._args[0].input_attrs, ["", input_attr_name])
input_attr_names.remove(input_attr_name)

get_arg_infos = {
# pyre-fixme[16]: assertions above ensure that steps[1] is a GetAttrArgInfoStep
postproc._args.args[0].steps[1].attr_name
for postproc in pipeline._pipelined_postprocs
}
self.assertEqual(get_arg_infos, {"idlist_features", "idscore_features"})
self.assertEqual(postproc_mod._args[0].is_getitems, [False, False])
# no parent postproc module in FX graph
self.assertEqual(postproc_mod._args[0].postproc_modules, [None, None])

# pyre-ignore
@unittest.skipIf(
Expand Down Expand Up @@ -1120,63 +1105,69 @@ def test_pipeline_postproc_recursive(self) -> None:
pipelined_weighted_ebc = pipeline._pipelined_modules[1]

# Check pipelined args
self.assertEqual(len(pipelined_ebc.forward._args.args), 1)
self.assertEqual(len(pipelined_ebc.forward._args.kwargs), 0)
for ebc in [pipelined_ebc, pipelined_weighted_ebc]:
self.assertEqual(len(ebc.forward._args), 1)
self.assertEqual(ebc.forward._args[0].input_attrs, ["", 0])
self.assertEqual(ebc.forward._args[0].is_getitems, [False, True])
self.assertEqual(len(ebc.forward._args[0].postproc_modules), 2)
self.assertIsInstance(
ebc.forward._args[0].postproc_modules[0], PipelinedPostproc
)
self.assertEqual(ebc.forward._args[0].postproc_modules[1], None)

self.assertEqual(
pipelined_ebc.forward._args.args[0].steps,
[
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `_postproc_module`.
PostprocArgInfoStep(pipelined_model.module.postproc_nonweighted),
GetItemArgInfoStep(0),
],
pipelined_ebc.forward._args[0].postproc_modules[0],
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute
# `postproc_nonweighted`.
pipelined_model.module.postproc_nonweighted,
)
self.assertEqual(len(pipelined_weighted_ebc.forward._args.args), 1)
self.assertEqual(len(pipelined_weighted_ebc.forward._args.kwargs), 0)
self.assertEqual(
pipelined_weighted_ebc.forward._args.args[0].steps,
[
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `_postproc_module`.
PostprocArgInfoStep(pipelined_model.module.postproc_weighted),
GetItemArgInfoStep(0),
],
pipelined_weighted_ebc.forward._args[0].postproc_modules[0],
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute
# `postproc_weighted`.
pipelined_model.module.postproc_weighted,
)

# postproc args
self.assertEqual(len(pipeline._pipelined_postprocs), 3)

# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `_postproc_module`.
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute
# `_postproc_module`.
parent_postproc_mod = pipelined_model.module._postproc_module

for postproc_mod in pipeline._pipelined_postprocs:
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute
# `postproc_nonweighted`.
if postproc_mod == pipelined_model.module.postproc_nonweighted:
self.assertEqual(len(postproc_mod._args.args), 1)
self.assertEqual(len(postproc_mod._args.kwargs), 0)
self.assertEqual(len(postproc_mod._args), 1)
args = postproc_mod._args[0]
self.assertEqual(args.input_attrs, ["", "idlist_features"])
self.assertEqual(args.is_getitems, [False, False])
self.assertEqual(len(args.postproc_modules), 2)
self.assertEqual(
postproc_mod._args.args[0].steps,
[
PostprocArgInfoStep(parent_postproc_mod),
GetAttrArgInfoStep("idlist_features"),
],
args.postproc_modules[0],
parent_postproc_mod,
)

self.assertEqual(args.postproc_modules[1], None)
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute
# `postproc_weighted`.
elif postproc_mod == pipelined_model.module.postproc_weighted:
self.assertEqual(len(postproc_mod._args.args), 1)
self.assertEqual(len(postproc_mod._args.kwargs), 0)
self.assertEqual(len(postproc_mod._args), 1)
args = postproc_mod._args[0]
self.assertEqual(args.input_attrs, ["", "idscore_features"])
self.assertEqual(args.is_getitems, [False, False])
self.assertEqual(len(args.postproc_modules), 2)
self.assertEqual(
postproc_mod._args.args[0].steps,
[
PostprocArgInfoStep(parent_postproc_mod),
GetAttrArgInfoStep("idscore_features"),
],
args.postproc_modules[0],
parent_postproc_mod,
)
self.assertEqual(args.postproc_modules[1], None)
elif postproc_mod == parent_postproc_mod:
self.assertEqual(len(postproc_mod._args.args), 1)
self.assertEqual(len(postproc_mod._args.kwargs), 0)
self.assertEqual(postproc_mod._args.args[0].steps, [NoopArgInfoStep()])
self.assertEqual(len(postproc_mod._args), 1)
args = postproc_mod._args[0]
self.assertEqual(args.input_attrs, [""])
self.assertEqual(args.is_getitems, [False])
self.assertEqual(args.postproc_modules, [None])

# pyre-ignore
@unittest.skipIf(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,10 @@
TrainPipelineSparseDistTestBase,
)
from torchrec.distributed.train_pipeline.utils import (
_build_args_kwargs,
_get_node_args,
_rewrite_model,
ArgInfo,
ArgInfoStepFactory,
CallArgs,
NodeArgsHelper,
PipelinedForward,
PipelinedPostproc,
TrainPipelineContext,
Expand Down Expand Up @@ -111,19 +110,17 @@ def test_rewrite_model(self) -> None:
self.assertEqual(
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute
# `sparse`.
sharded_model.module.sparse.ebc.forward._args.args[0]
.steps[0]
.postproc_module,
sharded_model.module.sparse.ebc.forward._args[0].postproc_modules[0],
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute
# `postproc_module`.
sharded_model.module.postproc_module,
)
self.assertEqual(
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute
# `sparse`.
sharded_model.module.sparse.weighted_ebc.forward._args.args[0]
.steps[0]
.postproc_module,
sharded_model.module.sparse.weighted_ebc.forward._args[0].postproc_modules[
0
],
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute
# `postproc_module`.
sharded_model.module.postproc_module,
Expand Down Expand Up @@ -157,7 +154,7 @@ def forward(self, x):
rewritten_model.test_module = PipelinedPostproc(
postproc_module=rewritten_model.test_module,
fqn="test_module",
args=CallArgs(args=[], kwargs={}),
args=[],
context=TrainPipelineContext(),
default_stream=MagicMock(),
dist_stream=MagicMock(),
Expand Down Expand Up @@ -263,53 +260,83 @@ def test_restore_from_snapshot(self) -> None:
@parameterized.expand(
[
(
CallArgs(
args=[],
kwargs={
"id_list_features": ArgInfo(steps=[ArgInfoStepFactory.noop()]),
# Empty attrs to ignore any attr based logic.
"id_score_list_features": ArgInfo(
steps=[ArgInfoStepFactory.noop()]
),
},
),
[
# Empty attrs to ignore any attr based logic.
ArgInfo(
input_attrs=[
"",
],
is_getitems=[False],
postproc_modules=[None],
constants=[None],
name="id_list_features",
),
ArgInfo(
input_attrs=[],
is_getitems=[],
postproc_modules=[],
constants=[],
name="id_score_list_features",
),
],
0,
["id_list_features", "id_score_list_features"],
),
(
CallArgs(
args=[
# Empty attrs to ignore any attr based logic.
ArgInfo(steps=[ArgInfoStepFactory.noop()]),
ArgInfo(steps=[]),
],
kwargs={},
),
[
# Empty attrs to ignore any attr based logic.
ArgInfo(
input_attrs=[
"",
],
is_getitems=[False],
postproc_modules=[None],
constants=[None],
name=None,
),
ArgInfo(
input_attrs=[],
is_getitems=[],
postproc_modules=[],
constants=[],
name=None,
),
],
2,
[],
),
(
CallArgs(
args=[
# Empty attrs to ignore any attr based logic.
ArgInfo(
steps=[ArgInfoStepFactory.noop()],
)
],
kwargs={"id_score_list_features": ArgInfo(steps=[])},
),
[
# Empty attrs to ignore any attr based logic.
ArgInfo(
input_attrs=[
"",
],
is_getitems=[False],
postproc_modules=[None],
constants=[None],
name=None,
),
ArgInfo(
input_attrs=[],
is_getitems=[],
postproc_modules=[],
constants=[],
name="id_score_list_features",
),
],
1,
["id_score_list_features"],
),
]
)
def test_build_args_kwargs(
self,
fwd_args: CallArgs,
fwd_args: List[ArgInfo],
args_len: int,
kwarges_keys: List[str],
) -> None:
args, kwargs = fwd_args.build_args_kwargs("initial_input")
args, kwargs = _build_args_kwargs("initial_input", fwd_args)
self.assertEqual(len(args), args_len)
self.assertEqual(list(kwargs.keys()), kwarges_keys)

Expand Down Expand Up @@ -340,9 +367,10 @@ def test_get_node_args_helper_call_module_kjt(self) -> None:
{},
)

node_args_helper = NodeArgsHelper(MagicMock(), TrainPipelineContext(), False)

_, num_found = node_args_helper.get_node_args(kjt_node)
num_found = 0
_, num_found = _get_node_args(
MagicMock(), kjt_node, set(), TrainPipelineContext(), False
)

# Weights is call_module node, so we should only find 2 args unmodified
self.assertEqual(num_found, len(kjt_args) - 1)
Loading

0 comments on commit 315539b

Please sign in to comment.