Skip to content

Commit

Permalink
update metadata format to support more label information
Browse files Browse the repository at this point in the history
  • Loading branch information
fjxmlzn committed Dec 20, 2024
1 parent 46baff4 commit 4e1a83c
Show file tree
Hide file tree
Showing 13 changed files with 33 additions and 31 deletions.
2 changes: 1 addition & 1 deletion doc/source/getting_started/details/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ To apply **Private Evolution** to your own private dataset, you need to create a

* ``metadata``: A dictionary that holds the metadata of the samples. The following keys must be included:

* ``label_names``: A list of strings that holds the names of the classes. The length of the list must be equal to K.
* ``label_info``: A list of dictionaries that hold the information of the classes. The length of the list must be equal to K. Inside each dictionary, a key ``name`` must be included, which refers to the name of the class.

In addition, you can include any other keys that hold the metadata of the samples if needed.

Expand Down
6 changes: 3 additions & 3 deletions pe/api/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@ class API(ABC):
"""The abstract class that defines the APIs for the synthetic data generation."""

@abstractmethod
def random_api(self, label_name, num_samples):
def random_api(self, label_info, num_samples):
"""The abstract method that generates random synthetic data.
:param label_name: The name of the label
:type label_name: str
:param label_info: The info of the label
:type label_info: dict
:param num_samples: The number of random samples to generate
:type num_samples: int
"""
Expand Down
7 changes: 4 additions & 3 deletions pe/api/image/improved_diffusion_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,16 +123,17 @@ def __init__(
else:
self._variation_degrees = variation_degrees

def random_api(self, label_name, num_samples):
def random_api(self, label_info, num_samples):
"""Generating random synthetic data.
:param label_name: The name of the label, not utilized in this API
:type label_name: str
:param label_info: The info of the label, not utilized in this API
:type label_info: dict
:param num_samples: The number of random samples to generate
:type num_samples: int
:return: The data object of the generated synthetic data
:rtype: :py:class:`pe.data.data.Data`
"""
label_name = label_info.name
execution_logger.info(f"RANDOM API: creating {num_samples} samples for label {label_name}")
samples, labels = sample(
sampler=self._timestep_respacing_to_sampler[self._timestep_respacing[0]],
Expand Down
7 changes: 4 additions & 3 deletions pe/api/image/stable_diffusion_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,16 +116,17 @@ def __init__(
self._variation_api_pipe.safety_checker = None
self._variation_api_pipe = self._variation_api_pipe.to(self._device)

def random_api(self, label_name, num_samples):
def random_api(self, label_info, num_samples):
"""Generating random synthetic data.
:param label_name: The name of the label, not utilized in this API
:type label_name: str
:param label_info: The info of the label
:type label_info: dict
:param num_samples: The number of random samples to generate
:type num_samples: int
:return: The data object of the generated synthetic data
:rtype: :py:class:`pe.data.data.Data`
"""
label_name = label_info.name
execution_logger.info(f"RANDOM API: creating {num_samples} samples for label {label_name}")

prompt = self._prompt[label_name]
Expand Down
2 changes: 1 addition & 1 deletion pe/callback/image/sample_images.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def __call__(self, syn_data):
:rtype: list[:py:class:`pe.metric_item.ImageListMetricItem`]
"""
all_image_list = []
num_classes = len(syn_data.metadata.label_names)
num_classes = len(syn_data.metadata.label_info)
for class_id in range(num_classes):
image_list = syn_data.data_frame[syn_data.data_frame[LABEL_ID_COLUMN_NAME] == class_id][
IMAGE_DATA_COLUMN_NAME
Expand Down
2 changes: 1 addition & 1 deletion pe/callback/image/save_all_images.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def __call__(self, syn_data):
for i in iterator:
image = syn_data.data_frame[IMAGE_DATA_COLUMN_NAME][i]
label_id = int(syn_data.data_frame[LABEL_ID_COLUMN_NAME][i])
label_name = syn_data.metadata.label_names[label_id]
label_name = syn_data.metadata.label_info[label_id].name
index = syn_data.data_frame.index[i]
self._save_image(
image=image,
Expand Down
2 changes: 1 addition & 1 deletion pe/data/image/camelyon17.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,5 +46,5 @@ def __init__(self, split="train", root_dir="data", res=64):
LABEL_ID_COLUMN_NAME: labels,
}
)
metadata = {"label_names": CAMELYON17_LABEL_NAMES}
metadata = {"label_info": [{"name": n} for n in CAMELYON17_LABEL_NAMES]}
super().__init__(data_frame=data_frame, metadata=metadata)
2 changes: 1 addition & 1 deletion pe/data/image/cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def __init__(self, root_dir="data", res=512):
LABEL_ID_COLUMN_NAME: labels,
}
)
metadata = {"label_names": CAT_LABEL_NAMES}
metadata = {"label_info": [{"name": n} for n in CAT_LABEL_NAMES]}
super().__init__(data_frame=data_frame, metadata=metadata)

def _download(self):
Expand Down
2 changes: 1 addition & 1 deletion pe/data/image/cifar10.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,5 +41,5 @@ def __init__(self, split="train"):
LABEL_ID_COLUMN_NAME: dataset.targets,
}
)
metadata = {"label_names": CIFAR10_LABEL_NAMES}
metadata = {"label_info": [{"name": n} for n in CIFAR10_LABEL_NAMES]}
super().__init__(data_frame=data_frame, metadata=metadata)
2 changes: 1 addition & 1 deletion pe/data/image/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,5 +121,5 @@ def load_image_folder(path, image_size, class_cond=True, num_images=-1, num_work
LABEL_ID_COLUMN_NAME: list(all_labels),
}
)
metadata = {"label_names": dataset.class_names if class_cond else ["None"]}
metadata = {"label_info": [{"name": n} for n in dataset.class_names] if class_cond else [{"name": "None"}]}
return Data(data_frame=data_frame, metadata=metadata)
12 changes: 6 additions & 6 deletions pe/population/pe_population.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,29 +52,29 @@ def __init__(
"synthetic data will be empty."
)

def initial(self, label_name, num_samples):
def initial(self, label_info, num_samples):
"""Generate the initial synthetic data.
:param label_name: The label name
:type label_name: str
:param label_info: The label info
:type label_info: dict
:param num_samples: The number of samples to generate
:type num_samples: int
:return: The initial synthetic data
:rtype: :py:class:`pe.data.data.Data`
"""
execution_logger.info(
f"Population: generating {num_samples}*{self._initial_variation_api_fold + 1} initial "
f"synthetic samples for label {label_name}"
f"synthetic samples for label {label_info['name']}"
)
random_data = self._api.random_api(label_name=label_name, num_samples=num_samples)
random_data = self._api.random_api(label_info=label_info, num_samples=num_samples)
variation_data_list = []
for _ in range(self._initial_variation_api_fold):
variation_data = self._api.variation_api(syn_data=random_data)
variation_data_list.append(variation_data)
data = Data.concat([random_data] + variation_data_list)
execution_logger.info(
f"Population: finished generating {num_samples}*{self._initial_variation_api_fold + 1} initial "
f"synthetic samples for label {label_name}"
f"synthetic samples for label {label_info['name']}"
)
return data

Expand Down
6 changes: 3 additions & 3 deletions pe/population/population.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@ class Population(ABC):
"""The abstract class that generates synthetic data."""

@abstractmethod
def initial(self, label_name, num_samples):
def initial(self, label_info, num_samples):
"""Generate the initial synthetic data.
:param label_name: The label name
:type label_name: str
:param label_info: The label info
:type label_info: dict
:param num_samples: The number of samples to generate
:type num_samples: int
"""
Expand Down
12 changes: 6 additions & 6 deletions pe/runner/pe.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,9 @@ def _get_num_samples_per_label_id(self, num_samples, fraction_per_label_id):
fraction_per_label_id = self._priv_data.data_frame[LABEL_ID_COLUMN_NAME].value_counts().to_dict()
fraction_per_label_id = [
0 if i not in fraction_per_label_id else fraction_per_label_id[i]
for i in range(len(self._priv_data.metadata.label_names))
for i in range(len(self._priv_data.metadata.label_info))
]
if len(fraction_per_label_id) != len(self._priv_data.metadata.label_names):
if len(fraction_per_label_id) != len(self._priv_data.metadata.label_info):
raise ValueError("fraction_per_label_id should have the same length as the number of labels.")
fraction_per_label_id = np.array(fraction_per_label_id)
fraction_per_label_id = fraction_per_label_id / np.sum(fraction_per_label_id)
Expand Down Expand Up @@ -162,17 +162,17 @@ def run(
fraction_per_label_id=fraction_per_label_id,
)
syn_data_list = []
for label_id, label_name in enumerate(self._priv_data.metadata.label_names):
for label_id, label_info in enumerate(self._priv_data.metadata.label_info):
syn_data = self._population.initial(
label_name=label_name,
label_info=label_info,
num_samples=num_samples_per_label_id[label_id],
)
syn_data.set_label_id(label_id)
syn_data_list.append(syn_data)
syn_data = Data.concat(syn_data_list, metadata=self._priv_data.metadata)
syn_data.data_frame.reset_index(drop=True, inplace=True)
syn_data.metadata.iteration = 0
syn_data.metadata.label_names = self._priv_data.metadata.label_names
syn_data.metadata.label_info = self._priv_data.metadata.label_info
self._log_metrics(syn_data)

# Run PE iterations.
Expand All @@ -186,7 +186,7 @@ def run(
priv_data_list = []

# Generate synthetic data for each label.
for label_id in range(len(self._priv_data.metadata.label_names)):
for label_id in range(len(self._priv_data.metadata.label_info)):
execution_logger.info(f"Label {label_id}")
sub_priv_data = self._priv_data.filter_label_id(label_id=label_id)
sub_syn_data = syn_data.filter_label_id(label_id=label_id)
Expand Down

0 comments on commit 4e1a83c

Please sign in to comment.