Skip to content

Commit 26b1c83

Browse files
spcypptfacebook-github-bot
authored andcommitted
Backout
Summary: X-link: pytorch/FBGEMM#3803 X-link: facebookresearch/FBGEMM#887 Backout D68055168 as it seems to break pyper and causes S498612. Differential Revision: D70996903
1 parent b418a44 commit 26b1c83

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

torchrec/distributed/batched_embedding_kernel.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ def __init__(
190190
state: Dict[Any, Any] = {}
191191
param_group: Dict[str, Any] = {
192192
"params": [],
193-
"lr": emb_module.optimizer_args.learning_rate_tensor,
193+
"lr": emb_module.optimizer_args.learning_rate,
194194
}
195195

196196
params: Dict[str, Union[torch.Tensor, ShardedTensor]] = {}
@@ -383,7 +383,7 @@ def get_optimizer_pointwise_shard_metadata_and_global_metadata(
383383
state: Dict[Any, Any] = {}
384384
param_group: Dict[str, Any] = {
385385
"params": [],
386-
"lr": emb_module.optimizer_args.learning_rate_tensor,
386+
"lr": emb_module.optimizer_args.learning_rate,
387387
}
388388

389389
params: Dict[str, Union[torch.Tensor, ShardedTensor]] = {}

torchrec/modules/fused_embedding_modules.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def __init__( # noqa C901
6868
state: Dict[Any, Any] = {}
6969
param_group: Dict[str, Any] = {
7070
"params": [],
71-
"lr": emb_module.optimizer_args.learning_rate_tensor,
71+
"lr": emb_module.optimizer_args.learning_rate,
7272
}
7373

7474
params: Dict[str, torch.Tensor] = {}

0 commit comments

Comments
 (0)