From f0396cbb3dfeaa1dcab356f92dfa0f0999e334bc Mon Sep 17 00:00:00 2001 From: Harsha Date: Mon, 10 Jun 2024 09:01:19 -0400 Subject: [PATCH] resolved https://github.com/neuronets/nobrainer/issues/336 --- nobrainer/dataset.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/nobrainer/dataset.py b/nobrainer/dataset.py index 3b7b010d..f2d6e8bb 100644 --- a/nobrainer/dataset.py +++ b/nobrainer/dataset.py @@ -146,6 +146,8 @@ def from_tfrecords( ds_obj.map_labels( label_mapping=label_mapping, num_parallel_calls=num_parallel_calls ) + + ds_obj.filter_zero_volumes() # TODO automatically determine batch size ds_obj.batch(1) @@ -385,3 +387,9 @@ def repeat(self, n_repeats): # through once. self.dataset = self.dataset.repeat(n_repeats) return self + + def filter_zero_volumes(self): + self.dataset = self.dataset.filter( + lambda x, y: tf.cast(tf.math.reduce_sum(y), dtype="bool") + ) + return self