Skip to content

Commit 6716937

Browse files
che-shfacebook-github-bot
authored andcommitted
Support model buffers as pipeline postproc inputs (#2769)
Summary: Torchrec rewriting logic got a bit hairy over the years, this sequence of changes aims to refactor the rewrite logic to be less convoluted and more maintainable in the future. This change: Splits monolithic ArgInfoStep into multiple classes, each handling single potential operation (+minimum data necessary to perform it). Internal Diff stack navigation: 1. D69292525 and below - before refactoring 2. D69438143 - Refactor get_node_args and friends into a class 3. D69461227 - refactor "joint lists" in ArgInfo into a list of ArgInfoStep 4. D69461226 - refactor `_build_args_kwargs` into instance methods on ArgInfo and ArgInfoStep 5. D69461228 - split monolithic `ArgInfoStep` into a class hierarchy 6. D69764721 - enable buffers as preproc arguments (**you are here**) Differential Revision: D69764721
1 parent 0037319 commit 6716937

File tree

3 files changed

+194
-2
lines changed

3 files changed

+194
-2
lines changed

torchrec/distributed/test_utils/test_model.py

+120
Original file line numberDiff line numberDiff line change
@@ -1989,6 +1989,126 @@ def forward(self, kjt: KeyedJaggedTensor) -> List[KeyedJaggedTensor]:
19891989
]
19901990

19911991

1992+
class TestPreprocForModelWithBuffer(nn.Module):
1993+
"""
1994+
Basic module for testing
1995+
1996+
Args: None
1997+
Examples:
1998+
>>> TestPreprocForModelWithBuffer()
1999+
Returns:
2000+
List[KeyedJaggedTensor
2001+
"""
2002+
2003+
def forward(
2004+
self, kjt: KeyedJaggedTensor, buffer: torch.Tensor
2005+
) -> List[KeyedJaggedTensor]:
2006+
"""
2007+
Selects 3 features from a specific KJT and concatenates
2008+
them with KJT derived from a given buffer
2009+
"""
2010+
# split
2011+
jt_0 = kjt["feature_0"]
2012+
jt_1 = kjt["feature_1"]
2013+
jt_2 = kjt["feature_2"]
2014+
2015+
kjt_from_buffer = KeyedJaggedTensor.from_lengths_sync(
2016+
["feature_0"],
2017+
buffer,
2018+
torch.ones(buffer.size(), dtype=torch.int32, device=buffer.device),
2019+
)
2020+
2021+
# merge only features 0,1,2, removing feature 3
2022+
kjt_projection = KeyedJaggedTensor.from_jt_dict(
2023+
{
2024+
"feature_0": jt_0,
2025+
"feature_1": jt_1,
2026+
"feature_2": jt_2,
2027+
}
2028+
)
2029+
2030+
return [
2031+
KeyedJaggedTensor.concat(
2032+
[
2033+
kjt_projection,
2034+
kjt_from_buffer,
2035+
]
2036+
)
2037+
]
2038+
2039+
2040+
class TestModelWithBuffer(nn.Module):
2041+
"""
2042+
Basic module that has a postproc that takes a buffer as input
2043+
2044+
2045+
Args:
2046+
tables,
2047+
weighted_tables,
2048+
device,
2049+
buffer_size,
2050+
num_float_features,
2051+
2052+
Example:
2053+
>>> TestModelWithBuffer(tables, weighted_tables, device, 100)
2054+
2055+
Returns:
2056+
Tuple[torch.Tensor, torch.Tensor]
2057+
"""
2058+
2059+
def __init__(
2060+
self,
2061+
tables: List[EmbeddingBagConfig],
2062+
weighted_tables: List[EmbeddingBagConfig],
2063+
device: torch.device,
2064+
buffer_size: int,
2065+
num_float_features: int = 10,
2066+
) -> None:
2067+
super().__init__()
2068+
self.dense = TestDenseArch(num_float_features, device)
2069+
2070+
self.ebc: EmbeddingBagCollection = EmbeddingBagCollection(
2071+
tables=tables,
2072+
device=device,
2073+
)
2074+
self.weighted_ebc = EmbeddingBagCollection(
2075+
tables=weighted_tables,
2076+
is_weighted=True,
2077+
device=device,
2078+
)
2079+
max_index = tables[0].num_embeddings
2080+
self._postproc_module = TestPreprocForModelWithBuffer()
2081+
self.register_buffer(
2082+
"_buffer",
2083+
torch.randint(0, max_index, (buffer_size,), device=device),
2084+
persistent=False,
2085+
)
2086+
2087+
def forward(
2088+
self,
2089+
input: ModelInput,
2090+
) -> Tuple[torch.Tensor, torch.Tensor]:
2091+
"""
2092+
Runs preprco for EBC and weighted EBC, optionally runs postproc for input
2093+
2094+
Args:
2095+
input
2096+
Returns:
2097+
Tuple[torch.Tensor, torch.Tensor]
2098+
"""
2099+
modified_input = input
2100+
2101+
modified_input.idlist_features = self._postproc_module(
2102+
modified_input.idlist_features, self._buffer
2103+
)
2104+
2105+
ebc_out = self.ebc(modified_input.idlist_features[0])
2106+
weighted_ebc_out = self.weighted_ebc(modified_input.idscore_features)
2107+
2108+
pred = torch.cat([ebc_out.values(), weighted_ebc_out.values()], dim=1)
2109+
return pred.sum(), pred
2110+
2111+
19922112
class TestModelWithPreproc(nn.Module):
19932113
"""
19942114
Basic module with up to 3 postproc modules:

torchrec/distributed/train_pipeline/tests/test_train_pipelines.py

+25-1
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from torchrec.distributed.test_utils.test_model import (
3737
ModelInput,
3838
TestEBCSharder,
39+
TestModelWithBuffer,
3940
TestModelWithPreproc,
4041
TestModelWithPreprocCollectionArgs,
4142
TestNegSamplingModule,
@@ -1459,6 +1460,30 @@ def forward(
14591460
self.assertEqual(len(pipeline._pipelined_modules), 2)
14601461
self.assertEqual(len(pipeline._pipelined_postprocs), 1)
14611462

1463+
# pyre-ignore
1464+
@unittest.skipIf(
1465+
not torch.cuda.is_available(),
1466+
"Not enough GPUs, this test requires at least one GPU",
1467+
)
1468+
def test_postproc_with_buffer_arg(self) -> None:
1469+
"""
1470+
If postproc module is nested, we should still be able to pipeline it
1471+
"""
1472+
model = TestModelWithBuffer(
1473+
tables=self.tables[:-1], # ignore last table as postproc will remove
1474+
weighted_tables=self.weighted_tables[:-1], # ignore last table
1475+
device=self.device,
1476+
buffer_size=self.batch_size,
1477+
)
1478+
pipelined_model, pipeline = self._check_output_equal(
1479+
model,
1480+
self.sharding_type,
1481+
)
1482+
1483+
# Check that both EC and EBC pipelined
1484+
self.assertEqual(len(pipeline._pipelined_modules), 2)
1485+
self.assertEqual(len(pipeline._pipelined_postprocs), 1)
1486+
14621487
# pyre-ignore
14631488
@unittest.skipIf(
14641489
not torch.cuda.is_available(),
@@ -1469,7 +1494,6 @@ def test_pipeline_postproc_with_collection_args(self) -> None:
14691494
Exercises scenario when postproc module has an argument that is a list or dict
14701495
with some elements being:
14711496
* static scalars
1472-
* static tensors (e.g. torch.ones())
14731497
* tensors derived from input batch (e.g. input.idlist_features["feature_0"])
14741498
* tensors derived from input batch and other postproc module (e.g. other_postproc(input.idlist_features["feature_0"]))
14751499
"""

