Skip to content

Commit f1c716a

Browse files
Ivan Kobzarevfacebook-github-bot
Ivan Kobzarev
authored andcommitted
Small syntactic changes for dynamo compatibility
Summary: Dynamo has some gaps in support of generators, list comprehension etc. Avoiding them for now with syntactic changes Previous diff was reverted because recat was created on the target device from the start. Then with per-item manipulations it was writing directly to device (which broke freya training as it looks like freya does not support per-item changes). In this diff recat is created on "cpu", the same as List[int] in original version. Reviewed By: MatthewWEdwards Differential Revision: D54192498 fbshipit-source-id: bdfffc5f207fa200c9b207969225c2bcbdf94e7a
1 parent fdfc4e0 commit f1c716a

File tree

3 files changed

+33
-15
lines changed

3 files changed

+33
-15
lines changed

torchrec/distributed/dist_data.py

+10-7
Original file line numberDiff line numberDiff line change
@@ -81,17 +81,20 @@ def _get_recat(
8181
if local_split == 0:
8282
return None
8383

84-
recat: List[int] = []
84+
feature_order: List[int] = []
85+
for x in range(num_splits // stagger):
86+
for y in range(stagger):
87+
feature_order.append(x + num_splits // stagger * y)
8588

86-
feature_order: List[int] = [
87-
x + num_splits // stagger * y
88-
for x in range(num_splits // stagger)
89-
for y in range(stagger)
90-
]
89+
recat: torch.Tensor = torch.empty(
90+
local_split * len(feature_order), dtype=torch.int32
91+
)
9192

93+
_i = 0
9294
for i in range(local_split):
9395
for j in feature_order: # range(num_splits):
94-
recat.append(i + j * local_split)
96+
recat[_i] = i + j * local_split
97+
_i += 1
9598

9699
# variable batch size
97100
if batch_size_per_rank is not None and any(

torchrec/distributed/embedding_sharding.py

+16-2
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,15 @@ def _fx_wrap_stride_per_key_per_rank(
8787
)
8888

8989

90+
@torch.fx.wrap
91+
def _fx_wrap_gen_list_n_times(ls: List[str], n: int) -> List[str]:
92+
# Syntax for dynamo (instead of generator kjt.keys() * num_buckets)
93+
ret: List[str] = []
94+
for _ in range(n):
95+
ret.extend(ls)
96+
return ret
97+
98+
9099
def bucketize_kjt_before_all2all(
91100
kjt: KeyedJaggedTensor,
92101
num_buckets: int,
@@ -143,7 +152,7 @@ def bucketize_kjt_before_all2all(
143152
return (
144153
KeyedJaggedTensor(
145154
# duplicate keys will be resolved by AllToAll
146-
keys=kjt.keys() * num_buckets,
155+
keys=_fx_wrap_gen_list_n_times(kjt.keys(), num_buckets),
147156
values=bucketized_indices,
148157
weights=pos if bucketize_pos else bucketized_weights,
149158
lengths=bucketized_lengths.view(-1),
@@ -371,7 +380,12 @@ def _wait_impl(self) -> KJTList:
371380
Returns:
372381
KJTList: synced `KJTList`.
373382
"""
374-
kjts = [w.wait() for w in self.awaitables]
383+
384+
# Syntax: no list comprehension usage for dynamo
385+
kjts = []
386+
for w in self.awaitables:
387+
kjts.append(w.wait())
388+
375389
_set_sharding_context_post_a2a(kjts, self.ctx)
376390
return KJTList(kjts)
377391

torchrec/distributed/embeddingbag.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -856,12 +856,13 @@ def compute_and_output_dist(
856856
) -> LazyAwaitable[KeyedTensor]:
857857
batch_size_per_feature_pre_a2a = []
858858
awaitables = []
859-
for lookup, dist, sharding_context, features in zip(
860-
self._lookups,
861-
self._output_dists,
862-
ctx.sharding_contexts,
863-
input,
864-
):
859+
860+
# No usage of zip for dynamo
861+
for i in range(len(self._lookups)):
862+
lookup = self._lookups[i]
863+
dist = self._output_dists[i]
864+
sharding_context = ctx.sharding_contexts[i]
865+
features = input[i]
865866
awaitables.append(dist(lookup(features), sharding_context))
866867
if sharding_context:
867868
batch_size_per_feature_pre_a2a.extend(

0 commit comments

Comments
 (0)