From 81b33efb6147659d6679c119ade1cc6e92976e3f Mon Sep 17 00:00:00 2001 From: Hariharan Devarajan Date: Mon, 2 Oct 2023 19:32:21 -0700 Subject: [PATCH] fixed dali data loader execution. (#91) --- .../data_loader/dali_data_loader.py | 35 +++++++++++++------ 1 file changed, 24 insertions(+), 11 deletions(-) diff --git a/dlio_benchmark/data_loader/dali_data_loader.py b/dlio_benchmark/data_loader/dali_data_loader.py index dc36b9f8..73e5eba5 100644 --- a/dlio_benchmark/data_loader/dali_data_loader.py +++ b/dlio_benchmark/data_loader/dali_data_loader.py @@ -37,6 +37,7 @@ def __call__(self, sample_info): self.num_images_read += 1 step = int(math.ceil(self.num_images_read / self.batch_size)) sample_idx = sample_info.idx_in_epoch + logging.debug(f"{utcnow()} Reading {sample_idx} {sample_info.iteration} {self.num_samples} {self.batch_size}") if sample_info.iteration >= self.num_samples // self.batch_size: # Indicate end of the epoch raise StopIteration() @@ -49,7 +50,7 @@ class DaliDataLoader(BaseDataLoader): @dlp.log_init def __init__(self, format_type, dataset_type, epoch): super().__init__(format_type, dataset_type, epoch, DataLoaderType.DALI) - self.pipelines = [] + self.pipeline = None @dlp.log def read(self): @@ -60,25 +61,37 @@ def read(self): num_threads = 1 if self._args.read_threads > 0: num_threads = self._args.read_threads - dataset = DaliDataset(self.format_type, self.dataset_type, self.epoch_number, num_samples, batch_size, 0) + prefetch_size = 2 + if self._args.prefetch_size > 0: + prefetch_size = self._args.prefetch_size # None executes pipeline on CPU and the reader does the batching - pipeline = Pipeline(batch_size=batch_size, num_threads=num_threads, device_id=None, py_num_workers=num_threads) - with pipeline: + dataset = DaliDataset(self.format_type, self.dataset_type, self.epoch_number, num_samples, batch_size, 0) + self.pipeline = Pipeline(batch_size=batch_size, num_threads=num_threads, device_id=None, py_num_workers=num_threads, + prefetch_queue_depth=prefetch_size, py_start_method='fork', exec_async=True) + with self.pipeline: images, labels = fn.external_source(source=dataset, num_outputs=2, dtype=[types.UINT8, types.UINT8], - parallel=parallel, batch=False) - pipeline.set_outputs(images, labels) - self.pipelines.append(pipeline) - logging.info(f"{utcnow()} Creating {num_threads} pipelines by {self._args.my_rank} rank ") + parallel=True, batch=False) + self.pipeline.set_outputs(images, labels) + + self.pipeline.build() + logging.debug(f"{utcnow()} Starting {num_threads} pipelines by {self._args.my_rank} rank ") + @dlp.log def next(self): super().next() num_samples = self._args.total_samples_train if self.dataset_type is DatasetType.TRAIN else self._args.total_samples_eval batch_size = self._args.batch_size if self.dataset_type is DatasetType.TRAIN else self._args.batch_size_eval + #DALIGenericIterator(self.pipelines, ['data', 'label']) + + logging.debug(f"{utcnow()} Iterating pipelines by {self._args.my_rank} rank ") for step in range(num_samples // batch_size): - _dataset = DALIGenericIterator(self.pipelines, ['data', 'label'], size=1) - for batch in _dataset: - yield batch + outputs = self.pipeline.run() + logging.debug(f"{utcnow()} Output batch {step} {len(outputs)}") + for batch in outputs: + yield outputs + + @dlp.log def finalize(self):