Skip to content

Commit d7bced9

Browse files
joshuadengfacebook-github-bot
authored andcommitted
prevent calling next on exhausted dataloader in train pipeline (#1778)
Summary: Pull Request resolved: #1778 calling `next` on an already exhausted dataloader can cause the dataloader to hang. this diff prevents this from occurring while respecting the train pipeline api which can allow user to send in a different pipeline. Reviewed By: sarckk, lequytra Differential Revision: D54753344 fbshipit-source-id: 64a5ec3b5fa39cbfe3206b7993608c42c81039ee
1 parent 2edb86c commit d7bced9

File tree

1 file changed

+20
-1
lines changed

1 file changed

+20
-1
lines changed

torchrec/distributed/train_pipeline/train_pipeline.py

+20-1
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,8 @@ def __init__(
178178
self._batch_ip2: Optional[In] = None
179179
self._context = TrainPipelineContext()
180180
self._pipelined_modules: List[ShardedModule] = []
181+
self._dataloader_iter: Optional[Iterator[In]] = None
182+
self._dataloader_exhausted: bool = False
181183

182184
def _fill_pipeline(self, dataloader_iter: Iterator[In]) -> None:
183185
# pipeline is already filled
@@ -262,13 +264,30 @@ def _copy_batch_to_gpu(self, dataloader_iter: Iterator[In]) -> Optional[In]:
262264
"""
263265
with record_function("## copy_batch_to_gpu ##"):
264266
with torch.cuda.stream(self._memcpy_stream):
265-
batch = next(dataloader_iter, None)
267+
batch = self._next_batch(dataloader_iter)
266268
if batch is not None:
267269
batch = _to_device(batch, self._device, non_blocking=True)
268270
elif not self._execute_all_batches:
269271
raise StopIteration
270272
return batch
271273

274+
def _next_batch(self, dataloader_iter: Iterator[In]) -> Optional[In]:
275+
"""
276+
Retrieves next batch from dataloader and prevents calling `next` on an already
277+
exhausted dataloader, which can cause hanging.
278+
"""
279+
if dataloader_iter is not self._dataloader_iter:
280+
self._dataloader_iter = dataloader_iter
281+
self._dataloader_exhausted = False
282+
283+
if self._dataloader_exhausted:
284+
batch = None
285+
else:
286+
batch = next(dataloader_iter, None)
287+
if batch is None:
288+
self._dataloader_exhausted = True
289+
return batch
290+
272291
def _start_sparse_data_dist(self, batch: Optional[In]) -> None:
273292
"""
274293
Waits for batch to finish getting copied to GPU, then starts the input dist.

0 commit comments

Comments
 (0)