Skip to content

Commit

Permalink
2025-03-07 nightly release (592ed93)
Browse files Browse the repository at this point in the history
  • Loading branch information
pytorchbot committed Mar 7, 2025
1 parent bae327d commit dd8636c
Show file tree
Hide file tree
Showing 7 changed files with 70 additions and 57 deletions.
33 changes: 23 additions & 10 deletions torchrec/distributed/batched_embedding_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def __init__(
state: Dict[Any, Any] = {}
param_group: Dict[str, Any] = {
"params": [],
"lr": emb_module.optimizer_args.learning_rate,
"lr": emb_module.optimizer_args.learning_rate_tensor,
}

params: Dict[str, Union[torch.Tensor, ShardedTensor]] = {}
Expand Down Expand Up @@ -383,7 +383,7 @@ def get_optimizer_pointwise_shard_metadata_and_global_metadata(
state: Dict[Any, Any] = {}
param_group: Dict[str, Any] = {
"params": [],
"lr": emb_module.optimizer_args.learning_rate,
"lr": emb_module.optimizer_args.learning_rate_tensor,
}

params: Dict[str, Union[torch.Tensor, ShardedTensor]] = {}
Expand Down Expand Up @@ -808,10 +808,16 @@ def init_parameters(self) -> None:
)

def forward(self, features: KeyedJaggedTensor) -> torch.Tensor:
return self.emb_module(
indices=features.values().long(),
offsets=features.offsets().long(),
)
indices, offsets = features.values(), features.offsets()

# If the data types of indices and offsets differ, cast the smaller type to the larger type
if indices.dtype != offsets.dtype:
if indices.element_size() < offsets.element_size():
indices = indices.to(offsets.dtype)
else:
offsets = offsets.to(indices.dtype)

return self.emb_module(indices=indices, offsets=offsets)

# pyre-fixme[14]: `state_dict` overrides method defined in `Module` inconsistently.
def state_dict(
Expand Down Expand Up @@ -1265,6 +1271,13 @@ def forward(self, features: KeyedJaggedTensor) -> torch.Tensor:
weights = features.weights_or_none()
if weights is not None and not torch.is_floating_point(weights):
weights = None
indices, offsets = features.values(), features.offsets()
# If the data types of indices and offsets differ, cast the smaller type to the larger type
if indices.dtype != offsets.dtype:
if indices.element_size() < offsets.element_size():
indices = indices.to(offsets.dtype)
else:
offsets = offsets.to(indices.dtype)
if features.variable_stride_per_key() and isinstance(
self.emb_module,
(
Expand All @@ -1274,15 +1287,15 @@ def forward(self, features: KeyedJaggedTensor) -> torch.Tensor:
),
):
return self.emb_module(
indices=features.values().long(),
offsets=features.offsets().long(),
indices=indices,
offsets=offsets,
per_sample_weights=weights,
batch_size_per_feature_per_rank=features.stride_per_key_per_rank(),
)
else:
return self.emb_module(
indices=features.values().long(),
offsets=features.offsets().long(),
indices=indices,
offsets=offsets,
per_sample_weights=weights,
)

Expand Down
49 changes: 32 additions & 17 deletions torchrec/distributed/model_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import abc
import copy
import logging as logger
from collections import OrderedDict
from collections import defaultdict, OrderedDict
from typing import Any, Callable, cast, Dict, Iterator, List, Optional, Set, Tuple, Type

import torch
Expand Down Expand Up @@ -754,42 +754,57 @@ def sync(self, include_optimizer_state: bool = True) -> None:
include_optimizer_state (bool): Flag to include optimizer state syncing upon call
"""
assert self._replica_pg is not None, "replica_pg is not initialized!"
all_weights: List[torch.Tensor] = [
w
for emb_kernel in self._modules_to_sync
all_weights_by_dtype: dict[torch.dtype, List[torch.Tensor]] = defaultdict(list)

for emb_kernel in self._modules_to_sync:
# pyre-fixme[29]: `Union[Module, Tensor]` is not a function.
for w in emb_kernel.split_embedding_weights()
]
for w in emb_kernel.split_embedding_weights():
all_weights_by_dtype[w.dtype].append(w)

opts = None
if self._custom_all_reduce is None:
opts = dist.AllreduceCoalescedOptions()
opts.reduceOp = dist.ReduceOp.AVG
self._allreduce_tensors(all_weights, opts)
self._allreduce_tensors(all_weights_by_dtype, opts)

if include_optimizer_state:
optimizer_tensors = []
optimizer_tensors_by_dtype: Dict[torch.dtype, List[torch.Tensor]] = (
defaultdict(list)
)
for emb_kernel in self._modules_to_sync:
# pyre-fixme[29]: `Union[Module, Tensor]` is not a function.
optimizer_states = emb_kernel.get_optimizer_state()
optimizer_tensors.extend([state["sum"] for state in optimizer_states])
if optimizer_tensors:
self._allreduce_tensors(optimizer_tensors, opts)
for state in optimizer_states:
opt_tensor = state["sum"]
optimizer_tensors_by_dtype[opt_tensor.dtype].append(opt_tensor)
if optimizer_tensors_by_dtype:
self._allreduce_tensors(optimizer_tensors_by_dtype, opts)

def _allreduce_tensors(
self,
tensors: List[torch.Tensor],
tensors_dict: Dict[torch.dtype, List[torch.Tensor]],
opts: Optional[dist.AllreduceCoalescedOptions] = None,
) -> None:
"""
Helper to perform all reduce on given tensors, uses custom all reduce function if provided
We perform all reduce per tensor dtype per collective constraints.
"""
if self._custom_all_reduce is not None:
# pyre-ignore[6]

def custom_all_reduce(tensors: List[torch.Tensor]) -> None:
# pyre-ignore[29]
self._custom_all_reduce(tensors)
else:
handle = self._replica_pg.allreduce_coalesced(tensors, opts=opts)
handle.wait()

def default_allreduce(tensor_list: List[torch.Tensor]) -> None:
self._replica_pg.allreduce_coalesced(tensor_list, opts=opts).wait()

allreduce = (
custom_all_reduce
if self._custom_all_reduce is not None
else default_allreduce
)

for tensor_list in tensors_dict.values():
allreduce(tensor_list)

def set_all_reduce_hook(
self,
Expand Down
3 changes: 2 additions & 1 deletion torchrec/distributed/test_utils/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,8 @@ def _validate_pooling_factor(
global_idlist_indices.append(indices)
global_idlist_offsets.append(offsets)

for idx, ind_range in enumerate(idscore_ind_ranges):
for idx in range(len(idscore_ind_ranges)):
ind_range = idscore_ind_ranges[idx]
lengths_ = torch.abs(
torch.randn(batch_size * world_size, device=device)
+ (
Expand Down
12 changes: 7 additions & 5 deletions torchrec/distributed/test_utils/test_model_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -916,7 +916,8 @@ def test_sharding_grid_8gpu(
)
# pyre-fixme[56]
@given(
dtype=st.sampled_from([torch.int32, torch.int64]),
index_dtype=st.sampled_from([torch.int32, torch.int64]),
offsets_dtype=st.sampled_from([torch.int32, torch.int64]),
use_offsets=st.booleans(),
sharder_type=st.sampled_from(
[
Expand All @@ -932,7 +933,8 @@ def test_sharding_grid_8gpu(
@settings(verbosity=Verbosity.verbose, max_examples=2, deadline=None)
def test_sharding_diff_table_index_type(
self,
dtype: torch.dtype,
index_dtype: torch.dtype,
offsets_dtype: torch.dtype,
use_offsets: bool,
sharder_type: str,
kernel_type: str,
Expand Down Expand Up @@ -960,7 +962,7 @@ def test_sharding_diff_table_index_type(
variable_batch_size=False,
pooling=PoolingType.SUM,
use_offsets=use_offsets,
indices_dtype=dtype,
offsets_dtype=dtype,
lengths_dtype=dtype,
indices_dtype=index_dtype,
offsets_dtype=offsets_dtype,
lengths_dtype=index_dtype,
)
22 changes: 4 additions & 18 deletions torchrec/distributed/test_utils/test_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,7 @@
ShardingPlan,
ShardingType,
)
from torchrec.modules.embedding_configs import (
BaseEmbeddingConfig,
DataType,
EmbeddingBagConfig,
)
from torchrec.modules.embedding_configs import BaseEmbeddingConfig, EmbeddingBagConfig
from torchrec.optim.keyed import CombinedOptimizer, KeyedOptimizerWrapper
from torchrec.optim.optimizers import in_backward_optimizer_filter

Expand Down Expand Up @@ -558,7 +554,9 @@ def _custom_hook(input: List[torch.Tensor]) -> None:
)

# Compare predictions of sharded vs unsharded models.
if qcomms_config is not None:
if qcomms_config is None:
torch.testing.assert_close(global_pred, torch.cat(all_local_pred))
else:
# With quantized comms, we can relax constraints a bit
rtol = 0.003
if CommType.FP8 in [
Expand All @@ -570,18 +568,6 @@ def _custom_hook(input: List[torch.Tensor]) -> None:
torch.testing.assert_close(
global_pred, torch.cat(all_local_pred), rtol=rtol, atol=atol
)
elif (
weighted_tables is not None
and weighted_tables[0].data_type == DataType.FP16
): # https://www.internalfb.com/intern/diffing/?paste_number=1740410921
torch.testing.assert_close(
global_pred,
torch.cat(all_local_pred),
atol=1e-4, # relaxed atol due to FP16 in weights
rtol=1e-4, # relaxed rtol due to FP16 in weights
)
else:
torch.testing.assert_close(global_pred, torch.cat(all_local_pred))


def create_device_mesh_for_2D(
Expand Down
6 changes: 1 addition & 5 deletions torchrec/modules/embedding_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,11 +248,7 @@ def forward(
res = embedding_bag(
input=f.values(),
offsets=f.offsets(),
per_sample_weights=(
f.weights().to(embedding_bag.weight.dtype)
if self._is_weighted
else None
),
per_sample_weights=f.weights() if self._is_weighted else None,
).float()
pooled_embeddings.append(res)
return KeyedTensor(
Expand Down
2 changes: 1 addition & 1 deletion torchrec/modules/fused_embedding_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def __init__( # noqa C901
state: Dict[Any, Any] = {}
param_group: Dict[str, Any] = {
"params": [],
"lr": emb_module.optimizer_args.learning_rate,
"lr": emb_module.optimizer_args.learning_rate_tensor,
}

params: Dict[str, torch.Tensor] = {}
Expand Down

0 comments on commit dd8636c

Please sign in to comment.