Skip to content

Commit 42fa9d0

Browse files
author
pytorchbot
committed
2025-01-14 nightly release (542b0b2)
1 parent 6f54403 commit 42fa9d0

File tree

2 files changed

+7
-0
lines changed

2 files changed

+7
-0
lines changed

torchrec/distributed/batched_embedding_kernel.py

+3
Original file line numberDiff line numberDiff line change
@@ -624,6 +624,9 @@ def step(self, closure: Any = None) -> None:
624624
def set_optimizer_step(self, step: int) -> None:
625625
self._emb_module.set_optimizer_step(step)
626626

627+
def update_hyper_parameters(self, params_dict: Dict[str, Any]) -> None:
628+
self._emb_module.update_hyper_parameters(params_dict)
629+
627630

628631
def _gen_named_parameters_by_table_ssd(
629632
emb_module: SSDTableBatchedEmbeddingBags,

torchrec/distributed/global_settings.py

+4
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,7 @@ def construct_sharded_tensor_from_metadata_enabled() -> bool:
3030
return (
3131
os.environ.get(TORCHREC_CONSTRUCT_SHARDED_TENSOR_FROM_METADATA_ENV, "0") == "1"
3232
)
33+
34+
35+
def enable_construct_sharded_tensor_from_metadata() -> None:
36+
os.environ[TORCHREC_CONSTRUCT_SHARDED_TENSOR_FROM_METADATA_ENV] = "1"

0 commit comments

Comments
 (0)