Skip to content

Commit 35b14b0

Browse files
TroyGardenfacebook-github-bot
authored andcommitted
reland D70126859 (#2787)
Summary: Pull Request resolved: #2787 # context * previous diff triggered S495021 * the error message is like ``` ModelGenerationPlatformError("AttributeError: '_EmbeddingBagProxy' object has no attribute 'weight'") ``` * This is because in some flow the EBC module is fx traced so there is no actual EBC but a Proxy. Without full context it's risky to push this change. * as a workaround, we'll just convert the unsharded EBC back to float32 so it's compatible with the input KJT.weight of float32 NOTE: this hacky change (unsharded EBC float16 ==> float32) is only needed in the tests, where we want to compare the results from sharded EBC. WARNING: We make a strong assumption here that in any unsharded EBC (with dtype=float16) use case, the input KJT.weights should never be float32. Reviewed By: basilwong Differential Revision: D70712348 fbshipit-source-id: f2abaa601adf3052ea322cf326363da8bfef96c3
1 parent df419e9 commit 35b14b0

File tree

3 files changed

+42
-22
lines changed

3 files changed

+42
-22
lines changed

torchrec/distributed/test_utils/test_model.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -243,8 +243,7 @@ def _validate_pooling_factor(
243243
global_idlist_indices.append(indices)
244244
global_idlist_offsets.append(offsets)
245245

246-
for idx in range(len(idscore_ind_ranges)):
247-
ind_range = idscore_ind_ranges[idx]
246+
for idx, ind_range in enumerate(idscore_ind_ranges):
248247
lengths_ = torch.abs(
249248
torch.randn(batch_size * world_size, device=device)
250249
+ (

torchrec/distributed/test_utils/test_model_parallel.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,7 @@ def test_sharding_rw(
290290
data_type=st.sampled_from([DataType.FP32, DataType.FP16]),
291291
# TODO - need to enable optimizer overlapped behavior for data_parallel tables
292292
)
293-
@settings(verbosity=Verbosity.verbose, max_examples=1, deadline=None)
293+
@settings(verbosity=Verbosity.verbose, max_examples=2, deadline=None)
294294
def test_sharding_dp(
295295
self,
296296
sharder_type: str,
@@ -429,7 +429,7 @@ def test_sharding_cw(
429429
variable_batch_size=st.booleans(),
430430
data_type=st.sampled_from([DataType.FP32, DataType.FP16]),
431431
)
432-
@settings(verbosity=Verbosity.verbose, max_examples=3, deadline=None)
432+
@settings(verbosity=Verbosity.verbose, max_examples=10, deadline=None)
433433
def test_sharding_twcw(
434434
self,
435435
sharder_type: str,
@@ -510,7 +510,7 @@ def test_sharding_twcw(
510510
variable_batch_size=st.booleans(),
511511
data_type=st.sampled_from([DataType.FP32, DataType.FP16]),
512512
)
513-
@settings(verbosity=Verbosity.verbose, max_examples=3, deadline=None)
513+
@settings(verbosity=Verbosity.verbose, max_examples=10, deadline=None)
514514
def test_sharding_tw(
515515
self,
516516
sharder_type: str,
@@ -592,7 +592,7 @@ def test_sharding_tw(
592592
pooling=st.sampled_from([PoolingType.SUM, PoolingType.MEAN]),
593593
data_type=st.sampled_from([DataType.FP32, DataType.FP16]),
594594
)
595-
@settings(verbosity=Verbosity.verbose, max_examples=6, deadline=None)
595+
@settings(verbosity=Verbosity.verbose, max_examples=10, deadline=None)
596596
def test_sharding_twrw(
597597
self,
598598
sharder_type: str,

torchrec/distributed/test_utils/test_sharding.py

+37-16
Original file line numberDiff line numberDiff line change
@@ -9,18 +9,7 @@
99

1010
import random
1111
from enum import Enum
12-
from typing import (
13-
Any,
14-
Callable,
15-
cast,
16-
Dict,
17-
List,
18-
Optional,
19-
Protocol,
20-
Tuple,
21-
Type,
22-
Union,
23-
)
12+
from typing import Any, cast, Dict, List, Optional, Protocol, Tuple, Type, Union
2413

2514
import torch
2615
import torch.distributed as dist
@@ -59,7 +48,12 @@
5948
ShardingPlan,
6049
ShardingType,
6150
)
62-
from torchrec.modules.embedding_configs import BaseEmbeddingConfig, EmbeddingBagConfig
51+
from torchrec.modules.embedding_configs import (
52+
BaseEmbeddingConfig,
53+
DataType,
54+
EmbeddingBagConfig,
55+
)
56+
from torchrec.modules.embedding_modules import EmbeddingBagCollection
6357
from torchrec.optim.keyed import CombinedOptimizer, KeyedOptimizerWrapper
6458
from torchrec.optim.optimizers import in_backward_optimizer_filter
6559

@@ -329,6 +323,15 @@ def copy_state_dict(
329323
tensor.copy_(global_tensor)
330324

331325

326+
# alter the ebc dtype to float32 in-place.
327+
def alter_global_ebc_dtype(model: nn.Module) -> None:
328+
for _name, ebc in model.named_modules():
329+
if isinstance(ebc, EmbeddingBagCollection) and ebc._is_weighted:
330+
with torch.no_grad():
331+
for bag in ebc.embedding_bags.values():
332+
bag.weight = torch.nn.Parameter(bag.weight.float())
333+
334+
332335
def sharding_single_rank_test(
333336
rank: int,
334337
world_size: int,
@@ -527,6 +530,7 @@ def _custom_hook(input: List[torch.Tensor]) -> None:
527530
global_model.state_dict(),
528531
exclude_predfix="sparse.pooled_embedding_arch.embedding_modules._itp_iter",
529532
)
533+
alter_global_ebc_dtype(global_model)
530534

531535
# Run a single training step of the sharded model.
532536
local_pred = gen_full_pred_after_one_step(
@@ -554,9 +558,7 @@ def _custom_hook(input: List[torch.Tensor]) -> None:
554558
)
555559

556560
# Compare predictions of sharded vs unsharded models.
557-
if qcomms_config is None:
558-
torch.testing.assert_close(global_pred, torch.cat(all_local_pred))
559-
else:
561+
if qcomms_config is not None:
560562
# With quantized comms, we can relax constraints a bit
561563
rtol = 0.003
562564
if CommType.FP8 in [
@@ -568,6 +570,25 @@ def _custom_hook(input: List[torch.Tensor]) -> None:
568570
torch.testing.assert_close(
569571
global_pred, torch.cat(all_local_pred), rtol=rtol, atol=atol
570572
)
573+
elif (
574+
weighted_tables is not None
575+
and weighted_tables[0].data_type == DataType.FP16
576+
):
577+
# we relax this accuracy test because when the embedding table weights is FP16,
578+
# the sharded EBC would upscale the precision to FP32 for the returned embedding
579+
# KJT.weights (FP32) + sharded_EBC (FP16) ==> embeddings (FP32)
580+
# the test uses the unsharded EBC for reference to compare the results, but the unsharded EBC
581+
# uses EmbeddingBags can only handle same precision, i.e.,
582+
# KJT.weights (FP32) + unsharded_EBC (FP32) ==> embeddings (FP32)
583+
# therefore, the discrepancy leads to a relaxed tol level.
584+
torch.testing.assert_close(
585+
global_pred,
586+
torch.cat(all_local_pred),
587+
atol=1e-4, # relaxed atol due to FP16 in weights
588+
rtol=1e-4, # relaxed rtol due to FP16 in weights
589+
)
590+
else:
591+
torch.testing.assert_close(global_pred, torch.cat(all_local_pred))
571592

572593

573594
def create_device_mesh_for_2D(

0 commit comments

Comments
 (0)