|
29 | 29 | from fbgemm_gpu.permute_pooled_embedding_modules import PermutePooledEmbeddings
|
30 | 30 | from torch import distributed as dist, nn, Tensor
|
31 | 31 | from torch.autograd.profiler import record_function
|
| 32 | +from torch.distributed._shard.sharded_tensor import TensorProperties |
32 | 33 | from torch.distributed._tensor import DTensor
|
33 | 34 | from torch.nn.modules.module import _IncompatibleKeys
|
34 | 35 | from torch.nn.parallel import DistributedDataParallel
|
|
81 | 82 | optimizer_type_to_emb_opt_type,
|
82 | 83 | )
|
83 | 84 | from torchrec.modules.embedding_configs import (
|
| 85 | + data_type_to_dtype, |
84 | 86 | EmbeddingBagConfig,
|
85 | 87 | EmbeddingTableConfig,
|
86 | 88 | PoolingType,
|
@@ -945,17 +947,48 @@ def _initialize_torch_state(self) -> None: # noqa
|
945 | 947 | # created ShardedTensors once in init, use in post_state_dict_hook
|
946 | 948 | # note: at this point kvstore backed tensors don't own valid snapshots, so no read
|
947 | 949 | # access is allowed on them.
|
948 |
| - self._model_parallel_name_to_sharded_tensor[table_name] = ( |
949 |
| - ShardedTensor._init_from_local_shards( |
950 |
| - local_shards, |
951 |
| - self._name_to_table_size[table_name], |
952 |
| - process_group=( |
953 |
| - self._env.sharding_pg |
954 |
| - if isinstance(self._env, ShardingEnv2D) |
955 |
| - else self._env.process_group |
| 950 | + |
| 951 | + # create ShardedTensor from local shards and metadata avoding all_gather collective |
| 952 | + if self._construct_sharded_tensor_from_metadata: |
| 953 | + sharding_spec = none_throws( |
| 954 | + self.module_sharding_plan[table_name].sharding_spec |
| 955 | + ) |
| 956 | + |
| 957 | + tensor_properties = TensorProperties( |
| 958 | + dtype=( |
| 959 | + data_type_to_dtype( |
| 960 | + self._table_name_to_config[table_name].data_type |
| 961 | + ) |
956 | 962 | ),
|
957 | 963 | )
|
958 |
| - ) |
| 964 | + |
| 965 | + self._model_parallel_name_to_sharded_tensor[table_name] = ( |
| 966 | + ShardedTensor._init_from_local_shards_and_global_metadata( |
| 967 | + local_shards=local_shards, |
| 968 | + sharded_tensor_metadata=sharding_spec.build_metadata( |
| 969 | + tensor_sizes=self._name_to_table_size[table_name], |
| 970 | + tensor_properties=tensor_properties, |
| 971 | + ), |
| 972 | + process_group=( |
| 973 | + self._env.sharding_pg |
| 974 | + if isinstance(self._env, ShardingEnv2D) |
| 975 | + else self._env.process_group |
| 976 | + ), |
| 977 | + ) |
| 978 | + ) |
| 979 | + else: |
| 980 | + # create ShardedTensor from local shards using all_gather collective |
| 981 | + self._model_parallel_name_to_sharded_tensor[table_name] = ( |
| 982 | + ShardedTensor._init_from_local_shards( |
| 983 | + local_shards, |
| 984 | + self._name_to_table_size[table_name], |
| 985 | + process_group=( |
| 986 | + self._env.sharding_pg |
| 987 | + if isinstance(self._env, ShardingEnv2D) |
| 988 | + else self._env.process_group |
| 989 | + ), |
| 990 | + ) |
| 991 | + ) |
959 | 992 |
|
960 | 993 | def extract_sharded_kvtensors(
|
961 | 994 | module: ShardedEmbeddingBagCollection,
|
|
0 commit comments