diff --git a/datasets/flwr_datasets/partitioner/iid_partitioner.py b/datasets/flwr_datasets/partitioner/iid_partitioner.py index 37b97468cadf..c8dbf8294fec 100644 --- a/datasets/flwr_datasets/partitioner/iid_partitioner.py +++ b/datasets/flwr_datasets/partitioner/iid_partitioner.py @@ -48,5 +48,5 @@ def load_partition(self, idx: int) -> datasets.Dataset: single dataset partition """ return self.dataset.shard( - num_shards=self._num_partitions, index=idx, contiguous=False + num_shards=self._num_partitions, index=idx, contiguous=True ) diff --git a/datasets/flwr_datasets/partitioner/iid_partitioner_test.py b/datasets/flwr_datasets/partitioner/iid_partitioner_test.py index 5f851807f4bd..64c37c4e7127 100644 --- a/datasets/flwr_datasets/partitioner/iid_partitioner_test.py +++ b/datasets/flwr_datasets/partitioner/iid_partitioner_test.py @@ -18,7 +18,6 @@ import unittest from typing import Tuple -import numpy as np from parameterized import parameterized from datasets import Dataset @@ -102,14 +101,15 @@ def test_load_partition_correct_data( ) -> None: """Test if the data in partition is equal to the expected.""" dataset, partitioner = _dummy_setup(num_partitions, num_rows) + partition_size = num_rows // num_partitions partition_index = 2 partition = partitioner.load_partition(partition_index) row_id = 0 self.assertEqual( - partition["features"][row_id], - dataset[np.arange(partition_index, len(dataset), num_partitions)][ - "features" - ][row_id], + partition[row_id]["features"], + # Note it's contiguous so partition_size * partition_index gets the first + # element of the partition of partition_index + dataset[partition_size * partition_index + row_id]["features"], ) @parameterized.expand( # type: ignore