diff --git a/torchrec/modules/embedding_configs.py b/torchrec/modules/embedding_configs.py index 0c41220e4..b665257a8 100644 --- a/torchrec/modules/embedding_configs.py +++ b/torchrec/modules/embedding_configs.py @@ -193,6 +193,9 @@ class BaseEmbeddingConfig: # enable this flag to support rw_sharding need_pos: bool = False + # handle the special case + input_dim: Optional[int] = None + def get_weight_init_max(self) -> float: if self.weight_init_max is None: return sqrt(1 / self.num_embeddings) diff --git a/torchrec/schema/api_tests/test_embedding_config_schema.py b/torchrec/schema/api_tests/test_embedding_config_schema.py index 27511c8c9..c0ca41a5b 100644 --- a/torchrec/schema/api_tests/test_embedding_config_schema.py +++ b/torchrec/schema/api_tests/test_embedding_config_schema.py @@ -38,6 +38,7 @@ class StableEmbeddingBagConfig: # when the position_weighted feature is in this table config, # enable this flag to support rw_sharding need_pos: bool = False + input_dim: Optional[int] = None pooling: PoolingType = PoolingType.SUM @@ -56,6 +57,7 @@ class StableEmbeddingConfig: # when the position_weighted feature is in this table config, # enable this flag to support rw_sharding need_pos: bool = False + input_dim: Optional[int] = None class TestEmbeddingConfigSchema(unittest.TestCase):