Skip to content

Commit

Permalink
improve checks of dataset type
Browse files Browse the repository at this point in the history
  • Loading branch information
KarhouTam committed Aug 24, 2024
1 parent c80a6db commit 28b1c6d
Showing 1 changed file with 66 additions and 15 deletions.
81 changes: 66 additions & 15 deletions datasets/flwr_datasets/partitioner/image_semantic_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ def __init__( # pylint: disable=R0913
sample_seed: int = 42,
pca_seed: Optional[int] = None,
gmm_seed: Optional[int] = None,
image_column_name: Optional[str] = None,
) -> None:
super().__init__()
# Attributes based on the constructor
Expand All @@ -156,14 +157,12 @@ def __init__( # pylint: disable=R0913
self._sample_seed = sample_seed
self._pca_seed = pca_seed
self._gmm_seed = gmm_seed
self._image_column_name = image_column_name

self._check_variable_validation()

self._rng_numpy = np.random.default_rng(seed=self._sample_seed)
# defaults, but some datasets have different names, e.g. cifar10 is "img"
# So this variable might be changed in self._check_dataset_type_if_needed()
self._data_column_name = "image"
# Utility attributes

# The attributes below are determined during the first call to load_partition
self._unique_classes: Optional[Union[List[int], List[str]]] = None
self._partition_id_to_indices: Dict[int, List[int]] = {}
Expand Down Expand Up @@ -386,7 +385,7 @@ def _determine_partition_id_to_indices_if_needed(self) -> None:
self._partition_id_to_indices_determined = True

def _preprocess_dataset_images(self) -> NDArrayFloat:
images = np.array(self.dataset[self._data_column_name], dtype=float)
images = np.array(self.dataset[self._image_column_name], dtype=float)
if len(images.shape) == 3: # [B, H, W]
images = np.reshape(
images, (images.shape[0], 1, images.shape[1], images.shape[2])
Expand Down Expand Up @@ -429,29 +428,45 @@ def _check_data_type_if_needed(self) -> None:
"""Test whether data is image-like."""
if not self._partition_id_to_indices_determined:
features_dict = self.dataset.features.to_dict()
self._data_column_name = list(features_dict.keys())[0]
if self._image_column_name is None:
self._image_column_name = list(features_dict.keys())[0]
if self._image_column_name not in features_dict:
raise ValueError(
"The image column name is not found in the dataset feature dict: ",
list(features_dict.keys()),
f"Now: {self._image_column_name}. ",
)
if not isinstance(
self.dataset.features[self._image_column_name], datasets.Image
):
warnings.warn(
"Image semantic partitioner only supports image-like data. "
f"But column '{self._image_column_name}' is not datasets.Image. "
"So the partition might be failed.",
stacklevel=1,
)
try:
data = np.array(
self.dataset[self._data_column_name][0], dtype=np.float32
image = np.array(
self.dataset.take(1)[self._image_column_name][0], dtype=np.float32
)
except ValueError:
except ValueError as err:
raise ValueError(
"The data needs to be able to transform to np.ndarray. "
) from None
) from err

if not 2 <= len(data.shape) <= 3:
if not 2 <= len(image.shape) <= 3:
raise ValueError(
"The image shape is not supported. "
"The image shape should among {[H, W], [C, H, W], [H, W, C]}. "
f"Now: {data.shape}. "
f"Now: {image.shape}. "
)
if len(data.shape) == 3:
smallest_axis = min(enumerate(data.shape), key=lambda x: x[1])[0]
if len(image.shape) == 3:
smallest_axis = min(enumerate(image.shape), key=lambda x: x[1])[0]
# smallest axis (C) should be at the first or the last place.
if smallest_axis not in [0, 2]:
raise ValueError(
"The 3D image shape should be [C, H, W] or [H, W, C]. "
f"Now: {data.shape}. "
f"Now: {image.shape}. "
)

def _check_variable_validation(self) -> None:
Expand Down Expand Up @@ -484,3 +499,39 @@ def _check_variable_validation(self) -> None:
raise TypeError("The pca seed needs to be an integer.")
if not isinstance(self._gmm_seed, int):
raise TypeError("The gmm seed needs to be an integer.")


if __name__ == "__main__":
# ===================== Test with custom Dataset =====================
from datasets import Dataset

dataset = {
"image": [np.random.randn(28, 28) for _ in range(50)],
"label": [i % 3 for i in range(50)],
}
dataset = Dataset.from_dict(dataset)
partitioner = ImageSemanticPartitioner(
num_partitions=5, partition_by="label", pca_components=30
)
partitioner.dataset = dataset
partition = partitioner.load_partition(0)
partition_sizes = partition_sizes = [
len(partitioner.load_partition(partition_id)) for partition_id in range(5)
]
print(sorted(partition_sizes))
# ====================================================================

# ===================== Test with FederatedDataset =====================
# from flwr_datasets import FederatedDataset

# partitioner = ImageSemanticPartitioner(
# num_partitions=5, partition_by="label", pca_components=128
# )
# fds = FederatedDataset(dataset="cifar10", partitioners={"train": partitioner})
# partition = fds.load_partition(0)
# print(partition[0]) # Print the first example
# partition_sizes = partition_sizes = [
# len(fds.load_partition(partition_id)) for partition_id in range(5)
# ]
# print(sorted(partition_sizes))
# ======================================================================

0 comments on commit 28b1c6d

Please sign in to comment.