torchrec/distributed/train_pipeline/utils.py

+49-1
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,30 @@ def process(self, arg) -> Any:
270270
}
271271

272272

273+
class ModuleAttributeArgInfoStep(BaseArgInfoStep):
274+
def __init__(self, module: torch.nn.Module, fqn: str) -> None:
275+
super().__init__()
276+
self.module = module
277+
self.fqn = fqn
278+
279+
@classmethod
280+
def validate(cls, module: torch.nn.Module, fqn: str) -> None:
281+
fqn_parts = fqn.split(".")
282+
current = module
283+
for step in fqn_parts:
284+
if not hasattr(current, step):
285+
raise ValueError(f"Module {module} does not have attribute {fqn}")
286+
current = getattr(current, step)
287+
288+
# pyre-ignore
289+
def process(self, _arg) -> Any:
290+
fqn_parts = self.fqn.split(".")
291+
current = self.module
292+
for step in fqn_parts:
293+
current = getattr(current, step)
294+
return current
295+
296+
273297
class ArgInfoStepFactory:
274298
"""
275299
Convenience class to reduce the amount of imports the external uses will have.
@@ -306,6 +330,13 @@ def from_list(cls, value: List[object]) -> ListArgInfoStep:
306330
def from_dict(cls, value: Dict[str, object]) -> DictArgInfoStep:
307331
return DictArgInfoStep(value)
308332

333+
@classmethod
334+
def from_module_attr(
335+
cls, module: torch.nn.Module, fqn: str
336+
) -> ModuleAttributeArgInfoStep:
337+
ModuleAttributeArgInfoStep.validate(module, fqn)
338+
return ModuleAttributeArgInfoStep(module, fqn)
339+
309340

310341
@dataclass
311342
class ArgInfo:
@@ -1134,6 +1165,19 @@ def _handle_placeholder(
11341165
arg_info.add_step(ArgInfoStepFactory.noop())
11351166
return arg_info
11361167

1168+
def _handle_module_get_attr(
1169+
self,
1170+
fqn: str,
1171+
arg_info: ArgInfo,
1172+
) -> ArgInfo:
1173+
# get_attr calls always carry FQN from model root
1174+
# NOTE: the first argument essentially creates a "closure" over the model
1175+
# so things might get hairy if the model kept in self._model is
1176+
# later discarded; however so far no training pipeline do that.
1177+
step = ArgInfoStepFactory.from_module_attr(self._model, fqn)
1178+
arg_info.add_step(step)
1179+
return arg_info
1180+
11371181
def _handle_module(
11381182
self, child_node: torch.fx.Node, arg_info: ArgInfo
11391183
) -> Optional[ArgInfo]:
@@ -1228,7 +1272,11 @@ def _get_node_args_helper_inner(
12281272
elif child_node.op == "call_module":
12291273
return self._handle_module(arg, arg_info)
12301274
elif (
1231-
child_node.op == "call_function"
1275+
child_node.op == "get_attr":
1276+
# pyre-fixme[9]: arg.target is a fqn string for get_attr op
1277+
module_fqn: str = arg.target
1278+
return self._handle_module_get_attr(module_fqn, arg_info)
1279+
elif child_node.op == "call_function"
12321280
and child_node.target.__module__ == "builtins"
12331281
# pyre-fixme[16]
12341282
and child_node.target.__name__ == "getattr"

0 commit comments

Comments
 (0)