Skip to content

Commit

Permalink
Avoid dying if a rare bug occurs in the look ahead bucketing
Browse files Browse the repository at this point in the history
StopIteration when trying to pick a bucket in a smart way.
Doing something stupid instead. Please check me.
  • Loading branch information
Waino committed Sep 11, 2023
1 parent 173d283 commit 7933c0c
Showing 1 changed file with 21 additions and 10 deletions.
31 changes: 21 additions & 10 deletions onmt/inputters/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,19 +139,30 @@ def __iter__(self):
if accum:
yield self.collate_fn(accum)
break
if not any(self._lens[current_bucket_idx:]):
# this was the largest bucket, so we'll need to pick the next smallest instead
smallest_bucket_idx = next(
bucket_idx
for bucket_idx in range(smallest_bucket_idx, -1, -1)
if self._lens[bucket_idx] != 0
try:
if not any(self._lens[current_bucket_idx:]):
# this was the largest bucket, so we'll need to pick the next smallest instead
smallest_bucket_idx = next(
bucket_idx
for bucket_idx in range(smallest_bucket_idx, -1, -1)
if self._lens[bucket_idx] != 0
)
current_bucket_idx = smallest_bucket_idx
else:
# there was a larger bucket, shift the index by one
current_bucket_idx = next(
bucket_idx
for bucket_idx in range(current_bucket_idx, len(self._buckets) + 1)
if self._lens[bucket_idx] != 0
)
except StopIteration:
logger.warning(
'StopIteration when trying to pick a bucket in a smart way. '
'Doing something stupid instead. Please check me.'
)
current_bucket_idx = smallest_bucket_idx
else:
# there was a larger bucket, shift the index by one
current_bucket_idx = next(
bucket_idx
for bucket_idx in range(current_bucket_idx, len(self._buckets) + 1)
for bucket_idx in range(len(self._lens))
if self._lens[bucket_idx] != 0
)
_ = self._choose_and_prepare_bucket(bucket_idx=current_bucket_idx)
Expand Down

0 comments on commit 7933c0c

Please sign in to comment.