diff --git a/onmt/inputters/dataloader.py b/onmt/inputters/dataloader.py index bcccbbad..4cde3e3e 100644 --- a/onmt/inputters/dataloader.py +++ b/onmt/inputters/dataloader.py @@ -139,19 +139,30 @@ def __iter__(self): if accum: yield self.collate_fn(accum) break - if not any(self._lens[current_bucket_idx:]): - # this was the largest bucket, so we'll need to pick the next smallest instead - smallest_bucket_idx = next( - bucket_idx - for bucket_idx in range(smallest_bucket_idx, -1, -1) - if self._lens[bucket_idx] != 0 + try: + if not any(self._lens[current_bucket_idx:]): + # this was the largest bucket, so we'll need to pick the next smallest instead + smallest_bucket_idx = next( + bucket_idx + for bucket_idx in range(smallest_bucket_idx, -1, -1) + if self._lens[bucket_idx] != 0 + ) + current_bucket_idx = smallest_bucket_idx + else: + # there was a larger bucket, shift the index by one + current_bucket_idx = next( + bucket_idx + for bucket_idx in range(current_bucket_idx, len(self._buckets) + 1) + if self._lens[bucket_idx] != 0 + ) + except StopIteration: + logger.warning( + 'StopIteration when trying to pick a bucket in a smart way. ' + 'Doing something stupid instead. Please check me.' ) - current_bucket_idx = smallest_bucket_idx - else: - # there was a larger bucket, shift the index by one current_bucket_idx = next( bucket_idx - for bucket_idx in range(current_bucket_idx, len(self._buckets) + 1) + for bucket_idx in range(len(self._lens)) if self._lens[bucket_idx] != 0 ) _ = self._choose_and_prepare_bucket(bucket_idx=current_bucket_idx)