diff --git a/mammoth/inputters/dataloader.py b/mammoth/inputters/dataloader.py index 3a402a74..8b7e95e7 100644 --- a/mammoth/inputters/dataloader.py +++ b/mammoth/inputters/dataloader.py @@ -191,8 +191,9 @@ def __iter__(self): while self.bucket_is_empty(*current_bucket_idx): current_bucket_idx = next(next_indices) # retrieve and process the example - example = self._buckets[current_bucket_idx].pop() - self._lens[current_bucket_idx] -= 1 + s_bucket, t_bucket = current_bucket_idx + example = self._buckets[s_bucket][t_bucket].pop() + self._lens[s_bucket][t_bucket] -= 1 accum.append(example) numel = self.numel_fn(example) cur_batch_size += numel