Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
adam-narozniak committed Sep 25, 2023
1 parent 8092693 commit 821e26d
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 4 deletions.
5 changes: 3 additions & 2 deletions datasets/flwr_datasets/partitioner/cid_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def _create_int_idx_to_cid(self) -> None:
Client ids come from the columns specified in `partition_by`.
"""
unique_cid = self._dataset.unique(self._partition_by)
unique_cid = self.dataset.unique(self._partition_by)
self._index_to_cid = dict(zip(range(len(unique_cid)), unique_cid))

def _save_partition_indexing(self, idx: int, rows: List[int]) -> None:
Expand Down Expand Up @@ -64,7 +64,7 @@ def load_partition(self, idx: int) -> datasets.Dataset:
if len(self._index_to_cid) == 0:
self._create_int_idx_to_cid()

return self._dataset.filter(
return self.dataset.filter(
lambda row: row[self._partition_by] == self._index_to_cid[idx]
)

Expand All @@ -73,6 +73,7 @@ def index_to_cid(self) -> Dict[int, str]:
"""Index to corresponding cid from the dataset property."""
return self._index_to_cid

# pylint: disable=R0201
@index_to_cid.setter
def index_to_cid(self, value: Dict[int, str]) -> None:
raise AttributeError("Setting the index_to_cid dictionary is not allowed.")
4 changes: 2 additions & 2 deletions datasets/flwr_datasets/partitioner/cid_partitioner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,13 +100,13 @@ def test_correct_number_of_partitions(
self, num_rows: int, num_unique_cid: int
) -> None:
"""Test if the # of available partitions is equal to # of unique clients."""
dataset, partitioner = _dummy_setup(num_rows, num_unique_cid)
_, partitioner = _dummy_setup(num_rows, num_unique_cid)
_ = partitioner.load_partition(idx=0)
self.assertEqual(len(partitioner.index_to_cid), num_unique_cid)

def test_cannot_set_index_to_cid(self) -> None:
"""Test the lack of ability to set index_to_cid."""
dataset, partitioner = _dummy_setup(num_rows=10, n_unique_cids=2)
_, partitioner = _dummy_setup(num_rows=10, n_unique_cids=2)
with self.assertRaises(AttributeError):
partitioner.index_to_cid = {0: "0"}

Expand Down

0 comments on commit 821e26d

Please sign in to comment.