From 2cc6b1df72a5cc60d75acea94002718dc72ea09a Mon Sep 17 00:00:00 2001 From: LouisDDN <77112282+LouisDDN@users.noreply.github.com> Date: Wed, 29 May 2024 16:26:58 +0200 Subject: [PATCH] Improve tfreader parsing performance (batch) (#194) * Improve tfreader parsing performance (batch) * Create batch only once --- dlio_benchmark/reader/tf_reader.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/dlio_benchmark/reader/tf_reader.py b/dlio_benchmark/reader/tf_reader.py index 4b71a15f..894fe8af 100644 --- a/dlio_benchmark/reader/tf_reader.py +++ b/dlio_benchmark/reader/tf_reader.py @@ -85,10 +85,6 @@ def next(self): f"{utcnow()} Reading {len(self._file_list)} files thread {self.thread_index} rank {self._args.my_rank}") self._dataset = tf.data.TFRecordDataset(filenames=self._file_list, buffer_size=self._args.transfer_size, num_parallel_reads=self._args.read_threads) - self._dataset = self._dataset.map( - lambda x: tf.py_function(func=self._parse_image, inp=[x], Tout=[tf.uint8]), - num_parallel_calls=self._args.computation_threads) - self._dataset = self._dataset.shard(num_shards=self._args.comm_size, index=self._args.my_rank) if self._args.sample_shuffle != Shuffle.OFF: if self._args.sample_shuffle == Shuffle.SEED: @@ -96,7 +92,13 @@ def next(self): seed=self._args.seed) else: self._dataset = self._dataset.shuffle(buffer_size=self._args.shuffle_size) + + self._dataset = self._dataset.shard(num_shards=self._args.comm_size, index=self._args.my_rank) self._dataset = self._dataset.batch(self.batch_size, drop_remainder=True) + self._dataset = self._dataset.map( + lambda x: tf.py_function(func=self._parse_image, inp=[x], Tout=[tf.uint8]), + num_parallel_calls=self._args.computation_threads) + self._dataset = self._dataset.repeat(self._args.epochs) total = math.ceil(len(self._file_list)/self._args.comm_size / self.batch_size * self._args.num_samples_per_file) return self._dataset.take(total*self._args.epochs).prefetch(buffer_size=self._args.prefetch_size)