Skip to content

Commit

Permalink
Add MixedVLE (#2738)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2738

Support Mixed VLE for LoopFM

Reviewed By: tao-jia

Differential Revision: D68706536

fbshipit-source-id: e57e5d004bf0a43b9ba3f452278b8d0a60ded7be
  • Loading branch information
Huayu Li authored and facebook-github-bot committed Mar 5, 2025
1 parent 919bbcb commit a2ade72
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 0 deletions.
3 changes: 3 additions & 0 deletions torchrec/modules/embedding_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,9 @@ class BaseEmbeddingConfig:
# enable this flag to support rw_sharding
need_pos: bool = False

# handle the special case
input_dim: Optional[int] = None

def get_weight_init_max(self) -> float:
if self.weight_init_max is None:
return sqrt(1 / self.num_embeddings)
Expand Down
2 changes: 2 additions & 0 deletions torchrec/schema/api_tests/test_embedding_config_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class StableEmbeddingBagConfig:
# when the position_weighted feature is in this table config,
# enable this flag to support rw_sharding
need_pos: bool = False
input_dim: Optional[int] = None
pooling: PoolingType = PoolingType.SUM


Expand All @@ -56,6 +57,7 @@ class StableEmbeddingConfig:
# when the position_weighted feature is in this table config,
# enable this flag to support rw_sharding
need_pos: bool = False
input_dim: Optional[int] = None


class TestEmbeddingConfigSchema(unittest.TestCase):
Expand Down

0 comments on commit a2ade72

Please sign in to comment.