Skip to content

Commit 7338a4f

Browse files
authored
fix: copy collate_fn from the original dataloader (#124)
The bagging dataloaders do not copy the collate_fn from the original dataloaders.
1 parent 03177ab commit 7338a4f

File tree

1 file changed

+1
-0
lines changed

1 file changed

+1
-0
lines changed

torchensemble/bagging.py

+1
Original file line numberDiff line numberDiff line change
@@ -468,6 +468,7 @@ def _get_bagging_dataloaders(original_dataloader, n_estimators):
468468
sub_dataset,
469469
batch_size=original_dataloader.batch_size,
470470
num_workers=original_dataloader.num_workers,
471+
collate_fn=original_dataloader.collate_fn,
471472
shuffle=True,
472473
)
473474
dataloaders.append(dataloader)

0 commit comments

Comments
 (0)