Skip to content

Commit 06c86ea

Browse files
Back out "fix flaky test due to input_jkt.weight dtype" (#2784)
Summary: Original commit changeset: 52fc46ced5a3 Original Phabricator Diff: D70126859 To fix failures like this one - f703883945 Differential Revision: D70706946
1 parent 959ede5 commit 06c86ea

File tree

3 files changed

+7
-24
lines changed

3 files changed

+7
-24
lines changed

torchrec/distributed/test_utils/test_model.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,8 @@ def _validate_pooling_factor(
243243
global_idlist_indices.append(indices)
244244
global_idlist_offsets.append(offsets)
245245

246-
for idx, ind_range in enumerate(idscore_ind_ranges):
246+
for idx in range(len(idscore_ind_ranges)):
247+
ind_range = idscore_ind_ranges[idx]
247248
lengths_ = torch.abs(
248249
torch.randn(batch_size * world_size, device=device)
249250
+ (

torchrec/distributed/test_utils/test_sharding.py

+4-18
Original file line numberDiff line numberDiff line change
@@ -59,11 +59,7 @@
5959
ShardingPlan,
6060
ShardingType,
6161
)
62-
from torchrec.modules.embedding_configs import (
63-
BaseEmbeddingConfig,
64-
DataType,
65-
EmbeddingBagConfig,
66-
)
62+
from torchrec.modules.embedding_configs import BaseEmbeddingConfig, EmbeddingBagConfig
6763
from torchrec.optim.keyed import CombinedOptimizer, KeyedOptimizerWrapper
6864
from torchrec.optim.optimizers import in_backward_optimizer_filter
6965

@@ -558,7 +554,9 @@ def _custom_hook(input: List[torch.Tensor]) -> None:
558554
)
559555

560556
# Compare predictions of sharded vs unsharded models.
561-
if qcomms_config is not None:
557+
if qcomms_config is None:
558+
torch.testing.assert_close(global_pred, torch.cat(all_local_pred))
559+
else:
562560
# With quantized comms, we can relax constraints a bit
563561
rtol = 0.003
564562
if CommType.FP8 in [
@@ -570,18 +568,6 @@ def _custom_hook(input: List[torch.Tensor]) -> None:
570568
torch.testing.assert_close(
571569
global_pred, torch.cat(all_local_pred), rtol=rtol, atol=atol
572570
)
573-
elif (
574-
weighted_tables is not None
575-
and weighted_tables[0].data_type == DataType.FP16
576-
): # https://www.internalfb.com/intern/diffing/?paste_number=1740410921
577-
torch.testing.assert_close(
578-
global_pred,
579-
torch.cat(all_local_pred),
580-
atol=1e-4, # relaxed atol due to FP16 in weights
581-
rtol=1e-4, # relaxed rtol due to FP16 in weights
582-
)
583-
else:
584-
torch.testing.assert_close(global_pred, torch.cat(all_local_pred))
585571

586572

587573
def create_device_mesh_for_2D(

torchrec/modules/embedding_modules.py

+1-5
Original file line numberDiff line numberDiff line change
@@ -248,11 +248,7 @@ def forward(
248248
res = embedding_bag(
249249
input=f.values(),
250250
offsets=f.offsets(),
251-
per_sample_weights=(
252-
f.weights().to(embedding_bag.weight.dtype)
253-
if self._is_weighted
254-
else None
255-
),
251+
per_sample_weights=f.weights() if self._is_weighted else None,
256252
).float()
257253
pooled_embeddings.append(res)
258254
return KeyedTensor(

0 commit comments

Comments
 (0)