From 9864746ad5d05b12304fd20abf27dc764c463f5c Mon Sep 17 00:00:00 2001 From: Huihuo Zheng Date: Tue, 28 May 2024 21:52:22 +0000 Subject: [PATCH 1/6] added some more metric log --- .../data_loader/synthetic_data_loader.py | 8 ++--- dlio_benchmark/framework/framework.py | 12 ++++++- dlio_benchmark/framework/tf_framework.py | 5 ++- dlio_benchmark/main.py | 35 ++++++++++--------- dlio_benchmark/reader/tf_reader.py | 2 +- dlio_benchmark/utils/statscounter.py | 31 +++++++++------- 6 files changed, 55 insertions(+), 38 deletions(-) diff --git a/dlio_benchmark/data_loader/synthetic_data_loader.py b/dlio_benchmark/data_loader/synthetic_data_loader.py index ce54b54e..ab6796cc 100644 --- a/dlio_benchmark/data_loader/synthetic_data_loader.py +++ b/dlio_benchmark/data_loader/synthetic_data_loader.py @@ -36,6 +36,9 @@ class SyntheticDataLoader(BaseDataLoader): @dlp.log_init def __init__(self, format_type, dataset_type, epoch): super().__init__(format_type, dataset_type, epoch, DataLoaderType.SYNTHETIC) + shape = self._args.resized_image.shape + self.batch = np.zeros((self.batch_size, shape[0], shape[1])) + #self.batch = 1 @dlp.log def read(self, init=False): @@ -48,10 +51,7 @@ def next(self): step = 0 self.read(True) while step < self.num_samples // self.batch_size: - batch = [] - for i in range(self.batch_size): - batch.append(self._args.resized_image) - yield batch + yield self.batch step += 1 @dlp.log diff --git a/dlio_benchmark/framework/framework.py b/dlio_benchmark/framework/framework.py index 02731d8f..4c42bb57 100644 --- a/dlio_benchmark/framework/framework.py +++ b/dlio_benchmark/framework/framework.py @@ -20,11 +20,21 @@ from dlio_benchmark.common.enumerations import DatasetType from dlio_benchmark.data_loader.data_loader_factory import DataLoaderFactory from dlio_benchmark.storage.storage_factory import StorageFactory -from dlio_benchmark.utils.utility import utcnow +from dlio_benchmark.utils.utility import utcnow, DLIOMPI +comm = DLIOMPI.get_instance().comm() from time import sleep import os import logging +from multiprocessing import Process +def emulate_compute(computation_time): + sleep(computation_time) + comm.barrier() + +def async_compute(computation_time): + p = Process(target=emulate_compute, args=(computation_time,)) + p.start() + return p from dlio_benchmark.utils.config import ConfigArguments diff --git a/dlio_benchmark/framework/tf_framework.py b/dlio_benchmark/framework/tf_framework.py index 2e21e151..0a079af7 100644 --- a/dlio_benchmark/framework/tf_framework.py +++ b/dlio_benchmark/framework/tf_framework.py @@ -18,13 +18,12 @@ import os import logging from time import time, sleep - from dlio_benchmark.common.constants import MODULE_AI_FRAMEWORK from dlio_benchmark.data_loader.data_loader_factory import DataLoaderFactory from dlio_benchmark.utils.utility import utcnow, DLIOMPI from dlio_profiler.logger import fn_interceptor as Profile from dlio_benchmark.common.error_code import ErrorCodes -from dlio_benchmark.framework.framework import Framework +from dlio_benchmark.framework.framework import Framework, async_compute, emulate_compute from dlio_benchmark.reader.reader_factory import ReaderFactory from dlio_benchmark.profiler.profiler_factory import ProfilerFactory from dlio_benchmark.storage.storage_factory import StorageFactory @@ -87,7 +86,7 @@ def trace_object(self, string, step, r): @dlp.log def compute(self, x, epoch_number, step, computation_time): - sleep(computation_time) + emulate_compute(computation_time) # tf.function(self.model)(epoch_number, step, computation_time) @dlp.log diff --git a/dlio_benchmark/main.py b/dlio_benchmark/main.py index ebeb5a9f..5da214ca 100644 --- a/dlio_benchmark/main.py +++ b/dlio_benchmark/main.py @@ -224,22 +224,23 @@ def _eval(self, epoch): step = 1 total = math.floor(self.num_samples * self.num_files_eval / self.batch_size_eval / self.comm_size) loader = self.framework.get_loader(DatasetType.VALID) - t0 = time() + self.stats.start_loading() for batch in dlp.iter(loader.next()): - self.stats.eval_batch_loaded(epoch, step, t0) + self.stats.eval_batch_loaded(epoch, step) eval_time = 0.0 if self.eval_time > 0: if self.eval_time_stdev > 0: eval_time = random.normal(self.eval_time, self.eval_time_stdev) else: eval_time = self.eval_time + self.stats.start_compute() self.framework.compute(batch, epoch, step, eval_time) - self.stats.eval_batch_processed(epoch, step, t0, eval_time) + self.stats.eval_batch_processed(epoch, step) step += 1 if step > total: break - t0 = time() + self.stats.start_loading() return step - 1 @dlp.log @@ -256,22 +257,22 @@ def _train(self, epoch): self.stats.start_block(epoch, block) loader = self.framework.get_loader(dataset_type=DatasetType.TRAIN) - t0 = time() - for batch in dlp.iter(loader.next()): - self.stats.batch_loaded(epoch, overall_step, block, t0) + self.stats.start_loading() + for batch in loader.next(): + self.stats.batch_loaded(epoch, overall_step, block) # Log a new block, unless it's the first one which we've already logged before the loop if block_step == 1 and block != 1: self.stats.start_block(epoch, block) computation_time = self.computation_time - if self.computation_time > 0: - self.framework.trace_object("Train", overall_step, 1) - if self.computation_time_stdev > 0: - computation_time = random.normal(self.computation_time, self.computation_time_stdev) - else: - computation_time = self.computation_time - self.framework.compute(batch, epoch, block_step, computation_time) - self.stats.batch_processed(epoch, overall_step, block, t0, computation_time) - self.comm.barrier() +# if self.computation_time > 0: +# self.framework.trace_object("Train", overall_step, 1) +# if self.computation_time_stdev > 0: +# computation_time = random.normal(self.computation_time, self.computation_time_stdev) +# else: +# computation_time = self.computation_time + self.stats.start_compute() + self.framework.compute(batch, epoch, block_step, self.computation_time) + self.stats.batch_processed(epoch, overall_step, block) if self.do_checkpoint and ( self.steps_between_checkpoints >= 0) and overall_step == self.next_checkpoint_step: self.stats.end_block(epoch, block, block_step) @@ -292,7 +293,7 @@ def _train(self, epoch): self.stats.end_block(epoch, block, block_step - 1) break overall_step += 1 - t0 = time() + self.stats.start_loading() self.comm.barrier() if self.do_checkpoint and (self.steps_between_checkpoints < 0) and (epoch == self.next_checkpoint_epoch): self.stats.end_block(epoch, block, block_step) diff --git a/dlio_benchmark/reader/tf_reader.py b/dlio_benchmark/reader/tf_reader.py index 4b71a15f..17b5baa7 100644 --- a/dlio_benchmark/reader/tf_reader.py +++ b/dlio_benchmark/reader/tf_reader.py @@ -67,7 +67,7 @@ def _parse_image(self, serialized): 'image': tf.io.FixedLenFeature([], tf.string), 'size': tf.io.FixedLenFeature([], tf.int64) } - parsed_example = tf.io.parse_example(serialized=serialized, features=features) + #parsed_example = tf.io.parse_example(serialized=serialized, features=features) # Get the image as raw bytes. #image_raw = parsed_example['image'] #dimension = tf.cast(parsed_example['size'], tf.int32).numpy() diff --git a/dlio_benchmark/utils/statscounter.py b/dlio_benchmark/utils/statscounter.py index 8ba652e8..90fcd9de 100644 --- a/dlio_benchmark/utils/statscounter.py +++ b/dlio_benchmark/utils/statscounter.py @@ -303,8 +303,12 @@ def end_ckpt(self, epoch, block): self.per_epoch_stats[epoch][f'ckpt{block}']['end'] = ts self.per_epoch_stats[epoch][f'ckpt{block}']['duration'] = duration - def batch_loaded(self, epoch, step, block, t0): - duration = time() - t0 + def start_loading(self): + self.start_time_loading = time() + def start_compute(self): + self.start_time_compute = time() + def batch_loaded(self, epoch, step, block): + duration = time() - self.start_time_loading key = f'block{block}' if key in self.output[epoch]['load']: self.output[epoch]['load'][key].append(duration) @@ -312,17 +316,18 @@ def batch_loaded(self, epoch, step, block, t0): self.output[epoch]['load'][key] = [duration] logging.debug(f"{utcnow()} Rank {self.my_rank} step {step}: loaded {self.batch_size} samples in {duration} s") - - def batch_processed(self, epoch, step, block, t0, computation_time): - duration = time() - t0 + def batch_processed(self, epoch, step, block): + current_time = time() + duration = current_time - self.start_time_loading key = f'block{block}' + self.computation_time = current_time - self.start_time_compute if key in self.output[epoch]['proc']: self.output[epoch]['proc'][key].append(duration) - self.output[epoch]['compute'][key].append(computation_time) + self.output[epoch]['compute'][key].append(self.computation_time) else: self.output[epoch]['proc'] = [duration] - self.output[epoch]['compute']=[computation_time] - logging.info(f"{utcnow()} Rank {self.my_rank} step {step} processed {self.batch_size} samples in {duration} s") + self.output[epoch]['compute']=[self.computation_time] + logging.info(f"{utcnow()} Rank {self.my_rank} step {step} processed {self.batch_size} samples in {duration}s)") def compute_metrics_train(self, epoch, block): key = f"block{block}" @@ -348,14 +353,16 @@ def compute_metrics_eval(self, epoch): self.output[epoch]['au'][key] = au*100 self.output[epoch]['throughput'][key] = throughput - def eval_batch_loaded(self, epoch, step, t0): - duration = time() - t0 + def eval_batch_loaded(self, epoch, step): + duration = time() - self.start_time_loading self.output[epoch]['load']['eval'].append(duration) logging.debug(f"{utcnow()} Rank {self.my_rank} step {step} loaded {self.batch_size_eval} samples in {duration} s") - def eval_batch_processed(self, epoch, step, t0, computation_time): - duration = time() - t0 + def eval_batch_processed(self, epoch, step): + current_time = time() + duration = current_time - self.start_time_loading + computation_time = current_time -self.start_time_compute self.output[epoch]['proc']['eval'].append(duration) self.output[epoch]['compute']['eval'].append(computation_time) logging.info(f"{utcnow()} Rank {self.my_rank} step {step} processed {self.batch_size_eval} samples in {duration} s") From e1ed0ab66175394bee6d0eb297ba3c6f00aa8313 Mon Sep 17 00:00:00 2001 From: Huihuo Zheng Date: Wed, 29 May 2024 14:27:52 +0000 Subject: [PATCH 2/6] local change --- dlio_benchmark/framework/framework.py | 2 +- dlio_benchmark/framework/tf_framework.py | 2 +- dlio_benchmark/framework/torch_framework.py | 2 +- dlio_benchmark/reader/tf_reader.py | 8 ++++---- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/dlio_benchmark/framework/framework.py b/dlio_benchmark/framework/framework.py index 4c42bb57..7d73ca71 100644 --- a/dlio_benchmark/framework/framework.py +++ b/dlio_benchmark/framework/framework.py @@ -79,7 +79,7 @@ def stop_framework_profiler(self): def trace_object(self, string, step, r): pass - def model(epoch, epoch_number, step, computation_time): + def model(epoch, x, computation_time): sleep(computation_time) @abstractmethod diff --git a/dlio_benchmark/framework/tf_framework.py b/dlio_benchmark/framework/tf_framework.py index 0a079af7..1cbf8c19 100644 --- a/dlio_benchmark/framework/tf_framework.py +++ b/dlio_benchmark/framework/tf_framework.py @@ -86,7 +86,7 @@ def trace_object(self, string, step, r): @dlp.log def compute(self, x, epoch_number, step, computation_time): - emulate_compute(computation_time) + return self.model(x, computation_time) # tf.function(self.model)(epoch_number, step, computation_time) @dlp.log diff --git a/dlio_benchmark/framework/torch_framework.py b/dlio_benchmark/framework/torch_framework.py index 8660914c..6022a3f1 100644 --- a/dlio_benchmark/framework/torch_framework.py +++ b/dlio_benchmark/framework/torch_framework.py @@ -93,7 +93,7 @@ def trace_object(self, string, step, r): @dlp.log def compute(self, x, epoch_number, step, computation_time): - torch_sleep(computation_time) + return self.model(x, computation_time) @dlp.log def get_loader(self, dataset_type=DatasetType.TRAIN): diff --git a/dlio_benchmark/reader/tf_reader.py b/dlio_benchmark/reader/tf_reader.py index 17b5baa7..ece29d6f 100644 --- a/dlio_benchmark/reader/tf_reader.py +++ b/dlio_benchmark/reader/tf_reader.py @@ -67,7 +67,7 @@ def _parse_image(self, serialized): 'image': tf.io.FixedLenFeature([], tf.string), 'size': tf.io.FixedLenFeature([], tf.int64) } - #parsed_example = tf.io.parse_example(serialized=serialized, features=features) + parsed_example = tf.io.parse_example(serialized=serialized, features=features) # Get the image as raw bytes. #image_raw = parsed_example['image'] #dimension = tf.cast(parsed_example['size'], tf.int32).numpy() @@ -85,9 +85,6 @@ def next(self): f"{utcnow()} Reading {len(self._file_list)} files thread {self.thread_index} rank {self._args.my_rank}") self._dataset = tf.data.TFRecordDataset(filenames=self._file_list, buffer_size=self._args.transfer_size, num_parallel_reads=self._args.read_threads) - self._dataset = self._dataset.map( - lambda x: tf.py_function(func=self._parse_image, inp=[x], Tout=[tf.uint8]), - num_parallel_calls=self._args.computation_threads) self._dataset = self._dataset.shard(num_shards=self._args.comm_size, index=self._args.my_rank) if self._args.sample_shuffle != Shuffle.OFF: @@ -97,6 +94,9 @@ def next(self): else: self._dataset = self._dataset.shuffle(buffer_size=self._args.shuffle_size) self._dataset = self._dataset.batch(self.batch_size, drop_remainder=True) + self._dataset = self._dataset.map( + lambda x: tf.py_function(func=self._parse_image, inp=[x], Tout=[tf.uint8]), + num_parallel_calls=self._args.computation_threads) self._dataset = self._dataset.repeat(self._args.epochs) total = math.ceil(len(self._file_list)/self._args.comm_size / self.batch_size * self._args.num_samples_per_file) return self._dataset.take(total*self._args.epochs).prefetch(buffer_size=self._args.prefetch_size) From d8a9d058911d8d402b1700581938d0f0ec451d4e Mon Sep 17 00:00:00 2001 From: Huihuo Zheng Date: Wed, 29 May 2024 15:38:50 +0000 Subject: [PATCH 3/6] remove computation_time_stev input variable --- dlio_benchmark/framework/framework.py | 9 +-------- dlio_benchmark/main.py | 7 ------- dlio_benchmark/utils/config.py | 2 -- docs/source/config.rst | 3 --- docs/source/testedsystems.rst | 2 +- 5 files changed, 2 insertions(+), 21 deletions(-) diff --git a/dlio_benchmark/framework/framework.py b/dlio_benchmark/framework/framework.py index 7d73ca71..cc38e668 100644 --- a/dlio_benchmark/framework/framework.py +++ b/dlio_benchmark/framework/framework.py @@ -27,14 +27,6 @@ import os import logging from multiprocessing import Process -def emulate_compute(computation_time): - sleep(computation_time) - comm.barrier() - -def async_compute(computation_time): - p = Process(target=emulate_compute, args=(computation_time,)) - p.start() - return p from dlio_benchmark.utils.config import ConfigArguments @@ -81,6 +73,7 @@ def trace_object(self, string, step, r): def model(epoch, x, computation_time): sleep(computation_time) + comm.barrier() @abstractmethod def compute(self, x, epoch_number, step, computation_time): diff --git a/dlio_benchmark/main.py b/dlio_benchmark/main.py index 5da214ca..885b1cec 100644 --- a/dlio_benchmark/main.py +++ b/dlio_benchmark/main.py @@ -263,13 +263,6 @@ def _train(self, epoch): # Log a new block, unless it's the first one which we've already logged before the loop if block_step == 1 and block != 1: self.stats.start_block(epoch, block) - computation_time = self.computation_time -# if self.computation_time > 0: -# self.framework.trace_object("Train", overall_step, 1) -# if self.computation_time_stdev > 0: -# computation_time = random.normal(self.computation_time, self.computation_time_stdev) -# else: -# computation_time = self.computation_time self.stats.start_compute() self.framework.compute(batch, epoch, block_step, self.computation_time) self.stats.batch_processed(epoch, overall_step, block) diff --git a/dlio_benchmark/utils/config.py b/dlio_benchmark/utils/config.py index ab793d5b..d2093754 100644 --- a/dlio_benchmark/utils/config.py +++ b/dlio_benchmark/utils/config.py @@ -496,8 +496,6 @@ def LoadConfig(args, config): args.seed_change_epoch = config['train']['seed_change_epoch'] if 'computation_time' in config['train']: args.computation_time = config['train']['computation_time'] - if 'computation_time_stdev' in config['train']: - args.computation_time_stdev = config['train']['computation_time_stdev'] if 'seed' in config['train']: args.seed = config['train']['seed'] diff --git a/docs/source/config.rst b/docs/source/config.rst index b8c6f25e..02917366 100644 --- a/docs/source/config.rst +++ b/docs/source/config.rst @@ -253,9 +253,6 @@ train * - computation_time - 0.0 - emulated computation time per step in second - * - computation_time_stdev - - 0.0 - - standard deviation of the emulated computation time per step in second * - total_training_steps - -1 - number of training steps to simulate, assuming running the benchmark less than one epoch. diff --git a/docs/source/testedsystems.rst b/docs/source/testedsystems.rst index 4ced28af..265aaaac 100644 --- a/docs/source/testedsystems.rst +++ b/docs/source/testedsystems.rst @@ -4,4 +4,4 @@ Tested systems ================ So far we have tested DLIO on the following systems: * Personal workstation, laptops including both MacOSX and Linux OS system. - * Supercomputers (Linux), such as Theta @ ALCF, Summit @ OLCF, Lassen @ LLNL (please turn to: `instructions_lassen.rst`_ for instructions) + * Supercomputers (Linux), such as Polaris @ ALCF, Summit @ OLCF, Lassen @ LLNL (please turn to: `instructions_lassen.rst`_ for instructions) From e0bdf291efc5b5fa0834a733776f7e40c3f2b527 Mon Sep 17 00:00:00 2001 From: Huihuo Zheng Date: Wed, 29 May 2024 16:06:03 +0000 Subject: [PATCH 4/6] fixed bugs for tests --- dlio_benchmark/framework/tf_framework.py | 2 +- dlio_benchmark/utils/statscounter.py | 8 +++++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/dlio_benchmark/framework/tf_framework.py b/dlio_benchmark/framework/tf_framework.py index 1cbf8c19..f618b4fa 100644 --- a/dlio_benchmark/framework/tf_framework.py +++ b/dlio_benchmark/framework/tf_framework.py @@ -23,7 +23,7 @@ from dlio_benchmark.utils.utility import utcnow, DLIOMPI from dlio_profiler.logger import fn_interceptor as Profile from dlio_benchmark.common.error_code import ErrorCodes -from dlio_benchmark.framework.framework import Framework, async_compute, emulate_compute +from dlio_benchmark.framework.framework import Framework from dlio_benchmark.reader.reader_factory import ReaderFactory from dlio_benchmark.profiler.profiler_factory import ProfilerFactory from dlio_benchmark.storage.storage_factory import StorageFactory diff --git a/dlio_benchmark/utils/statscounter.py b/dlio_benchmark/utils/statscounter.py index 66e1685f..f734640f 100644 --- a/dlio_benchmark/utils/statscounter.py +++ b/dlio_benchmark/utils/statscounter.py @@ -166,11 +166,17 @@ def end_run(self): metric = metric + f"[METRIC] Training Accelerator Utilization [AU] (%): {np.mean(train_au):.4f} ({np.std(train_au):.4f})\n" metric = metric + f"[METRIC] Training Throughput (samples/second): {np.mean(train_throughput):.4f} ({np.std(train_throughput):.4f})\n" metric = metric + f"[METRIC] Training I/O Throughput (MB/second): {np.mean(train_throughput)*self.record_size/1024/1024:.4f} ({np.std(train_throughput)*self.record_size/1024/1024:.4f})\n" + metric = metric + f"[METRIC] --\n" + metric = metric + f"[METRIC] Training Throughput (ideal) (samples/second): {self.batch_size / self.args.computation_time * self.MPI.size():.4f}\n" + metric = metric + f"[METRIC] Training I/O Throughput (ideal) (MB/second): {self.batch_size * self.record_size/1024/1024 / self.args.computation_time * self.MPI.size():.4f}\n" if self.args.do_eval: metric = metric + f"[METRIC] Eval Accelerator Utilization [AU] (%): {np.mean(eval_au):.4f} ({np.std(eval_au):.4f})\n" metric = metric + f"[METRIC] Eval Throughput (samples/second): {np.mean(eval_throughput):.6f} ({np.std(eval_throughput):.6f})\n" metric = metric + f"[METRIC] Eval Throughput (MB/second): {np.mean(eval_throughput)*self.record_size/1024/1024:.6f} ({np.std(eval_throughput)*self.record_size/1024/1024:.6f})\n" + metric = metric + f"[METRIC] --\n" + metric = metric + f"[METRIC] Eval Throughput (ideal) (samples/second): {self.batch_size / self.args.eval_time * self.MPI.size():.4f}\n" + metric = metric + f"[METRIC] Eval I/O Throughput (ideal) (MB/second): {self.batch_size * self.record_size/1024/1024 / self.args.eval_time * self.MPI.size():.4f}\n" metric+="[METRIC] ==========================================================\n" logging.info(metric) def start_train(self, epoch): @@ -322,10 +328,10 @@ def batch_processed(self, epoch, step, block): def compute_metrics_train(self, epoch, block): key = f"block{block}" total_compute_time = np.sum(self.output[epoch]['compute'][key][1:-1]) + total_time = self.end_timestamp - self.start_timestamp - self.output[epoch]['proc'][key][0] - self.output[epoch]['proc'][key][-1] if (total_compute_time==0): au=0.0 else: - total_time = self.end_timestamp - self.start_timestamp - self.output[epoch]['proc'][key][0] - self.output[epoch]['proc'][key][-1] au = total_compute_time / total_time throughput = (len(self.output[epoch]['compute'][key]) - 2)/(total_time)*self.batch_size self.output[epoch]['au'][key] = au*100 From eeeae3b4fe63aee4cde98bd222a5b1cd1d922cec Mon Sep 17 00:00:00 2001 From: Huihuo Zheng Date: Fri, 31 May 2024 14:12:27 +0000 Subject: [PATCH 5/6] updated expected throughput value --- dlio_benchmark/utils/statscounter.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/dlio_benchmark/utils/statscounter.py b/dlio_benchmark/utils/statscounter.py index f734640f..84caff51 100644 --- a/dlio_benchmark/utils/statscounter.py +++ b/dlio_benchmark/utils/statscounter.py @@ -166,17 +166,17 @@ def end_run(self): metric = metric + f"[METRIC] Training Accelerator Utilization [AU] (%): {np.mean(train_au):.4f} ({np.std(train_au):.4f})\n" metric = metric + f"[METRIC] Training Throughput (samples/second): {np.mean(train_throughput):.4f} ({np.std(train_throughput):.4f})\n" metric = metric + f"[METRIC] Training I/O Throughput (MB/second): {np.mean(train_throughput)*self.record_size/1024/1024:.4f} ({np.std(train_throughput)*self.record_size/1024/1024:.4f})\n" - metric = metric + f"[METRIC] --\n" - metric = metric + f"[METRIC] Training Throughput (ideal) (samples/second): {self.batch_size / self.args.computation_time * self.MPI.size():.4f}\n" - metric = metric + f"[METRIC] Training I/O Throughput (ideal) (MB/second): {self.batch_size * self.record_size/1024/1024 / self.args.computation_time * self.MPI.size():.4f}\n" + metric = metric + f"[METRIC] **Expected Throughputs if compute-bound\n" + metric = metric + f"[METRIC] Training Throughput (expected) (samples/second): {np.mean(train_throughput/train_au)*100:.4f}\n" + metric = metric + f"[METRIC] Training I/O Throughput (expected) (MB/second): {np.mean(train_throughput/train_au)*100*self.record_size/1024/1024:.4f}\n" if self.args.do_eval: metric = metric + f"[METRIC] Eval Accelerator Utilization [AU] (%): {np.mean(eval_au):.4f} ({np.std(eval_au):.4f})\n" metric = metric + f"[METRIC] Eval Throughput (samples/second): {np.mean(eval_throughput):.6f} ({np.std(eval_throughput):.6f})\n" metric = metric + f"[METRIC] Eval Throughput (MB/second): {np.mean(eval_throughput)*self.record_size/1024/1024:.6f} ({np.std(eval_throughput)*self.record_size/1024/1024:.6f})\n" - metric = metric + f"[METRIC] --\n" - metric = metric + f"[METRIC] Eval Throughput (ideal) (samples/second): {self.batch_size / self.args.eval_time * self.MPI.size():.4f}\n" - metric = metric + f"[METRIC] Eval I/O Throughput (ideal) (MB/second): {self.batch_size * self.record_size/1024/1024 / self.args.eval_time * self.MPI.size():.4f}\n" + metric = metric + f"[METRIC] **Expected Throughputs if compute-bound\n" + metric = metric + f"[METRIC] Eval Throughput (expected) (samples/second): {np.mean(eval_throughput/eval_au)*100:.4f}\n" + metric = metric + f"[METRIC] Eval I/O Throughput (expected) (MB/second): {np.mean(eval_throughput/eval_au) * self.record_size/1024/1024:.4f}\n" metric+="[METRIC] ==========================================================\n" logging.info(metric) def start_train(self, epoch): @@ -336,6 +336,7 @@ def compute_metrics_train(self, epoch, block): throughput = (len(self.output[epoch]['compute'][key]) - 2)/(total_time)*self.batch_size self.output[epoch]['au'][key] = au*100 self.output[epoch]['throughput'][key] = throughput + self.output[epoch]['compute'][key] = total_compute_time def compute_metrics_eval(self, epoch): key = 'eval' @@ -348,6 +349,8 @@ def compute_metrics_eval(self, epoch): throughput = len(self.output[epoch]['compute'][key])/(self.end_timestamp - self.start_timestamp)*self.batch_size_eval self.output[epoch]['au'][key] = au*100 self.output[epoch]['throughput'][key] = throughput + self.output[epoch]['compute'][key] = total_compute_time + def eval_batch_loaded(self, epoch, step): duration = time() - self.start_time_loading From 75bdee20cb08b041fabee575d3a6d5c17408c747 Mon Sep 17 00:00:00 2001 From: Huihuo Zheng Date: Fri, 31 May 2024 22:22:52 +0000 Subject: [PATCH 6/6] fixed typo and a bug --- dlio_benchmark/reader/tf_reader.py | 20 +++++++++++++++----- dlio_benchmark/utils/config.py | 4 ++-- dlio_benchmark/utils/statscounter.py | 9 +++------ 3 files changed, 20 insertions(+), 13 deletions(-) diff --git a/dlio_benchmark/reader/tf_reader.py b/dlio_benchmark/reader/tf_reader.py index 61eb8c10..493cde2d 100644 --- a/dlio_benchmark/reader/tf_reader.py +++ b/dlio_benchmark/reader/tf_reader.py @@ -67,7 +67,7 @@ def _parse_image(self, serialized): 'image': tf.io.FixedLenFeature([], tf.string), 'size': tf.io.FixedLenFeature([], tf.int64) } - parsed_example = tf.io.parse_example(serialized=serialized, features=features) + #parsed_example = tf.io.parse_example(serialized=serialized, features=features) # Get the image as raw bytes. #image_raw = parsed_example['image'] #dimension = tf.cast(parsed_example['size'], tf.int32).numpy() @@ -94,10 +94,20 @@ def next(self): self._dataset = self._dataset.shuffle(buffer_size=self._args.shuffle_size) self._dataset = self._dataset.shard(num_shards=self._args.comm_size, index=self._args.my_rank) - self._dataset = self._dataset.batch(self.batch_size, drop_remainder=True) - self._dataset = self._dataset.map( - lambda x: tf.py_function(func=self._parse_image, inp=[x], Tout=[tf.uint8]), - num_parallel_calls=self._args.computation_threads) + if self._args.computation_threads==0: + self._dataset = self._dataset.batch(self.batch_size, drop_remainder=True) + else: + if self._args.computation_threads <= self.batch_size: + self._dataset = self._dataset.batch(self.batch_size, drop_remainder=True) + self._dataset = self._dataset.map( + lambda x: tf.py_function(func=self._parse_image, inp=[x], Tout=[tf.uint8]), + num_parallel_calls=self._args.computation_threads) + else: + self._dataset = self._dataset.batch(self._args.computation_threads) + self._dataset = self._dataset.map( + lambda x: tf.py_function(func=self._parse_image, inp=[x], Tout=[tf.uint8]), + num_parallel_calls=self._args.computation_threads) + self._dataset = self._dataset.unbatch(self.batch_size) self._dataset = self._dataset.repeat(self._args.epochs) total = math.ceil(len(self._file_list)/self._args.comm_size / self.batch_size * self._args.num_samples_per_file) return self._dataset.take(total*self._args.epochs).prefetch(buffer_size=self._args.prefetch_size) diff --git a/dlio_benchmark/utils/config.py b/dlio_benchmark/utils/config.py index d2093754..d0929a91 100644 --- a/dlio_benchmark/utils/config.py +++ b/dlio_benchmark/utils/config.py @@ -461,8 +461,8 @@ def LoadConfig(args, config): args.data_loader_sampler = DataLoaderSampler(reader['data_loader_sampler']) if 'read_threads' in reader: args.read_threads = reader['read_threads'] - if 'computatation_threads' in reader: - args.computatation_threads = reader['computatation_threads'] + if 'computation_threads' in reader: + args.computation_threads = reader['computation_threads'] if 'batch_size' in reader: args.batch_size = reader['batch_size'] if 'batch_size_eval' in reader: diff --git a/dlio_benchmark/utils/statscounter.py b/dlio_benchmark/utils/statscounter.py index 84caff51..97ceea98 100644 --- a/dlio_benchmark/utils/statscounter.py +++ b/dlio_benchmark/utils/statscounter.py @@ -280,6 +280,7 @@ def end_block(self, epoch, block, steps_taken): self.per_epoch_stats[epoch][f'block{block}']['duration'] = duration logging.info(f"{utcnow()} Epoch {epoch} - Block {block} [Training] Accelerator Utilization [AU] (%): {self.output[epoch]['au'][f'block{block}']:.4f}") logging.info(f"{utcnow()} Epoch {epoch} - Block {block} [Training] Throughput (samples/second): {self.output[epoch]['throughput'][f'block{block}']*self.comm_size:.4f}") + logging.info(f"{utcnow()} Epoch {epoch} - Block {block} [Training] Computation time per step (second): {np.mean(self.output[epoch]['compute'][f'block{block}'][1:-1]):.4f}+/-{np.std(self.output[epoch]['compute'][f'block{block}'][1:-1]):.4f} (set value: {self.args.computation_time})") def start_ckpt(self, epoch, block, steps_taken): if self.my_rank == 0: @@ -336,11 +337,10 @@ def compute_metrics_train(self, epoch, block): throughput = (len(self.output[epoch]['compute'][key]) - 2)/(total_time)*self.batch_size self.output[epoch]['au'][key] = au*100 self.output[epoch]['throughput'][key] = throughput - self.output[epoch]['compute'][key] = total_compute_time def compute_metrics_eval(self, epoch): key = 'eval' - total_compute_time = np.sum(self.output[epoch]['compute'][key][1:]) + total_compute_time = np.sum(self.output[epoch]['compute'][key][1:-1]) if (total_compute_time==0): au=0.0 else: @@ -349,19 +349,16 @@ def compute_metrics_eval(self, epoch): throughput = len(self.output[epoch]['compute'][key])/(self.end_timestamp - self.start_timestamp)*self.batch_size_eval self.output[epoch]['au'][key] = au*100 self.output[epoch]['throughput'][key] = throughput - self.output[epoch]['compute'][key] = total_compute_time - def eval_batch_loaded(self, epoch, step): duration = time() - self.start_time_loading self.output[epoch]['load']['eval'].append(duration) logging.debug(f"{utcnow()} Rank {self.my_rank} step {step} loaded {self.batch_size_eval} samples in {duration} s") - def eval_batch_processed(self, epoch, step): current_time = time() duration = current_time - self.start_time_loading - computation_time = current_time -self.start_time_compute + computation_time = current_time - self.start_time_compute self.output[epoch]['proc']['eval'].append(duration) self.output[epoch]['compute']['eval'].append(computation_time) logging.info(f"{utcnow()} Rank {self.my_rank} step {step} processed {self.batch_size_eval} samples in {duration} s")