Skip to content

Commit a90d3ed

Browse files
sarckkfacebook-github-bot
authored andcommitted
Add missing event wait for last stage in StagedTrainPipeline (#1770)
Summary: Pull Request resolved: #1770 StagedTrainPipeline expects model forward() to happen outside of the pipeline, which means that we need to wait for the last pre-forward stage to finish before progressing in the main compute stream. Also changes `wait_sparse_data_dist` to happen in the SDD stream instead of main stream Reviewed By: dracifer, joshuadeng Differential Revision: D54685704 fbshipit-source-id: cad14e1a67fb06bf56be359ef4face6877ee794b
1 parent 5856c4d commit a90d3ed

File tree

3 files changed

+55
-13
lines changed

3 files changed

+55
-13
lines changed

torchrec/distributed/train_pipeline/tests/test_train_pipeline.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -875,7 +875,9 @@ def gpu_preproc(x: StageOut) -> StageOut:
875875
fill_callback=sdd.wait_sparse_data_dist,
876876
),
877877
]
878-
pipeline = StagedTrainPipeline(pipeline_stages=pipeline_stages)
878+
pipeline = StagedTrainPipeline(
879+
pipeline_stages=pipeline_stages, compute_stream=torch.cuda.current_stream()
880+
)
879881
dataloader = iter(data)
880882

881883
pipelined_out = []

torchrec/distributed/train_pipeline/train_pipeline.py

+43-7
Original file line numberDiff line numberDiff line change
@@ -591,6 +591,9 @@ class StagedTrainPipeline(TrainPipeline[In, Optional[StageOut]]):
591591
Args:
592592
pipeline_stages (List[PipelineStage]): A list of stages to execute.
593593
debug_mode (bool): Whether to enable debug mode.
594+
compute_stream (Optional[torch.cuda.Stream]): The main compute stream in which
595+
model forward is run, usually torch.cuda.default_stream(). Defaults to the
596+
current cuda stream.
594597
595598
Example::
596599
train_pipeline = StagedTrainPipeline(
@@ -619,6 +622,7 @@ def __init__(
619622
self,
620623
pipeline_stages: List[PipelineStage],
621624
debug_mode: bool = False,
625+
compute_stream: Optional[torch.cuda.Stream] = None,
622626
) -> None:
623627
self._pipeline_stages = pipeline_stages
624628
self._debug_mode = debug_mode
@@ -627,20 +631,23 @@ def __init__(
627631
)
628632
self._initialized = False
629633
self._num_steps = 0
634+
self._dataloader_iter: Optional[Iterator[In]] = None
635+
self._dataloader_exhausted: bool = False
636+
self._compute_stream: torch.cuda.streams.Stream = (
637+
compute_stream or torch.cuda.current_stream()
638+
)
630639

631640
@property
632641
def num_stages(self) -> int:
633642
return len(self._pipeline_stages)
634643

635-
def _advance(self) -> Optional[StageOut]:
644+
def _advance(self) -> Optional[StageOutputWithEvent]:
636645
# left shifts all batch results.
637646
out = self._stage_outputs[0]
638647
for idx in range(self.num_stages - 1):
639648
self._stage_outputs[idx] = self._stage_outputs[idx + 1]
640649
self._stage_outputs[-1] = None
641-
if out is None:
642-
return out
643-
return out[0]
650+
return out
644651

645652
def _run_with_event(
646653
self,
@@ -662,6 +669,23 @@ def _run_with_event(
662669
new_event.record(stream)
663670
return (output, new_event)
664671

672+
def _next_batch(self, dataloader_iter: Iterator[In]) -> Optional[In]:
673+
"""
674+
Retrieves next batch from dataloader and prevents calling `next` on an already
675+
exhausted dataloader, which can cause hanging.
676+
"""
677+
if dataloader_iter is not self._dataloader_iter:
678+
self._dataloader_iter = dataloader_iter
679+
self._dataloader_exhausted = False
680+
681+
if self._dataloader_exhausted:
682+
batch = None
683+
else:
684+
batch = next(dataloader_iter, None)
685+
if batch is None:
686+
self._dataloader_exhausted = True
687+
return batch
688+
665689
def _run_stage(
666690
self,
667691
batch_offset: int,
@@ -680,7 +704,7 @@ def _run_stage(
680704
f"## Pipeline Stage {stage_idx} : {stage.name} for batch {batch_offset + self._num_steps} ##"
681705
):
682706
if stage_idx == 0:
683-
batch_to_wait = next(dataloader_iter, None)
707+
batch_to_wait = self._next_batch(dataloader_iter)
684708
event = None
685709
else:
686710
batch_to_wait_with_event = self._stage_outputs[batch_offset]
@@ -765,7 +789,12 @@ def progress(
765789
if not self._initialized:
766790
self._fill_pipeline(dataloader_iter)
767791

768-
output = self._advance()
792+
output_with_event = self._advance()
793+
794+
if output_with_event is None:
795+
# All data consumed, exit early
796+
return None
797+
769798
self._num_steps += 1
770799

771800
for stage_idx in range(self.num_stages):
@@ -776,4 +805,11 @@ def progress(
776805
dataloader_iter=dataloader_iter,
777806
)
778807

779-
return output
808+
out, event = output_with_event
809+
if event is not None:
810+
# Since model forward() is expected to run outside the pipeline,
811+
# we need to explicitly wait for the last stage to finish
812+
event.wait(self._compute_stream)
813+
out.record_stream(self._compute_stream)
814+
815+
return out

torchrec/distributed/train_pipeline/utils.py

+9-5
Original file line numberDiff line numberDiff line change
@@ -802,8 +802,12 @@ def start_sparse_data_dist(self, batch: In) -> In:
802802
return batch
803803

804804
def wait_sparse_data_dist(self) -> None:
805-
self.context.module_contexts = self.context.module_contexts_next_batch.copy()
806-
self.context.input_dist_tensors_requests.clear()
807-
for names, awaitable in self.context.fused_splits_awaitables:
808-
for name, request in zip(names, awaitable.wait()):
809-
self.context.input_dist_tensors_requests[name] = request
805+
with record_function("## wait_sparse_data_dist ##"):
806+
with torch.cuda.stream(self.stream):
807+
self.context.module_contexts = (
808+
self.context.module_contexts_next_batch.copy()
809+
)
810+
self.context.input_dist_tensors_requests.clear()
811+
for names, awaitable in self.context.fused_splits_awaitables:
812+
for name, request in zip(names, awaitable.wait()):
813+
self.context.input_dist_tensors_requests[name] = request

0 commit comments

Comments
 (0)