diff --git a/.github/workflows/python-package-conda.yml b/.github/workflows/python-package-conda.yml index 76e36287..3ce2b846 100644 --- a/.github/workflows/python-package-conda.yml +++ b/.github/workflows/python-package-conda.yml @@ -160,4 +160,4 @@ jobs: run: | source ${VENV}/bin/activate mpirun -np 2 dlio_benchmark ++workload.workflow.generate_data=True ++workload.workflow.train=False - mpirun -np 2 dlio_benchmark ++workload.workflow.generate_data=False ++workload.workflow.train=True ++workload.dataset.num_files_train=8 \ No newline at end of file + mpirun -np 2 dlio_benchmark ++workload.workflow.generate_data=False ++workload.workflow.train=True ++workload.dataset.num_files_train=8 diff --git a/dlio_benchmark/common/enumerations.py b/dlio_benchmark/common/enumerations.py index 86332902..64227772 100644 --- a/dlio_benchmark/common/enumerations.py +++ b/dlio_benchmark/common/enumerations.py @@ -199,17 +199,25 @@ class FileAccess(Enum): SHARED = 'shared' # TO(HZ): I see currently, this collective mode is not used. It might be good to separate it out COLLECTIVE = 'collective' + MPIO = 'mpio' + POSIX = 'posix' def __str__(self): return self.value @staticmethod def get_enum(value): - if DatasetType.TRAIN.value == value: - return DatasetType.TRAIN - elif DatasetType.VALID.value == value: - return DatasetType.VALID - + if FileAccess.MPIO.value == value: + return FileAccess.MPIO + elif FileAccess.POSIX.value == value: + return FileAccess.POSIX + elif FileAccess.MULTI.value == value: + return FileAccess.MULTI + elif FileAccess.SHARED.value == value: + return FileAccess.SHARED + elif FileAccess.COLLECTIVE.value == value: + return FileAccess.COLLECTIVE + class Compression(Enum): """ Different Compression Libraries. diff --git a/dlio_benchmark/configs/workload/resnet50.yaml b/dlio_benchmark/configs/workload/resnet50.yaml index d8376ed9..5608d287 100644 --- a/dlio_benchmark/configs/workload/resnet50.yaml +++ b/dlio_benchmark/configs/workload/resnet50.yaml @@ -11,12 +11,14 @@ dataset: num_samples_per_file: 1 record_length: 150528 data_folder: data/resnet50 - format: jpeg + format: png + +train: + computation_time: 0.317 # this is for A100 -train: - computation_time: 0.1 reader: data_loader: pytorch read_threads: 8 computation_threads: 8 + batch_size: 1 diff --git a/dlio_benchmark/data_loader/base_data_loader.py b/dlio_benchmark/data_loader/base_data_loader.py index 4b613ff7..eac78ac3 100644 --- a/dlio_benchmark/data_loader/base_data_loader.py +++ b/dlio_benchmark/data_loader/base_data_loader.py @@ -1,3 +1,19 @@ +""" + Copyright (c) 2022, UChicago Argonne, LLC + All Rights Reserved + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" import math import os from abc import ABC, abstractmethod @@ -17,6 +33,8 @@ def __init__(self, format_type, dataset_type, epoch_number, data_loader_type): self.format_type = format_type self.epoch_number = epoch_number self.data_loader_type = data_loader_type + self.num_samples = self._args.total_samples_train if self.dataset_type is DatasetType.TRAIN else self._args.total_samples_eval + self.batch_size = self._args.batch_size if self.dataset_type is DatasetType.TRAIN else self._args.batch_size_eval @abstractmethod def read(self): diff --git a/dlio_benchmark/data_loader/dali_data_loader.py b/dlio_benchmark/data_loader/dali_data_loader.py index 73e5eba5..4412ad4b 100644 --- a/dlio_benchmark/data_loader/dali_data_loader.py +++ b/dlio_benchmark/data_loader/dali_data_loader.py @@ -1,3 +1,19 @@ +""" + Copyright (c) 2022, UChicago Argonne, LLC + All Rights Reserved + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" from time import time import logging import math @@ -5,8 +21,6 @@ from nvidia.dali.pipeline import Pipeline import nvidia.dali.fn as fn import nvidia.dali.types as types -import nvidia.dali as dali -from nvidia.dali.plugin.pytorch import DALIGenericIterator from dlio_benchmark.common.constants import MODULE_DATA_LOADER from dlio_benchmark.common.enumerations import Shuffle, DataLoaderType, DatasetType @@ -19,43 +33,44 @@ class DaliDataset(object): - def __init__(self, format_type, dataset_type, epoch, num_samples, batch_size, thread_index): + def __init__(self, format_type, dataset_type, epoch, thread_index, + total_num_workers, total_num_samples, samples_per_worker, batch_size): self.format_type = format_type self.dataset_type = dataset_type self.epoch = epoch - self.num_samples = num_samples - self.num_images_read = 0 + self.total_num_workers = total_num_workers + self.total_num_samples = total_num_samples + self.samples_per_worker = samples_per_worker self.batch_size = batch_size + self.worker_index = thread_index self.reader = ReaderFactory.get_reader(type=self.format_type, dataset_type=self.dataset_type, thread_index=thread_index, epoch_number=self.epoch) - self.item = self.reader.next() - self.is_last = 0 def __call__(self, sample_info): - self.num_images_read += 1 - step = int(math.ceil(self.num_images_read / self.batch_size)) - sample_idx = sample_info.idx_in_epoch - logging.debug(f"{utcnow()} Reading {sample_idx} {sample_info.iteration} {self.num_samples} {self.batch_size}") - if sample_info.iteration >= self.num_samples // self.batch_size: + logging.debug( + f"{utcnow()} Reading {sample_info.idx_in_epoch} out of {self.samples_per_worker} by worker {self.worker_index}") + sample_idx = sample_info.idx_in_epoch + self.total_num_workers * self.worker_index + logging.debug( + f"{utcnow()} Reading {sample_idx} on {sample_info.iteration} by worker {self.worker_index}") + if sample_info.iteration >= self.samples_per_worker or sample_idx >= self.total_num_samples: # Indicate end of the epoch raise StopIteration() - with Profile(MODULE_DATA_LOADER, epoch=self.epoch,image_idx=sample_idx, step=step): + + step = int(math.ceil(sample_idx / self.batch_size)) + with Profile(MODULE_DATA_LOADER, epoch=self.epoch, image_idx=sample_idx, step=step): image = self.reader.read_index(sample_idx, step) return image, np.uint8([sample_idx]) - class DaliDataLoader(BaseDataLoader): @dlp.log_init def __init__(self, format_type, dataset_type, epoch): super().__init__(format_type, dataset_type, epoch, DataLoaderType.DALI) - self.pipeline = None + self.pipelines = [] @dlp.log def read(self): - num_samples = self._args.total_samples_train if self.dataset_type is DatasetType.TRAIN else self._args.total_samples_eval - batch_size = self._args.batch_size if self.dataset_type is DatasetType.TRAIN else self._args.batch_size_eval parallel = True if self._args.read_threads > 0 else False self.pipelines = [] num_threads = 1 @@ -64,35 +79,47 @@ def read(self): prefetch_size = 2 if self._args.prefetch_size > 0: prefetch_size = self._args.prefetch_size - # None executes pipeline on CPU and the reader does the batching - dataset = DaliDataset(self.format_type, self.dataset_type, self.epoch_number, num_samples, batch_size, 0) - self.pipeline = Pipeline(batch_size=batch_size, num_threads=num_threads, device_id=None, py_num_workers=num_threads, - prefetch_queue_depth=prefetch_size, py_start_method='fork', exec_async=True) - with self.pipeline: - images, labels = fn.external_source(source=dataset, num_outputs=2, dtype=[types.UINT8, types.UINT8], - parallel=True, batch=False) - self.pipeline.set_outputs(images, labels) - - self.pipeline.build() + num_pipelines = 1 + samples_per_worker = self.num_samples // num_pipelines // self._args.comm_size + + for worker_index in range(num_pipelines): + global_worker_index = self._args.my_rank * num_pipelines + worker_index + # None executes pipeline on CPU and the reader does the batching + dataset = DaliDataset(self.format_type, self.dataset_type, self.epoch_number, global_worker_index, + self._args.comm_size * num_pipelines, self.num_samples, samples_per_worker, self.batch_size) + pipeline = Pipeline(batch_size=self.batch_size, num_threads=num_threads, device_id=None, py_num_workers=num_threads//num_pipelines, + prefetch_queue_depth=prefetch_size, py_start_method=self._args.multiprocessing_context, exec_async=True) + with pipeline: + images, labels = fn.external_source(source=dataset, num_outputs=2, dtype=[types.UINT8, types.UINT8], + parallel=True, batch=False) + pipeline.set_outputs(images, labels) + self.pipelines.append(pipeline) + for pipe in self.pipelines: + pipe.start_py_workers() + for pipe in self.pipelines: + pipe.build() + for pipe in self.pipelines: + pipe.schedule_run() logging.debug(f"{utcnow()} Starting {num_threads} pipelines by {self._args.my_rank} rank ") @dlp.log def next(self): super().next() - num_samples = self._args.total_samples_train if self.dataset_type is DatasetType.TRAIN else self._args.total_samples_eval - batch_size = self._args.batch_size if self.dataset_type is DatasetType.TRAIN else self._args.batch_size_eval - #DALIGenericIterator(self.pipelines, ['data', 'label']) + # DALIGenericIterator(self.pipelines, ['data', 'label']) logging.debug(f"{utcnow()} Iterating pipelines by {self._args.my_rank} rank ") - for step in range(num_samples // batch_size): - outputs = self.pipeline.run() - logging.debug(f"{utcnow()} Output batch {step} {len(outputs)}") - for batch in outputs: + step = 0 + while step <= self.num_samples // self.batch_size: + for pipe in self.pipelines: + outputs = pipe.share_outputs() + logging.debug(f"{utcnow()} Output batch {step} {len(outputs)}") yield outputs - - - + step += 1 + pipe.release_outputs() + pipe.schedule_run() + + @dlp.log def finalize(self): pass diff --git a/dlio_benchmark/data_loader/tf_data_loader.py b/dlio_benchmark/data_loader/tf_data_loader.py index 304d10a3..f851a7aa 100644 --- a/dlio_benchmark/data_loader/tf_data_loader.py +++ b/dlio_benchmark/data_loader/tf_data_loader.py @@ -1,3 +1,19 @@ +""" + Copyright (c) 2022, UChicago Argonne, LLC + All Rights Reserved + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" from time import time import logging import math @@ -64,12 +80,11 @@ def read(self): options.experimental_threading.private_threadpool_size = read_threads options.experimental_threading.max_intra_op_parallelism = read_threads - batch_size = self._args.batch_size if self.dataset_type is DatasetType.TRAIN else self._args.batch_size_eval self._dataset = tf.data.Dataset.from_tensor_slices(np.arange(read_threads)).with_options(options) self._dataset = self._dataset.interleave(lambda x: TensorflowDataset(self.format_type, self.dataset_type, self.epoch_number, ( - batch_size, + self.batch_size, self._args.max_dimension, self._args.max_dimension), x), cycle_length=read_threads, diff --git a/dlio_benchmark/data_loader/torch_data_loader.py b/dlio_benchmark/data_loader/torch_data_loader.py index 0f42806e..35c9b32d 100644 --- a/dlio_benchmark/data_loader/torch_data_loader.py +++ b/dlio_benchmark/data_loader/torch_data_loader.py @@ -1,3 +1,19 @@ +""" + Copyright (c) 2022, UChicago Argonne, LLC + All Rights Reserved + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" from time import time import logging import math @@ -57,14 +73,12 @@ def __init__(self, format_type, dataset_type, epoch_number): @dlp.log def read(self): do_shuffle = True if self._args.sample_shuffle != Shuffle.OFF else False - num_samples = self._args.total_samples_train if self.dataset_type is DatasetType.TRAIN else self._args.total_samples_eval - batch_size = self._args.batch_size if self.dataset_type is DatasetType.TRAIN else self._args.batch_size_eval - dataset = TorchDataset(self.format_type, self.dataset_type, self.epoch_number, num_samples, self._args.read_threads, batch_size) + dataset = TorchDataset(self.format_type, self.dataset_type, self.epoch_number, self.num_samples, self._args.read_threads, self.batch_size) if do_shuffle: sampler = RandomSampler(dataset) else: sampler = SequentialSampler(dataset) - if self._args.read_threads > 1: + if self._args.read_threads >= 1: prefetch_factor = math.ceil(self._args.prefetch_size / self._args.read_threads) else: prefetch_factor = self._args.prefetch_size @@ -73,28 +87,37 @@ def read(self): logging.debug( f"{utcnow()} Prefetch size is {self._args.prefetch_size}; prefetch factor of {prefetch_factor} will be set to Torch DataLoader.") else: + prefetch_factor = 2 if self._args.my_rank == 0: logging.debug( f"{utcnow()} Prefetch size is 0; a default prefetch factor of 2 will be set to Torch DataLoader.") logging.debug(f"{utcnow()} Setup dataloader with {self._args.read_threads} workers {torch.__version__}") + if self._args.read_threads==0: + kwargs={} + else: + kwargs={'multiprocessing_context':self._args.multiprocessing_context, + 'prefetch_factor': prefetch_factor, + 'persistent_workers': True} if torch.__version__ == '1.3.1': + if 'prefetch_factor' in kargs: + del kwargs['prefetch_factor'] self._dataset = DataLoader(dataset, - batch_size=batch_size, - sampler=sampler, - num_workers=self._args.read_threads, - pin_memory=True, - drop_last=True, - worker_init_fn=dataset.worker_init) + batch_size=self.batch_size, + sampler=sampler, + num_workers=self._args.read_threads, + pin_memory=True, + drop_last=True, + worker_init_fn=dataset.worker_init, **kwargs) else: self._dataset = DataLoader(dataset, - batch_size=batch_size, - sampler=sampler, - num_workers=self._args.read_threads, - pin_memory=True, - drop_last=True, - worker_init_fn=dataset.worker_init, - prefetch_factor=prefetch_factor if prefetch_factor > 0 else 2) # 2 is the default value - logging.debug(f"{utcnow()} Rank {self._args.my_rank} will read {len(self._dataset) * batch_size} files") + batch_size=self.batch_size, + sampler=sampler, + num_workers=self._args.read_threads, + pin_memory=True, + drop_last=True, + worker_init_fn=dataset.worker_init, + **kwargs) # 2 is the default value + logging.debug(f"{utcnow()} Rank {self._args.my_rank} will read {len(self._dataset) * self.batch_size} files") # self._dataset.sampler.set_epoch(epoch_number) diff --git a/dlio_benchmark/main.py b/dlio_benchmark/main.py index 7b490bbb..81eba422 100644 --- a/dlio_benchmark/main.py +++ b/dlio_benchmark/main.py @@ -205,10 +205,10 @@ def initialize(self): file_list_train = fullpaths elif dataset_type is DatasetType.VALID: file_list_eval = fullpaths - if not self.generate_only: - assert(self.num_files_train <=len(file_list_train)) - if self.do_eval: - assert(self.num_files_eval <=len(file_list_eval)) + if not self.generate_only and self.num_files_train > len(file_list_train): + raise Exception("Not enough training dataset is found; Please run the code with ++workload.workflow.generate_data=True") + if self.do_eval and self.num_files_eval > len(file_list_eval): + raise Exception("Not enough evaluation dataset is found; Please run the code with ++workload.workflow.generate_data=True") if (self.num_files_train < len(file_list_train)): logging.warning(f"Number of files for training in {os.path.join(self.args.data_folder, f'{DatasetType.TRAIN}')} ({len(file_list_train)}) is more than requested ({self.num_files_train}). A subset of files will be used ") file_list_train = file_list_train[:self.num_files_train] diff --git a/dlio_benchmark/reader/tf_reader.py b/dlio_benchmark/reader/tf_reader.py index b4147206..8a881af7 100644 --- a/dlio_benchmark/reader/tf_reader.py +++ b/dlio_benchmark/reader/tf_reader.py @@ -36,7 +36,8 @@ class TFReader(FormatReader): def __init__(self, dataset_type, thread_index, epoch): super().__init__(dataset_type, thread_index) self._dataset = None - + self._file_list = self._args.file_list_train if self.dataset_type is DatasetType.TRAIN else self._args.file_list_eval + self.batch_size = self._args.batch_size if self.dataset_type is DatasetType.TRAIN else self._args.batch_size_eval @dlp.log def open(self, filename): pass @@ -79,17 +80,15 @@ def parse_image(self, serialized): @dlp.log def next(self): - _file_list = self._args.file_list_train if self.dataset_type is DatasetType.TRAIN else self._args.file_list_eval - batch_size = self._args.batch_size if self.dataset_type is DatasetType.TRAIN else self._args.batch_size_eval logging.debug( - f"{utcnow()} Reading {len(_file_list)} files thread {self.thread_index} rank {self._args.my_rank}") - self._dataset = tf.data.TFRecordDataset(filenames=_file_list, buffer_size=self._args.transfer_size) + f"{utcnow()} Reading {len(self._file_list)} files thread {self.thread_index} rank {self._args.my_rank}") + self._dataset = tf.data.TFRecordDataset(filenames=self._file_list, buffer_size=self._args.transfer_size) self._dataset = self._dataset.shard(num_shards=self._args.comm_size, index=self._args.my_rank) self._dataset = self._dataset.map( lambda x: tf.py_function(func=self.parse_image, inp=[x], Tout=[tf.uint8]) , num_parallel_calls=self._args.computation_threads) - self._dataset = self._dataset.batch(batch_size, drop_remainder=True) - total = math.ceil(len(_file_list)/self._args.comm_size / batch_size * self._args.num_samples_per_file) + self._dataset = self._dataset.batch(self.batch_size, drop_remainder=True) + total = math.ceil(len(self._file_list)/self._args.comm_size / self.batch_size * self._args.num_samples_per_file) step = 1 for batch in self._dataset: is_last = 0 if step <= total else 1 @@ -104,4 +103,4 @@ def read_index(self, image_idx, step): @dlp.log def finalize(self): - return super().finalize() \ No newline at end of file + return super().finalize() diff --git a/dlio_benchmark/storage/file_storage.py b/dlio_benchmark/storage/file_storage.py index c4bcd6a5..e2557cdc 100644 --- a/dlio_benchmark/storage/file_storage.py +++ b/dlio_benchmark/storage/file_storage.py @@ -1,3 +1,19 @@ +""" + Copyright (c) 2022, UChicago Argonne, LLC + All Rights Reserved + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" from abc import ABC, abstractmethod from time import time diff --git a/dlio_benchmark/storage/s3_storage.py b/dlio_benchmark/storage/s3_storage.py index 8c6491ee..8381deb5 100644 --- a/dlio_benchmark/storage/s3_storage.py +++ b/dlio_benchmark/storage/s3_storage.py @@ -1,3 +1,19 @@ +""" + Copyright (c) 2022, UChicago Argonne, LLC + All Rights Reserved + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" from time import time from dlio_benchmark.common.constants import MODULE_STORAGE diff --git a/dlio_benchmark/storage/storage_factory.py b/dlio_benchmark/storage/storage_factory.py index 38eaac15..fd532794 100644 --- a/dlio_benchmark/storage/storage_factory.py +++ b/dlio_benchmark/storage/storage_factory.py @@ -1,3 +1,19 @@ +""" + Copyright (c) 2022, UChicago Argonne, LLC + All Rights Reserved + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" from dlio_benchmark.storage.file_storage import FileStorage from dlio_benchmark.storage.s3_storage import S3Storage from dlio_benchmark.common.enumerations import StorageType diff --git a/dlio_benchmark/storage/storage_handler.py b/dlio_benchmark/storage/storage_handler.py index 26a18798..c83efb1c 100644 --- a/dlio_benchmark/storage/storage_handler.py +++ b/dlio_benchmark/storage/storage_handler.py @@ -1,3 +1,19 @@ +""" + Copyright (c) 2022, UChicago Argonne, LLC + All Rights Reserved + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" from abc import ABC, abstractmethod from dlio_benchmark.framework.framework_factory import FrameworkFactory from dlio_benchmark.utils.config import ConfigArguments diff --git a/dlio_benchmark/utils/config.py b/dlio_benchmark/utils/config.py index 449637d3..acc43c66 100644 --- a/dlio_benchmark/utils/config.py +++ b/dlio_benchmark/utils/config.py @@ -69,7 +69,7 @@ class ConfigArguments: log_file: str = "dlio.log" file_prefix: str = "img" keep_files: bool = True - do_profiling: bool = True + do_profiling: bool = False profiler: Profiler = Profiler.IOSTAT seed: int = 123 do_checkpoint: bool = False @@ -105,6 +105,7 @@ class ConfigArguments: data_loader_classname = None data_loader_sampler: DataLoaderSampler = None reader_classname: str = None + multiprocessing_context: str = "fork" # derived fields required_samples: int = 1 @@ -123,6 +124,7 @@ class ConfigArguments: global_index_map = None data_loader_class = None reader_class = None + def __init__(self): """ Virtually private constructor. """ @@ -148,9 +150,9 @@ def validate(self): if ('LD_PRELOAD' not in os.environ or os.environ["LD_PRELOAD"].find("libdarshan") == -1): raise Exception("Please set darshan runtime library in LD_PRELOAD") if self.format is FormatType.TFRECORD and self.framework is not FrameworkType.TENSORFLOW: - raise Exception("Imcompatible between format and framework setup.") + raise Exception(f"{self.framework} support for tfrecord is not implemented.") if self.format is FormatType.TFRECORD and self.data_loader is not DataLoaderType.TENSORFLOW: - raise Exception("Imcompatible between format and data loader setup.") + raise Exception(f"{self.data_loader} support for tfrecord is not implemented.") if (self.framework == FrameworkType.TENSORFLOW and self.data_loader == DataLoaderType.PYTORCH) or ( self.framework == FrameworkType.PYTORCH and self.data_loader == DataLoaderType.TENSORFLOW): raise Exception("Imcompatible between framework and data_loader setup.") @@ -351,6 +353,8 @@ def LoadConfig(args, config): if reader is not None: if 'reader_classname' in reader: args.reader_classname = reader['reader_classname'] + if 'multiprocessing_context' in reader: + args.multiprocessing_context = reader['multiprocessing_context'] if 'data_loader' in reader: args.data_loader = DataLoaderType(reader['data_loader']) if 'data_loader_classname' in reader: @@ -369,6 +373,8 @@ def LoadConfig(args, config): args.prefetch_size = reader['prefetch_size'] if 'file_shuffle' in reader: args.file_shuffle = reader['file_shuffle'] + if 'file_access' in reader: + args.file_access = FileAccess(reader['file_access']) if 'shuffle_size' in reader: args.shuffle_size = reader['shuffle_size'] if 'sample_shuffle' in reader: diff --git a/dlio_benchmark/utils/utility.py b/dlio_benchmark/utils/utility.py index 47dc9124..0442a0d8 100644 --- a/dlio_benchmark/utils/utility.py +++ b/dlio_benchmark/utils/utility.py @@ -188,6 +188,8 @@ def initialize_log(logdir, data_dir): import dlio_profiler_py as dlio_logger instance.logger = dlio_logger instance.logger.initialize(instance.log_file, f"{data_dir}", process_id=get_rank()) + with open(instance.log_file, 'w') as f: + f.write("[") else: instance.logger = logging.getLogger("perftrace") instance.logger.setLevel(logging.DEBUG) @@ -198,7 +200,7 @@ def initialize_log(logdir, data_dir): fh.setFormatter(formatter) instance.logger.addHandler(fh) instance.logger.debug("[") - + def get_time(self): if self.logger_type == LoggerType.DLIO_PROFILER: return self.logger.get_time() @@ -217,6 +219,8 @@ def log_event(self, name, cat, start_time, duration, int_args=None): def finalize(self): if self.logger_type == LoggerType.DLIO_PROFILER: self.logger.finalize() + with open(self.log_file, 'a') as f: + f.write("]") else: self.logger.debug("]") diff --git a/requirements.txt b/requirements.txt index 636f609f..1b37b4c6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -47,12 +47,12 @@ tensorflow==2.11.0 tensorflow-io==0.28.0 tensorflow-estimator==2.11.0 termcolor==2.1.1 -torch==1.13.0 -torchaudio==0.13.0 -torchvision==0.14.0 +torch==2.0.1 +torchaudio +torchvision typing_extensions==4.4.0 urllib3==1.26.12 Werkzeug==2.2.2 wrapt==1.14.1 nvidia-dali-cuda110 -psutil \ No newline at end of file +psutil