Skip to content

Commit b40c5c1

Browse files
spcypptfacebook-github-bot
authored andcommitted
Unifying TBE API using List (Frontend) (#2751)
Summary: X-link: pytorch/FBGEMM#3711 X-link: facebookresearch/FBGEMM#793 **Backend**: D68054868 --- As the number of arguments in TBE keeps growing, some of the optimizers run into number of arguments limitation (i.e., 64) during pytorch operation registration. **For long-term growth and maintenance, we hence redesign TBE API by packing some of the arguments into list. Note that not all arguments are packed.** We pack the arguments as a list for each type. For **common** arguments, we pack - weights and arguments of type `Momentum` into TensorList - other tensors and optional tensors to list of optional tensors `aux_tensor` - `int` arguments into `aux_int` - `float` arguments into `aux_float` - `bool` arguments into `aux_bool`. Similarly for **optimizer-specific** arguments, we pack - arguments of type `Momentum` that are *__not__ optional* into TensorList - *optional* tensors to list of optional tensors `optim_tensor` - `int` arguments into `optim_int` - `float` arguments into `optim_float` - `bool` arguments into `optim_bool`. We see issues with pytorch registration across packing SymInt in python-C++, so we unroll and pass SymInt arguments individually. **This significantly reduces number of arguments.** For example, `split_embedding_codegen_lookup_rowwise_adagrad_with_counter_function`, which currently has 61 arguments only have 26 arguments with this API design. Please refer to the design doc on which arguments are packed and signature. Design doc: https://docs.google.com/document/d/1dCBg7dcf7Yq9FHVrvXsAmFtBxkDi9o6u0r-Ptd4UDPE/edit?tab=t.0#heading=h.6bip5pwqq8xb Full signature for each optimizer lookup function will be provided shortly. Reviewed By: sryap Differential Revision: D68055168
1 parent 919bbcb commit b40c5c1

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

torchrec/distributed/batched_embedding_kernel.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ def __init__(
190190
state: Dict[Any, Any] = {}
191191
param_group: Dict[str, Any] = {
192192
"params": [],
193-
"lr": emb_module.optimizer_args.learning_rate,
193+
"lr": emb_module.optimizer_args.learning_rate_tensor,
194194
}
195195

196196
params: Dict[str, Union[torch.Tensor, ShardedTensor]] = {}
@@ -383,7 +383,7 @@ def get_optimizer_pointwise_shard_metadata_and_global_metadata(
383383
state: Dict[Any, Any] = {}
384384
param_group: Dict[str, Any] = {
385385
"params": [],
386-
"lr": emb_module.optimizer_args.learning_rate,
386+
"lr": emb_module.optimizer_args.learning_rate_tensor,
387387
}
388388

389389
params: Dict[str, Union[torch.Tensor, ShardedTensor]] = {}

torchrec/modules/fused_embedding_modules.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def __init__( # noqa C901
6868
state: Dict[Any, Any] = {}
6969
param_group: Dict[str, Any] = {
7070
"params": [],
71-
"lr": emb_module.optimizer_args.learning_rate,
71+
"lr": emb_module.optimizer_args.learning_rate_tensor,
7272
}
7373

7474
params: Dict[str, torch.Tensor] = {}

0 commit comments

Comments
 (0)