diff --git a/dlio_benchmark/main.py b/dlio_benchmark/main.py index 9b3d2d9a..f16e488a 100644 --- a/dlio_benchmark/main.py +++ b/dlio_benchmark/main.py @@ -257,7 +257,7 @@ def _train(self, epoch): loader = self.framework.get_loader(dataset_type=DatasetType.TRAIN) t0 = time() for batch in dlp.iter(loader.next()): - if overall_step > max_steps or overall_step > self.total_training_steps: + if overall_step > max_steps or ((self.total_training_steps > 0) and (overall_step > self.total_training_steps)): if self.args.my_rank == 0: logging.info(f"{utcnow()} Maximum number of steps reached") if (block_step != 1 and self.do_checkpoint) or (not self.do_checkpoint):