diff --git a/trainer/dataset.py b/trainer/dataset.py index a6485eea3..958d3c9fe 100644 --- a/trainer/dataset.py +++ b/trainer/dataset.py @@ -98,12 +98,12 @@ def get_batch(self): for i, data_loader_iter in enumerate(self.dataloader_iter_list): try: - image, text = data_loader_iter.next() + image, text = next(data_loader_iter) balanced_batch_images.append(image) balanced_batch_texts += text except StopIteration: self.dataloader_iter_list[i] = iter(self.data_loader_list[i]) - image, text = self.dataloader_iter_list[i].next() + image, text = next(self.dataloader_iter_list[i]) balanced_batch_images.append(image) balanced_batch_texts += text except ValueError: