From 135ad2bbcdd40fafc4648513fdda028e46697f41 Mon Sep 17 00:00:00 2001 From: Adam Narozniak Date: Wed, 6 Dec 2023 17:53:09 +0100 Subject: [PATCH] Improve tests --- .../partitioner/dirichlet_partitioner_test.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/datasets/flwr_datasets/partitioner/dirichlet_partitioner_test.py b/datasets/flwr_datasets/partitioner/dirichlet_partitioner_test.py index 71f289309781..4f313bb9e5b3 100644 --- a/datasets/flwr_datasets/partitioner/dirichlet_partitioner_test.py +++ b/datasets/flwr_datasets/partitioner/dirichlet_partitioner_test.py @@ -85,10 +85,13 @@ def test_invalid_alpha(self): num_partitions=3, alpha=[0.5, 0.5], partition_by="labels" ) - def test_load_partition(self): + def test_min_partition_size_requirement(self): + """Test if partitions are created with min partition size required.""" _, partitioner = _dummy_setup(3, 0.5, 100, "labels") - partition_list = [partitioner.load_partition(node_id) for node_id in [0,1,2]] - self.assertGreaterEqual(all([len(p) for p in partition_list]), partitioner._min_partition_size) + partition_list = [partitioner.load_partition(node_id) for node_id in [0, 1, 2]] + self.assertTrue( + all([len(p) > partitioner._min_partition_size for p in partition_list]) + ) def test_load_invalid_partition_index(self): """Test if raises when the load_partition is above the num_partitions."""