@@ -591,6 +591,9 @@ class StagedTrainPipeline(TrainPipeline[In, Optional[StageOut]]):
591
591
Args:
592
592
pipeline_stages (List[PipelineStage]): A list of stages to execute.
593
593
debug_mode (bool): Whether to enable debug mode.
594
+ compute_stream (Optional[torch.cuda.Stream]): The main compute stream in which
595
+ model forward is run, usually torch.cuda.default_stream(). Defaults to the
596
+ current cuda stream.
594
597
595
598
Example::
596
599
train_pipeline = StagedTrainPipeline(
@@ -619,6 +622,7 @@ def __init__(
619
622
self ,
620
623
pipeline_stages : List [PipelineStage ],
621
624
debug_mode : bool = False ,
625
+ compute_stream : Optional [torch .cuda .Stream ] = None ,
622
626
) -> None :
623
627
self ._pipeline_stages = pipeline_stages
624
628
self ._debug_mode = debug_mode
@@ -627,20 +631,23 @@ def __init__(
627
631
)
628
632
self ._initialized = False
629
633
self ._num_steps = 0
634
+ self ._dataloader_iter : Optional [Iterator [In ]] = None
635
+ self ._dataloader_exhausted : bool = False
636
+ self ._compute_stream : torch .cuda .streams .Stream = (
637
+ compute_stream or torch .cuda .current_stream ()
638
+ )
630
639
631
640
@property
632
641
def num_stages (self ) -> int :
633
642
return len (self ._pipeline_stages )
634
643
635
- def _advance (self ) -> Optional [StageOut ]:
644
+ def _advance (self ) -> Optional [StageOutputWithEvent ]:
636
645
# left shifts all batch results.
637
646
out = self ._stage_outputs [0 ]
638
647
for idx in range (self .num_stages - 1 ):
639
648
self ._stage_outputs [idx ] = self ._stage_outputs [idx + 1 ]
640
649
self ._stage_outputs [- 1 ] = None
641
- if out is None :
642
- return out
643
- return out [0 ]
650
+ return out
644
651
645
652
def _run_with_event (
646
653
self ,
@@ -662,6 +669,23 @@ def _run_with_event(
662
669
new_event .record (stream )
663
670
return (output , new_event )
664
671
672
+ def _next_batch (self , dataloader_iter : Iterator [In ]) -> Optional [In ]:
673
+ """
674
+ Retrieves next batch from dataloader and prevents calling `next` on an already
675
+ exhausted dataloader, which can cause hanging.
676
+ """
677
+ if dataloader_iter is not self ._dataloader_iter :
678
+ self ._dataloader_iter = dataloader_iter
679
+ self ._dataloader_exhausted = False
680
+
681
+ if self ._dataloader_exhausted :
682
+ batch = None
683
+ else :
684
+ batch = next (dataloader_iter , None )
685
+ if batch is None :
686
+ self ._dataloader_exhausted = True
687
+ return batch
688
+
665
689
def _run_stage (
666
690
self ,
667
691
batch_offset : int ,
@@ -680,7 +704,7 @@ def _run_stage(
680
704
f"## Pipeline Stage { stage_idx } : { stage .name } for batch { batch_offset + self ._num_steps } ##"
681
705
):
682
706
if stage_idx == 0 :
683
- batch_to_wait = next (dataloader_iter , None )
707
+ batch_to_wait = self . _next_batch (dataloader_iter )
684
708
event = None
685
709
else :
686
710
batch_to_wait_with_event = self ._stage_outputs [batch_offset ]
@@ -765,7 +789,12 @@ def progress(
765
789
if not self ._initialized :
766
790
self ._fill_pipeline (dataloader_iter )
767
791
768
- output = self ._advance ()
792
+ output_with_event = self ._advance ()
793
+
794
+ if output_with_event is None :
795
+ # All data consumed, exit early
796
+ return None
797
+
769
798
self ._num_steps += 1
770
799
771
800
for stage_idx in range (self .num_stages ):
@@ -776,4 +805,11 @@ def progress(
776
805
dataloader_iter = dataloader_iter ,
777
806
)
778
807
779
- return output
808
+ out , event = output_with_event
809
+ if event is not None :
810
+ # Since model forward() is expected to run outside the pipeline,
811
+ # we need to explicitly wait for the last stage to finish
812
+ event .wait (self ._compute_stream )
813
+ out .record_stream (self ._compute_stream )
814
+
815
+ return out
0 commit comments