Skip to content

Commit

Permalink
Fix condition in get_collate_for_dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
BloodAxe committed May 4, 2023
1 parent bf29472 commit 589f603
Showing 1 changed file with 2 additions and 3 deletions.
5 changes: 2 additions & 3 deletions pytorch_toolbelt/utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,9 +347,8 @@ def get_collate_for_dataset(dataset: Union[Dataset, ConcatDataset]) -> Callable:
collate_fn = default_collate

if hasattr(dataset, "get_collate_fn"):
collate_fn = dataset.get_collate_fn()

if isinstance(dataset, ConcatDataset):
return dataset.get_collate_fn()
elif isinstance(dataset, ConcatDataset):
collates = set(get_collate_for_dataset(ds) for ds in dataset.datasets)
if len(collates) != 1:
raise RuntimeError(
Expand Down

0 comments on commit 589f603

Please sign in to comment.