Skip to content

Commit fdc60ae

Browse files
author
pytorchbot
committed
2025-01-07 nightly release (6f4bfe2)
1 parent be0f3db commit fdc60ae

8 files changed

+207
-44
lines changed

torchrec/distributed/embedding_types.py

+8
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@
2020
from torch.distributed._tensor.placement_types import Placement
2121
from torch.nn.modules.module import _addindent
2222
from torch.nn.parallel import DistributedDataParallel
23+
from torchrec.distributed.global_settings import (
24+
construct_sharded_tensor_from_metadata_enabled,
25+
)
2326
from torchrec.distributed.types import (
2427
get_tensor_size_bytes,
2528
ModuleSharder,
@@ -343,6 +346,11 @@ def __init__(
343346
self._lookups: List[nn.Module] = []
344347
self._output_dists: List[nn.Module] = []
345348

349+
# option to construct ShardedTensor from metadata avoiding expensive all-gather
350+
self._construct_sharded_tensor_from_metadata: bool = (
351+
construct_sharded_tensor_from_metadata_enabled()
352+
)
353+
346354
def prefetch(
347355
self,
348356
dist_input: KJTList,

torchrec/distributed/embeddingbag.py

+42-9
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from fbgemm_gpu.permute_pooled_embedding_modules import PermutePooledEmbeddings
3030
from torch import distributed as dist, nn, Tensor
3131
from torch.autograd.profiler import record_function
32+
from torch.distributed._shard.sharded_tensor import TensorProperties
3233
from torch.distributed._tensor import DTensor
3334
from torch.nn.modules.module import _IncompatibleKeys
3435
from torch.nn.parallel import DistributedDataParallel
@@ -81,6 +82,7 @@
8182
optimizer_type_to_emb_opt_type,
8283
)
8384
from torchrec.modules.embedding_configs import (
85+
data_type_to_dtype,
8486
EmbeddingBagConfig,
8587
EmbeddingTableConfig,
8688
PoolingType,
@@ -945,17 +947,48 @@ def _initialize_torch_state(self) -> None: # noqa
945947
# created ShardedTensors once in init, use in post_state_dict_hook
946948
# note: at this point kvstore backed tensors don't own valid snapshots, so no read
947949
# access is allowed on them.
948-
self._model_parallel_name_to_sharded_tensor[table_name] = (
949-
ShardedTensor._init_from_local_shards(
950-
local_shards,
951-
self._name_to_table_size[table_name],
952-
process_group=(
953-
self._env.sharding_pg
954-
if isinstance(self._env, ShardingEnv2D)
955-
else self._env.process_group
950+
951+
# create ShardedTensor from local shards and metadata avoding all_gather collective
952+
if self._construct_sharded_tensor_from_metadata:
953+
sharding_spec = none_throws(
954+
self.module_sharding_plan[table_name].sharding_spec
955+
)
956+
957+
tensor_properties = TensorProperties(
958+
dtype=(
959+
data_type_to_dtype(
960+
self._table_name_to_config[table_name].data_type
961+
)
956962
),
957963
)
958-
)
964+
965+
self._model_parallel_name_to_sharded_tensor[table_name] = (
966+
ShardedTensor._init_from_local_shards_and_global_metadata(
967+
local_shards=local_shards,
968+
sharded_tensor_metadata=sharding_spec.build_metadata(
969+
tensor_sizes=self._name_to_table_size[table_name],
970+
tensor_properties=tensor_properties,
971+
),
972+
process_group=(
973+
self._env.sharding_pg
974+
if isinstance(self._env, ShardingEnv2D)
975+
else self._env.process_group
976+
),
977+
)
978+
)
979+
else:
980+
# create ShardedTensor from local shards using all_gather collective
981+
self._model_parallel_name_to_sharded_tensor[table_name] = (
982+
ShardedTensor._init_from_local_shards(
983+
local_shards,
984+
self._name_to_table_size[table_name],
985+
process_group=(
986+
self._env.sharding_pg
987+
if isinstance(self._env, ShardingEnv2D)
988+
else self._env.process_group
989+
),
990+
)
991+
)
959992

960993
def extract_sharded_kvtensors(
961994
module: ShardedEmbeddingBagCollection,

torchrec/distributed/global_settings.py

+12
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,14 @@
77

88
# pyre-strict
99

10+
import os
11+
1012
PROPOGATE_DEVICE: bool = False
1113

14+
TORCHREC_CONSTRUCT_SHARDED_TENSOR_FROM_METADATA_ENV = (
15+
"TORCHREC_CONSTRUCT_SHARDED_TENSOR_FROM_METADATA"
16+
)
17+
1218

1319
def set_propogate_device(val: bool) -> None:
1420
global PROPOGATE_DEVICE
@@ -18,3 +24,9 @@ def set_propogate_device(val: bool) -> None:
1824
def get_propogate_device() -> bool:
1925
global PROPOGATE_DEVICE
2026
return PROPOGATE_DEVICE
27+
28+
29+
def construct_sharded_tensor_from_metadata_enabled() -> bool:
30+
return (
31+
os.environ.get(TORCHREC_CONSTRUCT_SHARDED_TENSOR_FROM_METADATA_ENV, "0") == "1"
32+
)

0 commit comments

Comments
 (0)