@@ -900,7 +900,7 @@ def __init__(
900
900
pg ,
901
901
)
902
902
self ._param_per_table : Dict [str , nn .Parameter ] = dict (
903
- _gen_named_parameters_by_table_ssd (
903
+ _gen_named_parameters_by_table_ssd_pmt (
904
904
emb_module = self ._emb_module ,
905
905
table_name_to_count = self .table_name_to_count .copy (),
906
906
config = self ._config ,
@@ -933,11 +933,31 @@ def state_dict(
933
933
destination : Optional [Dict [str , Any ]] = None ,
934
934
prefix : str = "" ,
935
935
keep_vars : bool = False ,
936
+ no_snapshot : bool = True ,
936
937
) -> Dict [str , Any ]:
937
- if destination is None :
938
- destination = OrderedDict ()
938
+ """
939
+ Args:
940
+ no_snapshot (bool): the tensors in the returned dict are
941
+ PartiallyMaterializedTensors. this argument controls wether the
942
+ PartiallyMaterializedTensor owns a RocksDB snapshot handle. True means the
943
+ PartiallyMaterializedTensor doesn't have a RocksDB snapshot handle. False means the
944
+ PartiallyMaterializedTensor has a RocksDB snapshot handle
945
+ """
946
+ # in the case no_snapshot=False, a flush is required. we rely on the flush operation in
947
+ # ShardedEmbeddingBagCollection._pre_state_dict_hook()
939
948
940
- return destination
949
+ emb_tables = self .split_embedding_weights (no_snapshot = no_snapshot )
950
+ emb_table_config_copy = copy .deepcopy (self ._config .embedding_tables )
951
+ for emb_table in emb_table_config_copy :
952
+ emb_table .local_metadata .placement ._device = torch .device ("cpu" )
953
+ ret = get_state_dict (
954
+ emb_table_config_copy ,
955
+ emb_tables ,
956
+ self ._pg ,
957
+ destination ,
958
+ prefix ,
959
+ )
960
+ return ret
941
961
942
962
def named_parameters (
943
963
self , prefix : str = "" , recurse : bool = True , remove_duplicate : bool = True
@@ -950,14 +970,16 @@ def named_parameters(
950
970
):
951
971
# hack before we support optimizer on sharded parameter level
952
972
# can delete after PEA deprecation
973
+ # pyre-ignore [6]
953
974
param = nn .Parameter (tensor )
954
975
# pyre-ignore
955
976
param ._in_backward_optimizers = [EmptyFusedOptimizer ()]
956
977
yield name , param
957
978
979
+ # pyre-ignore [15]
958
980
def named_split_embedding_weights (
959
981
self , prefix : str = "" , recurse : bool = True , remove_duplicate : bool = True
960
- ) -> Iterator [Tuple [str , torch . Tensor ]]:
982
+ ) -> Iterator [Tuple [str , PartiallyMaterializedTensor ]]:
961
983
assert (
962
984
remove_duplicate
963
985
), "remove_duplicate=False not supported in BaseBatchedEmbedding.named_split_embedding_weights"
@@ -968,6 +990,21 @@ def named_split_embedding_weights(
968
990
key = append_prefix (prefix , f"{ config .name } .weight" )
969
991
yield key , tensor
970
992
993
+ def get_named_split_embedding_weights_snapshot (
994
+ self , prefix : str = ""
995
+ ) -> Iterator [Tuple [str , PartiallyMaterializedTensor ]]:
996
+ """
997
+ Return an iterator over embedding tables, yielding both the table name as well as the embedding
998
+ table itself. The embedding table is in the form of PartiallyMaterializedTensor with a valid
999
+ RocksDB snapshot to support windowed access.
1000
+ """
1001
+ for config , tensor in zip (
1002
+ self ._config .embedding_tables ,
1003
+ self .split_embedding_weights (no_snapshot = False ),
1004
+ ):
1005
+ key = append_prefix (prefix , f"{ config .name } " )
1006
+ yield key , tensor
1007
+
971
1008
def flush (self ) -> None :
972
1009
"""
973
1010
Flush the embeddings in cache back to SSD. Should be pretty expensive.
@@ -982,11 +1019,11 @@ def purge(self) -> None:
982
1019
self .emb_module .lxu_cache_weights .zero_ ()
983
1020
self .emb_module .lxu_cache_state .fill_ (- 1 )
984
1021
985
- def split_embedding_weights ( self ) -> List [ torch . Tensor ]:
986
- """
987
- Return fake tensors.
988
- """
989
- return [ param . data for param in self ._param_per_table . values ()]
1022
+ # pyre-ignore [15]
1023
+ def split_embedding_weights (
1024
+ self , no_snapshot : bool = True
1025
+ ) -> List [ PartiallyMaterializedTensor ]:
1026
+ return self .emb_module . split_embedding_weights ( no_snapshot )
990
1027
991
1028
992
1029
class BatchedFusedEmbedding (BaseBatchedEmbedding [torch .Tensor ], FusedOptimizerModule ):
0 commit comments