Skip to content

Commit 5fd2811

Browse files
author
pytorchbot
committed
2025-01-03 nightly release (00d8ed2)
1 parent d0e11e7 commit 5fd2811

8 files changed

+377
-319
lines changed

torchrec/distributed/embeddingbag.py

+12
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from torch.distributed._tensor import DTensor
3333
from torch.nn.modules.module import _IncompatibleKeys
3434
from torch.nn.parallel import DistributedDataParallel
35+
from torchrec.distributed.comm import get_local_size
3536
from torchrec.distributed.embedding_sharding import (
3637
EmbeddingSharding,
3738
EmbeddingShardingContext,
@@ -73,6 +74,7 @@
7374
add_params_from_parameter_sharding,
7475
append_prefix,
7576
convert_to_fbgemm_types,
77+
create_global_tensor_shape_stride_from_metadata,
7678
maybe_annotate_embedding_event,
7779
merge_fused_params,
7880
none_throws,
@@ -918,6 +920,14 @@ def _initialize_torch_state(self) -> None: # noqa
918920
)
919921
)
920922
else:
923+
shape, stride = create_global_tensor_shape_stride_from_metadata(
924+
none_throws(self.module_sharding_plan[table_name]),
925+
(
926+
self._env.node_group_size
927+
if isinstance(self._env, ShardingEnv2D)
928+
else get_local_size(self._env.world_size)
929+
),
930+
)
921931
# empty shard case
922932
self._model_parallel_name_to_dtensor[table_name] = (
923933
DTensor.from_local(
@@ -927,6 +937,8 @@ def _initialize_torch_state(self) -> None: # noqa
927937
),
928938
device_mesh=self._env.device_mesh,
929939
run_check=False,
940+
shape=shape,
941+
stride=stride,
930942
)
931943
)
932944
else:

torchrec/distributed/test_utils/test_model.py

