From 7933c0c5cdd9aab4b98bfaf4f3f4f10cdf05bf9b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stig-Arne=20Gr=C3=B6nroos?= Date: Mon, 11 Sep 2023 13:59:07 +0300 Subject: [PATCH] Avoid dying if a rare bug occurs in the look ahead bucketing StopIteration when trying to pick a bucket in a smart way. Doing something stupid instead. Please check me. --- onmt/inputters/dataloader.py | 31 +++++++++++++++++++++---------- 1 file changed, 21 insertions(+), 10 deletions(-) diff --git a/onmt/inputters/dataloader.py b/onmt/inputters/dataloader.py index bcccbbad..4cde3e3e 100644 --- a/onmt/inputters/dataloader.py +++ b/onmt/inputters/dataloader.py @@ -139,19 +139,30 @@ def __iter__(self): if accum: yield self.collate_fn(accum) break - if not any(self._lens[current_bucket_idx:]): - # this was the largest bucket, so we'll need to pick the next smallest instead - smallest_bucket_idx = next( - bucket_idx - for bucket_idx in range(smallest_bucket_idx, -1, -1) - if self._lens[bucket_idx] != 0 + try: + if not any(self._lens[current_bucket_idx:]): + # this was the largest bucket, so we'll need to pick the next smallest instead + smallest_bucket_idx = next( + bucket_idx + for bucket_idx in range(smallest_bucket_idx, -1, -1) + if self._lens[bucket_idx] != 0 + ) + current_bucket_idx = smallest_bucket_idx + else: + # there was a larger bucket, shift the index by one + current_bucket_idx = next( + bucket_idx + for bucket_idx in range(current_bucket_idx, len(self._buckets) + 1) + if self._lens[bucket_idx] != 0 + ) + except StopIteration: + logger.warning( + 'StopIteration when trying to pick a bucket in a smart way. ' + 'Doing something stupid instead. Please check me.' ) - current_bucket_idx = smallest_bucket_idx - else: - # there was a larger bucket, shift the index by one current_bucket_idx = next( bucket_idx - for bucket_idx in range(current_bucket_idx, len(self._buckets) + 1) + for bucket_idx in range(len(self._lens)) if self._lens[bucket_idx] != 0 ) _ = self._choose_and_prepare_bucket(bucket_idx=current_bucket_idx)