Skip to content

Commit 9b778c5

Browse files
author
pytorchbot
committed
2024-12-31 nightly release (455de88)
1 parent 634a0a8 commit 9b778c5

8 files changed

+282
-218
lines changed

examples/retrieval/tests/test_two_tower_retrieval.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@ class InferTest(unittest.TestCase):
2121
@skip_if_asan
2222
# pyre-ignore[56]
2323
@unittest.skipIf(
24-
not torch.cuda.is_available(),
25-
"this test requires a GPU",
24+
torch.cuda.device_count() <= 1,
25+
"Not enough GPUs, this test requires at least two GPUs",
2626
)
2727
def test_infer_function(self) -> None:
2828
infer(

torchrec/distributed/batched_embedding_kernel.py

+47-10
Original file line numberDiff line numberDiff line change
@@ -900,7 +900,7 @@ def __init__(
900900
pg,
901901
)
902902
self._param_per_table: Dict[str, nn.Parameter] = dict(
903-
_gen_named_parameters_by_table_ssd(
903+
_gen_named_parameters_by_table_ssd_pmt(
904904
emb_module=self._emb_module,
905905
table_name_to_count=self.table_name_to_count.copy(),
906906
config=self._config,
@@ -933,11 +933,31 @@ def state_dict(
933933
destination: Optional[Dict[str, Any]] = None,
934934
prefix: str = "",
935935
keep_vars: bool = False,
936+
no_snapshot: bool = True,
936937
) -> Dict[str, Any]:
937-
if destination is None:
938-
destination = OrderedDict()
938+
"""
939+
Args:
940+
no_snapshot (bool): the tensors in the returned dict are
941+
PartiallyMaterializedTensors. this argument controls wether the
942+
PartiallyMaterializedTensor owns a RocksDB snapshot handle. True means the
943+
PartiallyMaterializedTensor doesn't have a RocksDB snapshot handle. False means the
944+
PartiallyMaterializedTensor has a RocksDB snapshot handle
945+
"""
946+
# in the case no_snapshot=False, a flush is required. we rely on the flush operation in
947+
# ShardedEmbeddingBagCollection._pre_state_dict_hook()
939948

940-
return destination
949+
emb_tables = self.split_embedding_weights(no_snapshot=no_snapshot)
950+
emb_table_config_copy = copy.deepcopy(self._config.embedding_tables)
951+
for emb_table in emb_table_config_copy:
952+
emb_table.local_metadata.placement._device = torch.device("cpu")
953+
ret = get_state_dict(
954+
emb_table_config_copy,
955+
emb_tables,
956+
self._pg,
957+
destination,
958+
prefix,
959+
)
960+
return ret
941961

942962
def named_parameters(
943963
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
@@ -950,14 +970,16 @@ def named_parameters(
950970
):
951971
# hack before we support optimizer on sharded parameter level
952972
# can delete after PEA deprecation
973+
# pyre-ignore [6]
953974
param = nn.Parameter(tensor)
954975
# pyre-ignore
955976
param._in_backward_optimizers = [EmptyFusedOptimizer()]
956977
yield name, param
957978

979+
# pyre-ignore [15]
958980
def named_split_embedding_weights(
959981
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
960-
) -> Iterator[Tuple[str, torch.Tensor]]:
982+
) -> Iterator[Tuple[str, PartiallyMaterializedTensor]]:
961983
assert (
962984
remove_duplicate
963985
), "remove_duplicate=False not supported in BaseBatchedEmbedding.named_split_embedding_weights"
@@ -968,6 +990,21 @@ def named_split_embedding_weights(
968990
key = append_prefix(prefix, f"{config.name}.weight")
969991
yield key, tensor
970992

993+
def get_named_split_embedding_weights_snapshot(
994+
self, prefix: str = ""
995+
) -> Iterator[Tuple[str, PartiallyMaterializedTensor]]:
996+
"""
997+
Return an iterator over embedding tables, yielding both the table name as well as the embedding
998+
table itself. The embedding table is in the form of PartiallyMaterializedTensor with a valid
999+
RocksDB snapshot to support windowed access.
1000+
"""
1001+
for config, tensor in zip(
1002+
self._config.embedding_tables,
1003+
self.split_embedding_weights(no_snapshot=False),
1004+
):
1005+
key = append_prefix(prefix, f"{config.name}")
1006+
yield key, tensor
1007+
9711008
def flush(self) -> None:
9721009
"""
9731010
Flush the embeddings in cache back to SSD. Should be pretty expensive.
@@ -982,11 +1019,11 @@ def purge(self) -> None:
9821019
self.emb_module.lxu_cache_weights.zero_()
9831020
self.emb_module.lxu_cache_state.fill_(-1)
9841021

985-
def split_embedding_weights(self) -> List[torch.Tensor]:
986-
"""
987-
Return fake tensors.
988-
"""
989-
return [param.data for param in self._param_per_table.values()]
1022+
# pyre-ignore [15]
1023+
def split_embedding_weights(
1024+
self, no_snapshot: bool = True
1025+
) -> List[PartiallyMaterializedTensor]:
1026+
return self.emb_module.split_embedding_weights(no_snapshot)
9901027

9911028

9921029
class BatchedFusedEmbedding(BaseBatchedEmbedding[torch.Tensor], FusedOptimizerModule):

0 commit comments

Comments
 (0)