+27-27
Original file line numberDiff line numberDiff line change
@@ -1192,7 +1192,7 @@ def __init__(
11921192
max_feature_lengths: Optional[Dict[str, int]] = None,
11931193
feature_processor_modules: Optional[Dict[str, torch.nn.Module]] = None,
11941194
over_arch_clazz: Type[nn.Module] = TestOverArch,
1195-
preproc_module: Optional[nn.Module] = None,
1195+
postproc_module: Optional[nn.Module] = None,
11961196
) -> None:
11971197
super().__init__(
11981198
tables=cast(List[BaseEmbeddingConfig], tables),
@@ -1229,7 +1229,7 @@ def __init__(
12291229
"dummy_ones",
12301230
torch.ones(1, device=dense_device),
12311231
)
1232-
self.preproc_module = preproc_module
1232+
self.postproc_module = postproc_module
12331233

12341234
def sparse_forward(self, input: ModelInput) -> KeyedTensor:
12351235
return self.sparse(
@@ -1256,8 +1256,8 @@ def forward(
12561256
self,
12571257
input: ModelInput,
12581258
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
1259-
if self.preproc_module:
1260-
input = self.preproc_module(input)
1259+
if self.postproc_module:
1260+
input = self.postproc_module(input)
12611261
return self.dense_forward(input, self.sparse_forward(input))
12621262

12631263

@@ -1749,18 +1749,18 @@ def forward(self, kjt: KeyedJaggedTensor) -> List[KeyedJaggedTensor]:
17491749

17501750
class TestModelWithPreproc(nn.Module):
17511751
"""
1752-
Basic module with up to 3 preproc modules:
1753-
- preproc on idlist_features for non-weighted EBC
1754-
- preproc on idscore_features for weighted EBC
1755-
- optional preproc on model input shared by both EBCs
1752+
Basic module with up to 3 postproc modules:
1753+
- postproc on idlist_features for non-weighted EBC
1754+
- postproc on idscore_features for weighted EBC
1755+
- optional postproc on model input shared by both EBCs
17561756
17571757
Args:
17581758
tables,
17591759
weighted_tables,
17601760
device,
1761-
preproc_module,
1761+
postproc_module,
17621762
num_float_features,
1763-
run_preproc_inline,
1763+
run_postproc_inline,
17641764
17651765
Example:
17661766
>>> TestModelWithPreproc(tables, weighted_tables, device)
@@ -1774,9 +1774,9 @@ def __init__(
17741774
tables: List[EmbeddingBagConfig],
17751775
weighted_tables: List[EmbeddingBagConfig],
17761776
device: torch.device,
1777-
preproc_module: Optional[nn.Module] = None,
1777+
postproc_module: Optional[nn.Module] = None,
17781778
num_float_features: int = 10,
1779-
run_preproc_inline: bool = False,
1779+
run_postproc_inline: bool = False,
17801780
) -> None:
17811781
super().__init__()
17821782
self.dense = TestDenseArch(num_float_features, device)
@@ -1790,17 +1790,17 @@ def __init__(
17901790
is_weighted=True,
17911791
device=device,
17921792
)
1793-
self.preproc_nonweighted = TestPreprocNonWeighted()
1794-
self.preproc_weighted = TestPreprocWeighted()
1795-
self._preproc_module = preproc_module
1796-
self._run_preproc_inline = run_preproc_inline
1793+
self.postproc_nonweighted = TestPreprocNonWeighted()
1794+
self.postproc_weighted = TestPreprocWeighted()
1795+
self._postproc_module = postproc_module
1796+
self._run_postproc_inline = run_postproc_inline
17971797

17981798
def forward(
17991799
self,
18001800
input: ModelInput,
18011801
) -> Tuple[torch.Tensor, torch.Tensor]:
18021802
"""
1803-
Runs preprco for EBC and weighted EBC, optionally runs preproc for input
1803+
Runs preprco for EBC and weighted EBC, optionally runs postproc for input
18041804
18051805
Args:
18061806
input
@@ -1809,20 +1809,20 @@ def forward(
18091809
"""
18101810
modified_input = input
18111811

1812-
if self._preproc_module is not None:
1813-
modified_input = self._preproc_module(modified_input)
1814-
elif self._run_preproc_inline:
1812+
if self._postproc_module is not None:
1813+
modified_input = self._postproc_module(modified_input)
1814+
elif self._run_postproc_inline:
18151815
idlist_features = modified_input.idlist_features
18161816
modified_input.idlist_features = KeyedJaggedTensor.from_lengths_sync(
18171817
idlist_features.keys(), # pyre-ignore [6]
18181818
idlist_features.values(), # pyre-ignore [6]
18191819
idlist_features.lengths(), # pyre-ignore [16]
18201820
)
18211821

1822-
modified_idlist_features = self.preproc_nonweighted(
1822+
modified_idlist_features = self.postproc_nonweighted(
18231823
modified_input.idlist_features
18241824
)
1825-
modified_idscore_features = self.preproc_weighted(
1825+
modified_idscore_features = self.postproc_weighted(
18261826
modified_input.idscore_features
18271827
)
18281828
ebc_out = self.ebc(modified_idlist_features[0])
@@ -1834,15 +1834,15 @@ def forward(
18341834

18351835
class TestNegSamplingModule(torch.nn.Module):
18361836
"""
1837-
Basic module to simulate feature augmentation preproc (e.g. neg sampling) for testing
1837+
Basic module to simulate feature augmentation postproc (e.g. neg sampling) for testing
18381838
18391839
Args:
18401840
extra_input
18411841
has_params
18421842
18431843
Example:
1844-
>>> preproc = TestNegSamplingModule(extra_input)
1845-
>>> out = preproc(in)
1844+
>>> postproc = TestNegSamplingModule(extra_input)
1845+
>>> out = postproc(in)
18461846
18471847
Returns:
18481848
ModelInput
@@ -1906,8 +1906,8 @@ class TestPositionWeightedPreprocModule(torch.nn.Module):
19061906
19071907
Args: None
19081908
Example:
1909-
>>> preproc = TestPositionWeightedPreprocModule(max_feature_lengths, device)
1910-
>>> out = preproc(in)
1909+
>>> postproc = TestPositionWeightedPreprocModule(max_feature_lengths, device)
1910+
>>> out = postproc(in)
19111911
Returns:
19121912
ModelInput
19131913
"""

0 commit comments

Comments
 (0)