Skip to content

Commit a347e07

Browse files
sarckkfacebook-github-bot
authored andcommitted
Fail early if no sharding option found for table (#1657)
Summary: Pull Request resolved: #1657 Currently we raise an exception if no sharding options are found for the first table. If a sharding option is found for the first table, but not for the second table, no exception is raised. This causes the error to be [raised later when sharding the model](https://github.com/pytorch/torchrec/blob/77974f229ce7e229664fbe199e1308cc37a91d7f/torchrec/distributed/embeddingbag.py#L217-L218), which is harder to debug. Reviewed By: henrylhtsang Differential Revision: D53044797 fbshipit-source-id: bf0ec2307ef9bfea9db965f3a11464e238e5d6ac
1 parent b72f9fe commit a347e07

File tree

2 files changed

+52
-2
lines changed

2 files changed

+52
-2
lines changed

torchrec/distributed/planner/enumerators.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,8 @@ def enumerate(
127127
bounds_check_mode,
128128
) = _extract_constraints_for_param(self._constraints, name)
129129

130+
sharding_options_per_table: List[ShardingOption] = []
131+
130132
for sharding_type in self._filter_sharding_types(
131133
name, sharder.sharding_types(self._compute_device)
132134
):
@@ -150,7 +152,7 @@ def enumerate(
150152
elif isinstance(child_module, EmbeddingTowerCollection):
151153
tower_index = _get_tower_index(name, child_module)
152154
dependency = child_path + ".tower_" + str(tower_index)
153-
sharding_options.append(
155+
sharding_options_per_table.append(
154156
ShardingOption(
155157
name=name,
156158
tensor=param,
@@ -172,12 +174,14 @@ def enumerate(
172174
is_pooled=is_pooled,
173175
)
174176
)
175-
if not sharding_options:
177+
if not sharding_options_per_table:
176178
raise RuntimeError(
177179
"No available sharding type and compute kernel combination "
178180
f"after applying user provided constraints for {name}"
179181
)
180182

183+
sharding_options.extend(sharding_options_per_table)
184+
181185
self.populate_estimates(sharding_options)
182186

183187
return sharding_options

torchrec/distributed/planner/tests/test_enumerators.py

+46
Original file line numberDiff line numberDiff line change
@@ -858,3 +858,49 @@ def test_tower_collection_sharding(self) -> None:
858858
def test_empty(self) -> None:
859859
sharding_options = self.enumerator.enumerate(self.model, sharders=[])
860860
self.assertFalse(sharding_options)
861+
862+
def test_throw_ex_no_sharding_option_for_table(self) -> None:
863+
cw_constraint = ParameterConstraints(
864+
sharding_types=[
865+
ShardingType.COLUMN_WISE.value,
866+
],
867+
compute_kernels=[
868+
EmbeddingComputeKernel.FUSED.value,
869+
],
870+
)
871+
872+
rw_constraint = ParameterConstraints(
873+
sharding_types=[
874+
ShardingType.TABLE_ROW_WISE.value,
875+
],
876+
compute_kernels=[
877+
EmbeddingComputeKernel.FUSED_UVM_CACHING.value,
878+
],
879+
)
880+
881+
constraints = {
882+
"table_0": cw_constraint,
883+
"table_1": rw_constraint,
884+
"table_2": cw_constraint,
885+
"table_3": cw_constraint,
886+
}
887+
888+
enumerator = EmbeddingEnumerator(
889+
topology=Topology(
890+
world_size=self.world_size,
891+
compute_device=self.compute_device,
892+
local_world_size=self.local_world_size,
893+
),
894+
batch_size=self.batch_size,
895+
constraints=constraints,
896+
)
897+
898+
sharder = cast(ModuleSharder[torch.nn.Module], CWSharder())
899+
900+
with self.assertRaises(Exception) as context:
901+
_ = enumerator.enumerate(self.model, [sharder])
902+
903+
self.assertTrue(
904+
"No available sharding type and compute kernel combination after applying user provided constraints for table_1"
905+
in str(context.exception)
906+
)

0 commit comments

Comments
 (0)