diff --git a/doc/source/getting_started/details/data.rst b/doc/source/getting_started/details/data.rst index 420a5a9..5d619a8 100644 --- a/doc/source/getting_started/details/data.rst +++ b/doc/source/getting_started/details/data.rst @@ -37,7 +37,7 @@ To apply **Private Evolution** to your own private dataset, you need to create a * ``metadata``: A dictionary that holds the metadata of the samples. The following keys must be included: - * ``label_names``: A list of strings that holds the names of the classes. The length of the list must be equal to K. + * ``label_info``: A list of dictionaries that hold the information of the classes. The length of the list must be equal to K. Inside each dictionary, a key ``name`` must be included, which refers to the name of the class. In addition, you can include any other keys that hold the metadata of the samples if needed. diff --git a/pe/api/api.py b/pe/api/api.py index 31481e6..0f11e84 100644 --- a/pe/api/api.py +++ b/pe/api/api.py @@ -5,11 +5,11 @@ class API(ABC): """The abstract class that defines the APIs for the synthetic data generation.""" @abstractmethod - def random_api(self, label_name, num_samples): + def random_api(self, label_info, num_samples): """The abstract method that generates random synthetic data. - :param label_name: The name of the label - :type label_name: str + :param label_info: The info of the label + :type label_info: dict :param num_samples: The number of random samples to generate :type num_samples: int """ diff --git a/pe/api/image/improved_diffusion_api.py b/pe/api/image/improved_diffusion_api.py index bc76121..b78cee7 100644 --- a/pe/api/image/improved_diffusion_api.py +++ b/pe/api/image/improved_diffusion_api.py @@ -123,16 +123,17 @@ def __init__( else: self._variation_degrees = variation_degrees - def random_api(self, label_name, num_samples): + def random_api(self, label_info, num_samples): """Generating random synthetic data. - :param label_name: The name of the label, not utilized in this API - :type label_name: str + :param label_info: The info of the label, not utilized in this API + :type label_info: dict :param num_samples: The number of random samples to generate :type num_samples: int :return: The data object of the generated synthetic data :rtype: :py:class:`pe.data.data.Data` """ + label_name = label_info.name execution_logger.info(f"RANDOM API: creating {num_samples} samples for label {label_name}") samples, labels = sample( sampler=self._timestep_respacing_to_sampler[self._timestep_respacing[0]], diff --git a/pe/api/image/stable_diffusion_api.py b/pe/api/image/stable_diffusion_api.py index 4450383..ad9ef18 100644 --- a/pe/api/image/stable_diffusion_api.py +++ b/pe/api/image/stable_diffusion_api.py @@ -116,16 +116,17 @@ def __init__( self._variation_api_pipe.safety_checker = None self._variation_api_pipe = self._variation_api_pipe.to(self._device) - def random_api(self, label_name, num_samples): + def random_api(self, label_info, num_samples): """Generating random synthetic data. - :param label_name: The name of the label, not utilized in this API - :type label_name: str + :param label_info: The info of the label + :type label_info: dict :param num_samples: The number of random samples to generate :type num_samples: int :return: The data object of the generated synthetic data :rtype: :py:class:`pe.data.data.Data` """ + label_name = label_info.name execution_logger.info(f"RANDOM API: creating {num_samples} samples for label {label_name}") prompt = self._prompt[label_name] diff --git a/pe/callback/image/sample_images.py b/pe/callback/image/sample_images.py index 8d60929..bd99a9b 100644 --- a/pe/callback/image/sample_images.py +++ b/pe/callback/image/sample_images.py @@ -26,7 +26,7 @@ def __call__(self, syn_data): :rtype: list[:py:class:`pe.metric_item.ImageListMetricItem`] """ all_image_list = [] - num_classes = len(syn_data.metadata.label_names) + num_classes = len(syn_data.metadata.label_info) for class_id in range(num_classes): image_list = syn_data.data_frame[syn_data.data_frame[LABEL_ID_COLUMN_NAME] == class_id][ IMAGE_DATA_COLUMN_NAME diff --git a/pe/callback/image/save_all_images.py b/pe/callback/image/save_all_images.py index 6429339..104261a 100644 --- a/pe/callback/image/save_all_images.py +++ b/pe/callback/image/save_all_images.py @@ -58,7 +58,7 @@ def __call__(self, syn_data): for i in iterator: image = syn_data.data_frame[IMAGE_DATA_COLUMN_NAME][i] label_id = int(syn_data.data_frame[LABEL_ID_COLUMN_NAME][i]) - label_name = syn_data.metadata.label_names[label_id] + label_name = syn_data.metadata.label_info[label_id].name index = syn_data.data_frame.index[i] self._save_image( image=image, diff --git a/pe/data/image/camelyon17.py b/pe/data/image/camelyon17.py index acfc37a..9063669 100644 --- a/pe/data/image/camelyon17.py +++ b/pe/data/image/camelyon17.py @@ -46,5 +46,5 @@ def __init__(self, split="train", root_dir="data", res=64): LABEL_ID_COLUMN_NAME: labels, } ) - metadata = {"label_names": CAMELYON17_LABEL_NAMES} + metadata = {"label_info": [{"name": n} for n in CAMELYON17_LABEL_NAMES]} super().__init__(data_frame=data_frame, metadata=metadata) diff --git a/pe/data/image/cat.py b/pe/data/image/cat.py index 1c0730b..1b6b01d 100644 --- a/pe/data/image/cat.py +++ b/pe/data/image/cat.py @@ -52,7 +52,7 @@ def __init__(self, root_dir="data", res=512): LABEL_ID_COLUMN_NAME: labels, } ) - metadata = {"label_names": CAT_LABEL_NAMES} + metadata = {"label_info": [{"name": n} for n in CAT_LABEL_NAMES]} super().__init__(data_frame=data_frame, metadata=metadata) def _download(self): diff --git a/pe/data/image/cifar10.py b/pe/data/image/cifar10.py index c54eaef..196efed 100644 --- a/pe/data/image/cifar10.py +++ b/pe/data/image/cifar10.py @@ -41,5 +41,5 @@ def __init__(self, split="train"): LABEL_ID_COLUMN_NAME: dataset.targets, } ) - metadata = {"label_names": CIFAR10_LABEL_NAMES} + metadata = {"label_info": [{"name": n} for n in CIFAR10_LABEL_NAMES]} super().__init__(data_frame=data_frame, metadata=metadata) diff --git a/pe/data/image/image.py b/pe/data/image/image.py index de495a2..194c280 100644 --- a/pe/data/image/image.py +++ b/pe/data/image/image.py @@ -121,5 +121,5 @@ def load_image_folder(path, image_size, class_cond=True, num_images=-1, num_work LABEL_ID_COLUMN_NAME: list(all_labels), } ) - metadata = {"label_names": dataset.class_names if class_cond else ["None"]} + metadata = {"label_info": [{"name": n} for n in dataset.class_names] if class_cond else [{"name": "None"}]} return Data(data_frame=data_frame, metadata=metadata) diff --git a/pe/population/pe_population.py b/pe/population/pe_population.py index 24180b5..d2a8964 100644 --- a/pe/population/pe_population.py +++ b/pe/population/pe_population.py @@ -52,11 +52,11 @@ def __init__( "synthetic data will be empty." ) - def initial(self, label_name, num_samples): + def initial(self, label_info, num_samples): """Generate the initial synthetic data. - :param label_name: The label name - :type label_name: str + :param label_info: The label info + :type label_info: dict :param num_samples: The number of samples to generate :type num_samples: int :return: The initial synthetic data @@ -64,9 +64,9 @@ def initial(self, label_name, num_samples): """ execution_logger.info( f"Population: generating {num_samples}*{self._initial_variation_api_fold + 1} initial " - f"synthetic samples for label {label_name}" + f"synthetic samples for label {label_info['name']}" ) - random_data = self._api.random_api(label_name=label_name, num_samples=num_samples) + random_data = self._api.random_api(label_info=label_info, num_samples=num_samples) variation_data_list = [] for _ in range(self._initial_variation_api_fold): variation_data = self._api.variation_api(syn_data=random_data) @@ -74,7 +74,7 @@ def initial(self, label_name, num_samples): data = Data.concat([random_data] + variation_data_list) execution_logger.info( f"Population: finished generating {num_samples}*{self._initial_variation_api_fold + 1} initial " - f"synthetic samples for label {label_name}" + f"synthetic samples for label {label_info['name']}" ) return data diff --git a/pe/population/population.py b/pe/population/population.py index 22368e8..887b846 100644 --- a/pe/population/population.py +++ b/pe/population/population.py @@ -5,11 +5,11 @@ class Population(ABC): """The abstract class that generates synthetic data.""" @abstractmethod - def initial(self, label_name, num_samples): + def initial(self, label_info, num_samples): """Generate the initial synthetic data. - :param label_name: The label name - :type label_name: str + :param label_info: The label info + :type label_info: dict :param num_samples: The number of samples to generate :type num_samples: int """ diff --git a/pe/runner/pe.py b/pe/runner/pe.py index f212a61..cacbc10 100644 --- a/pe/runner/pe.py +++ b/pe/runner/pe.py @@ -87,9 +87,9 @@ def _get_num_samples_per_label_id(self, num_samples, fraction_per_label_id): fraction_per_label_id = self._priv_data.data_frame[LABEL_ID_COLUMN_NAME].value_counts().to_dict() fraction_per_label_id = [ 0 if i not in fraction_per_label_id else fraction_per_label_id[i] - for i in range(len(self._priv_data.metadata.label_names)) + for i in range(len(self._priv_data.metadata.label_info)) ] - if len(fraction_per_label_id) != len(self._priv_data.metadata.label_names): + if len(fraction_per_label_id) != len(self._priv_data.metadata.label_info): raise ValueError("fraction_per_label_id should have the same length as the number of labels.") fraction_per_label_id = np.array(fraction_per_label_id) fraction_per_label_id = fraction_per_label_id / np.sum(fraction_per_label_id) @@ -162,9 +162,9 @@ def run( fraction_per_label_id=fraction_per_label_id, ) syn_data_list = [] - for label_id, label_name in enumerate(self._priv_data.metadata.label_names): + for label_id, label_info in enumerate(self._priv_data.metadata.label_info): syn_data = self._population.initial( - label_name=label_name, + label_info=label_info, num_samples=num_samples_per_label_id[label_id], ) syn_data.set_label_id(label_id) @@ -172,7 +172,7 @@ def run( syn_data = Data.concat(syn_data_list, metadata=self._priv_data.metadata) syn_data.data_frame.reset_index(drop=True, inplace=True) syn_data.metadata.iteration = 0 - syn_data.metadata.label_names = self._priv_data.metadata.label_names + syn_data.metadata.label_info = self._priv_data.metadata.label_info self._log_metrics(syn_data) # Run PE iterations. @@ -186,7 +186,7 @@ def run( priv_data_list = [] # Generate synthetic data for each label. - for label_id in range(len(self._priv_data.metadata.label_names)): + for label_id in range(len(self._priv_data.metadata.label_info)): execution_logger.info(f"Label {label_id}") sub_priv_data = self._priv_data.filter_label_id(label_id=label_id) sub_syn_data = syn_data.filter_label_id(label_id=label_id)