Skip to content

Commit

Permalink
Enhancing Dali data loader support (#94)
Browse files Browse the repository at this point in the history
* added support for multiprocessing context

* added file access option for MPI IO

* data loader optimzing

* change profiling to false to disable iostat by default

* added copyright info to python files that did not have it

* fixed dali data loader execution.

* fixing batch_size not defined issue

* Merge branch 'main' into bugfix/dali-dl

* fixes the calculation of image idx.

* fixes the calculation of image idx.

* set prefetch_factor to be 2 if it was set to be 0

* moved some part of the code init in tf_reader

* fixed DataLoader args issues

* recover tf_reader code

---------

Co-authored-by: Hariharan Devarajan <[email protected]>
  • Loading branch information
zhenghh04 and hariharan-devarajan authored Oct 4, 2023
1 parent 81b33ef commit d624764
Show file tree
Hide file tree
Showing 16 changed files with 252 additions and 86 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/python-package-conda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -160,4 +160,4 @@ jobs:
run: |
source ${VENV}/bin/activate
mpirun -np 2 dlio_benchmark ++workload.workflow.generate_data=True ++workload.workflow.train=False
mpirun -np 2 dlio_benchmark ++workload.workflow.generate_data=False ++workload.workflow.train=True ++workload.dataset.num_files_train=8
mpirun -np 2 dlio_benchmark ++workload.workflow.generate_data=False ++workload.workflow.train=True ++workload.dataset.num_files_train=8
18 changes: 13 additions & 5 deletions dlio_benchmark/common/enumerations.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,17 +199,25 @@ class FileAccess(Enum):
SHARED = 'shared'
# TO(HZ): I see currently, this collective mode is not used. It might be good to separate it out
COLLECTIVE = 'collective'
MPIO = 'mpio'
POSIX = 'posix'

def __str__(self):
return self.value

@staticmethod
def get_enum(value):
if DatasetType.TRAIN.value == value:
return DatasetType.TRAIN
elif DatasetType.VALID.value == value:
return DatasetType.VALID

if FileAccess.MPIO.value == value:
return FileAccess.MPIO
elif FileAccess.POSIX.value == value:
return FileAccess.POSIX
elif FileAccess.MULTI.value == value:
return FileAccess.MULTI
elif FileAccess.SHARED.value == value:
return FileAccess.SHARED
elif FileAccess.COLLECTIVE.value == value:
return FileAccess.COLLECTIVE

class Compression(Enum):
"""
Different Compression Libraries.
Expand Down
8 changes: 5 additions & 3 deletions dlio_benchmark/configs/workload/resnet50.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,14 @@ dataset:
num_samples_per_file: 1
record_length: 150528
data_folder: data/resnet50
format: jpeg
format: png

train:
computation_time: 0.317 # this is for A100

train:
computation_time: 0.1

reader:
data_loader: pytorch
read_threads: 8
computation_threads: 8
batch_size: 1
18 changes: 18 additions & 0 deletions dlio_benchmark/data_loader/base_data_loader.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,19 @@
"""
Copyright (c) 2022, UChicago Argonne, LLC
All Rights Reserved
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import math
import os
from abc import ABC, abstractmethod
Expand All @@ -17,6 +33,8 @@ def __init__(self, format_type, dataset_type, epoch_number, data_loader_type):
self.format_type = format_type
self.epoch_number = epoch_number
self.data_loader_type = data_loader_type
self.num_samples = self._args.total_samples_train if self.dataset_type is DatasetType.TRAIN else self._args.total_samples_eval
self.batch_size = self._args.batch_size if self.dataset_type is DatasetType.TRAIN else self._args.batch_size_eval

@abstractmethod
def read(self):
Expand Down
101 changes: 64 additions & 37 deletions dlio_benchmark/data_loader/dali_data_loader.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,26 @@
"""
Copyright (c) 2022, UChicago Argonne, LLC
All Rights Reserved
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from time import time
import logging
import math
import numpy as np
from nvidia.dali.pipeline import Pipeline
import nvidia.dali.fn as fn
import nvidia.dali.types as types
import nvidia.dali as dali
from nvidia.dali.plugin.pytorch import DALIGenericIterator

from dlio_benchmark.common.constants import MODULE_DATA_LOADER
from dlio_benchmark.common.enumerations import Shuffle, DataLoaderType, DatasetType
Expand All @@ -19,43 +33,44 @@

class DaliDataset(object):

def __init__(self, format_type, dataset_type, epoch, num_samples, batch_size, thread_index):
def __init__(self, format_type, dataset_type, epoch, thread_index,
total_num_workers, total_num_samples, samples_per_worker, batch_size):
self.format_type = format_type
self.dataset_type = dataset_type
self.epoch = epoch
self.num_samples = num_samples
self.num_images_read = 0
self.total_num_workers = total_num_workers
self.total_num_samples = total_num_samples
self.samples_per_worker = samples_per_worker
self.batch_size = batch_size
self.worker_index = thread_index
self.reader = ReaderFactory.get_reader(type=self.format_type,
dataset_type=self.dataset_type,
thread_index=thread_index,
epoch_number=self.epoch)
self.item = self.reader.next()
self.is_last = 0

def __call__(self, sample_info):
self.num_images_read += 1
step = int(math.ceil(self.num_images_read / self.batch_size))
sample_idx = sample_info.idx_in_epoch
logging.debug(f"{utcnow()} Reading {sample_idx} {sample_info.iteration} {self.num_samples} {self.batch_size}")
if sample_info.iteration >= self.num_samples // self.batch_size:
logging.debug(
f"{utcnow()} Reading {sample_info.idx_in_epoch} out of {self.samples_per_worker} by worker {self.worker_index}")
sample_idx = sample_info.idx_in_epoch + self.total_num_workers * self.worker_index
logging.debug(
f"{utcnow()} Reading {sample_idx} on {sample_info.iteration} by worker {self.worker_index}")
if sample_info.iteration >= self.samples_per_worker or sample_idx >= self.total_num_samples:
# Indicate end of the epoch
raise StopIteration()
with Profile(MODULE_DATA_LOADER, epoch=self.epoch,image_idx=sample_idx, step=step):

step = int(math.ceil(sample_idx / self.batch_size))
with Profile(MODULE_DATA_LOADER, epoch=self.epoch, image_idx=sample_idx, step=step):
image = self.reader.read_index(sample_idx, step)
return image, np.uint8([sample_idx])


class DaliDataLoader(BaseDataLoader):
@dlp.log_init
def __init__(self, format_type, dataset_type, epoch):
super().__init__(format_type, dataset_type, epoch, DataLoaderType.DALI)
self.pipeline = None
self.pipelines = []

@dlp.log
def read(self):
num_samples = self._args.total_samples_train if self.dataset_type is DatasetType.TRAIN else self._args.total_samples_eval
batch_size = self._args.batch_size if self.dataset_type is DatasetType.TRAIN else self._args.batch_size_eval
parallel = True if self._args.read_threads > 0 else False
self.pipelines = []
num_threads = 1
Expand All @@ -64,35 +79,47 @@ def read(self):
prefetch_size = 2
if self._args.prefetch_size > 0:
prefetch_size = self._args.prefetch_size
# None executes pipeline on CPU and the reader does the batching
dataset = DaliDataset(self.format_type, self.dataset_type, self.epoch_number, num_samples, batch_size, 0)
self.pipeline = Pipeline(batch_size=batch_size, num_threads=num_threads, device_id=None, py_num_workers=num_threads,
prefetch_queue_depth=prefetch_size, py_start_method='fork', exec_async=True)
with self.pipeline:
images, labels = fn.external_source(source=dataset, num_outputs=2, dtype=[types.UINT8, types.UINT8],
parallel=True, batch=False)
self.pipeline.set_outputs(images, labels)

self.pipeline.build()
num_pipelines = 1
samples_per_worker = self.num_samples // num_pipelines // self._args.comm_size

for worker_index in range(num_pipelines):
global_worker_index = self._args.my_rank * num_pipelines + worker_index
# None executes pipeline on CPU and the reader does the batching
dataset = DaliDataset(self.format_type, self.dataset_type, self.epoch_number, global_worker_index,
self._args.comm_size * num_pipelines, self.num_samples, samples_per_worker, self.batch_size)
pipeline = Pipeline(batch_size=self.batch_size, num_threads=num_threads, device_id=None, py_num_workers=num_threads//num_pipelines,
prefetch_queue_depth=prefetch_size, py_start_method=self._args.multiprocessing_context, exec_async=True)
with pipeline:
images, labels = fn.external_source(source=dataset, num_outputs=2, dtype=[types.UINT8, types.UINT8],
parallel=True, batch=False)
pipeline.set_outputs(images, labels)
self.pipelines.append(pipeline)
for pipe in self.pipelines:
pipe.start_py_workers()
for pipe in self.pipelines:
pipe.build()
for pipe in self.pipelines:
pipe.schedule_run()
logging.debug(f"{utcnow()} Starting {num_threads} pipelines by {self._args.my_rank} rank ")


@dlp.log
def next(self):
super().next()
num_samples = self._args.total_samples_train if self.dataset_type is DatasetType.TRAIN else self._args.total_samples_eval
batch_size = self._args.batch_size if self.dataset_type is DatasetType.TRAIN else self._args.batch_size_eval
#DALIGenericIterator(self.pipelines, ['data', 'label'])
# DALIGenericIterator(self.pipelines, ['data', 'label'])

logging.debug(f"{utcnow()} Iterating pipelines by {self._args.my_rank} rank ")
for step in range(num_samples // batch_size):
outputs = self.pipeline.run()
logging.debug(f"{utcnow()} Output batch {step} {len(outputs)}")
for batch in outputs:
step = 0
while step <= self.num_samples // self.batch_size:
for pipe in self.pipelines:
outputs = pipe.share_outputs()
logging.debug(f"{utcnow()} Output batch {step} {len(outputs)}")
yield outputs



step += 1
pipe.release_outputs()
pipe.schedule_run()


@dlp.log
def finalize(self):
pass
19 changes: 17 additions & 2 deletions dlio_benchmark/data_loader/tf_data_loader.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,19 @@
"""
Copyright (c) 2022, UChicago Argonne, LLC
All Rights Reserved
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from time import time
import logging
import math
Expand Down Expand Up @@ -64,12 +80,11 @@ def read(self):
options.experimental_threading.private_threadpool_size = read_threads
options.experimental_threading.max_intra_op_parallelism = read_threads

batch_size = self._args.batch_size if self.dataset_type is DatasetType.TRAIN else self._args.batch_size_eval
self._dataset = tf.data.Dataset.from_tensor_slices(np.arange(read_threads)).with_options(options)

self._dataset = self._dataset.interleave(lambda x: TensorflowDataset(self.format_type, self.dataset_type,
self.epoch_number, (
batch_size,
self.batch_size,
self._args.max_dimension,
self._args.max_dimension), x),
cycle_length=read_threads,
Expand Down
59 changes: 41 additions & 18 deletions dlio_benchmark/data_loader/torch_data_loader.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,19 @@
"""
Copyright (c) 2022, UChicago Argonne, LLC
All Rights Reserved
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from time import time
import logging
import math
Expand Down Expand Up @@ -57,14 +73,12 @@ def __init__(self, format_type, dataset_type, epoch_number):
@dlp.log
def read(self):
do_shuffle = True if self._args.sample_shuffle != Shuffle.OFF else False
num_samples = self._args.total_samples_train if self.dataset_type is DatasetType.TRAIN else self._args.total_samples_eval
batch_size = self._args.batch_size if self.dataset_type is DatasetType.TRAIN else self._args.batch_size_eval
dataset = TorchDataset(self.format_type, self.dataset_type, self.epoch_number, num_samples, self._args.read_threads, batch_size)
dataset = TorchDataset(self.format_type, self.dataset_type, self.epoch_number, self.num_samples, self._args.read_threads, self.batch_size)
if do_shuffle:
sampler = RandomSampler(dataset)
else:
sampler = SequentialSampler(dataset)
if self._args.read_threads > 1:
if self._args.read_threads >= 1:
prefetch_factor = math.ceil(self._args.prefetch_size / self._args.read_threads)
else:
prefetch_factor = self._args.prefetch_size
Expand All @@ -73,28 +87,37 @@ def read(self):
logging.debug(
f"{utcnow()} Prefetch size is {self._args.prefetch_size}; prefetch factor of {prefetch_factor} will be set to Torch DataLoader.")
else:
prefetch_factor = 2
if self._args.my_rank == 0:
logging.debug(
f"{utcnow()} Prefetch size is 0; a default prefetch factor of 2 will be set to Torch DataLoader.")
logging.debug(f"{utcnow()} Setup dataloader with {self._args.read_threads} workers {torch.__version__}")
if self._args.read_threads==0:
kwargs={}
else:
kwargs={'multiprocessing_context':self._args.multiprocessing_context,
'prefetch_factor': prefetch_factor,
'persistent_workers': True}
if torch.__version__ == '1.3.1':
if 'prefetch_factor' in kargs:
del kwargs['prefetch_factor']
self._dataset = DataLoader(dataset,
batch_size=batch_size,
sampler=sampler,
num_workers=self._args.read_threads,
pin_memory=True,
drop_last=True,
worker_init_fn=dataset.worker_init)
batch_size=self.batch_size,
sampler=sampler,
num_workers=self._args.read_threads,
pin_memory=True,
drop_last=True,
worker_init_fn=dataset.worker_init, **kwargs)
else:
self._dataset = DataLoader(dataset,
batch_size=batch_size,
sampler=sampler,
num_workers=self._args.read_threads,
pin_memory=True,
drop_last=True,
worker_init_fn=dataset.worker_init,
prefetch_factor=prefetch_factor if prefetch_factor > 0 else 2) # 2 is the default value
logging.debug(f"{utcnow()} Rank {self._args.my_rank} will read {len(self._dataset) * batch_size} files")
batch_size=self.batch_size,
sampler=sampler,
num_workers=self._args.read_threads,
pin_memory=True,
drop_last=True,
worker_init_fn=dataset.worker_init,
**kwargs) # 2 is the default value
logging.debug(f"{utcnow()} Rank {self._args.my_rank} will read {len(self._dataset) * self.batch_size} files")

# self._dataset.sampler.set_epoch(epoch_number)

Expand Down
8 changes: 4 additions & 4 deletions dlio_benchmark/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,10 +205,10 @@ def initialize(self):
file_list_train = fullpaths
elif dataset_type is DatasetType.VALID:
file_list_eval = fullpaths
if not self.generate_only:
assert(self.num_files_train <=len(file_list_train))
if self.do_eval:
assert(self.num_files_eval <=len(file_list_eval))
if not self.generate_only and self.num_files_train > len(file_list_train):
raise Exception("Not enough training dataset is found; Please run the code with ++workload.workflow.generate_data=True")
if self.do_eval and self.num_files_eval > len(file_list_eval):
raise Exception("Not enough evaluation dataset is found; Please run the code with ++workload.workflow.generate_data=True")
if (self.num_files_train < len(file_list_train)):
logging.warning(f"Number of files for training in {os.path.join(self.args.data_folder, f'{DatasetType.TRAIN}')} ({len(file_list_train)}) is more than requested ({self.num_files_train}). A subset of files will be used ")
file_list_train = file_list_train[:self.num_files_train]
Expand Down
Loading

0 comments on commit d624764

Please sign in to comment.