Skip to content

Commit

Permalink
ruff fixes on train.py
Browse files Browse the repository at this point in the history
  • Loading branch information
johanos1 committed Mar 25, 2024
1 parent 585fb11 commit 8d08a77
Show file tree
Hide file tree
Showing 16 changed files with 400 additions and 302 deletions.
16 changes: 8 additions & 8 deletions leakpro.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,24 +8,24 @@
import torch
import yaml

import leakpro.dataset as dataset
import leakpro.models as models
import leakpro.train as util

from leakpro import dataset, models
from leakpro.mia_attacks.attack_scheduler import AttackScheduler

from leakpro.reporting.utils import prepare_priavcy_risk_report




def setup_log(name: str, save_file: bool):
def setup_log(name: str, save_file: bool) -> logging.Logger:
"""Generate the logger for the current run.
Args:
----
name (str): Logging file name.
save_file (bool): Flag about whether to save to file.
Returns:
-------
logging.Logger: Logger object for the current run.
"""
my_logger = logging.getLogger(name)
my_logger.setLevel(logging.INFO)
Expand Down Expand Up @@ -95,7 +95,7 @@ def setup_log(name: str, save_file: bool):
elif "cifar10" in configs["data"]["dataset"]:
model = models.ConvNet()
if RETRAIN:
model = util.train(model, train_loader, configs, test_loader, train_test_dataset)
model = util.train(model, train_loader, configs, test_loader, train_test_dataset, logger)


# ------------------------------------------------
Expand Down
166 changes: 86 additions & 80 deletions leakpro/dataset.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,21 @@
import torch
import torchvision
import torchvision.transforms as transforms

import pickle
import os
from torch.utils.data import Dataset
import pandas as pd
import pickle
from typing import List

import numpy as np
from sklearn.preprocessing import LabelEncoder, OneHotEncoder, StandardScaler
import pandas as pd
import torch
import torchvision
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder, OneHotEncoder, StandardScaler
from torch.utils.data import Dataset
from torchvision import transforms

from typing import List

class GeneralDataset(Dataset):
def __init__(self, data:np.ndarray, label:np.ndarray, transforms=None):
"""data_list: A list of GeneralData instances.
"""
data_list: A list of GeneralData instances.
"""

self.X = data # Convert to tensor and specify the data type
self.y = label # Assuming labels are for classification
self.transforms = transforms
Expand All @@ -26,18 +24,17 @@ def __len__(self):
return len(self.y)

def __getitem__(self, idx):
"""
Returns the data and label for a single instance indexed by idx.
"""Returns the data and label for a single instance indexed by idx.
"""
if self.transforms:
X = self.transforms(self.X[idx])
else:
X = self.X[idx]

# ensure that X is a tensor
if not isinstance(X, torch.Tensor):
X = torch.tensor(X, dtype=torch.float32)

y = torch.tensor(self.y[idx], dtype=torch.long)
return X, y

Expand Down Expand Up @@ -88,7 +85,7 @@ def __getitem__(self, idx):
# # If preprocessing functions were passed as parameters, execute them
# if not preprocessed and preproc_fn_dict is not None:
# self.preprocess()

# def __len__(self):
# return len(self.data_dict[self.default_output])

Expand Down Expand Up @@ -268,9 +265,12 @@ class TabularDataset(Dataset):

def __init__(self, X, y):
"""Initializes instance of class TabularDataset.
Args:
----
X (str): features
y (str): target
"""
super().__init__(
data_dict={"X": X, "y": y},
Expand Down Expand Up @@ -309,68 +309,67 @@ def get_dataset(dataset_name: str, data_dir: str):
with open(f"{path}.pkl", "rb") as file:
all_data = pickle.load(file)
print(f"Load data from {path}.pkl")
else:
if "adult" in dataset_name:
column_names = [
"age",
"workclass",
"fnlwgt",
"education",
"education-num",
"marital-status",
"occupation",
"relationship",
"race",
"sex",
"capital-gain",
"capital-loss",
"hours-per-week",
"native-country",
"income",
]
df_train = pd.read_csv(f"{path}/{dataset_name}.data", names=column_names)
df_test = pd.read_csv(
f"{path}/{dataset_name}.test", names=column_names, header=0
)
df_test["income"] = df_test["income"].str.replace(".", "", regex=False)
df = pd.concat([df_train, df_test], axis=0)
df = df.replace(" ?", np.nan)
df = df.dropna()
X, y = df.iloc[:, :-1], df.iloc[:, -1]

categorical_features = [col for col in X.columns if X[col].dtype == "object"]
numerical_features = [
col for col in X.columns if X[col].dtype in ["int64", "float64"]
]

onehot_encoder = OneHotEncoder(sparse_output=False, handle_unknown="ignore")
X_categorical = onehot_encoder.fit_transform(X[categorical_features])

scaler = StandardScaler()
X_numerical = scaler.fit_transform(X[numerical_features])

X = np.hstack([X_numerical, X_categorical])

# label encode the target variable to have the classes 0 and 1
y = LabelEncoder().fit_transform(y)

all_data = GeneralDataset(X,y)
with open(f"{path}.pkl", "wb") as file:
pickle.dump(all_data, file)
print(f"Save data to {path}.pkl")
elif "cifar10" in dataset_name:
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trainset = torchvision.datasets.CIFAR10(root='./data/cifar10', train=True, download=True, transform=transform)
testset = torchvision.datasets.CIFAR10(root='./data/cifar10', train=False,download=True, transform=transform)
X = np.vstack([trainset.data, testset.data])
y = np.hstack([trainset.targets, testset.targets])

all_data = GeneralDataset(X, y, transform)

with open(f"{path}.pkl", "wb") as file:
pickle.dump(all_data, file)
print(f"Save data to {path}.pkl")

elif "adult" in dataset_name:
column_names = [
"age",
"workclass",
"fnlwgt",
"education",
"education-num",
"marital-status",
"occupation",
"relationship",
"race",
"sex",
"capital-gain",
"capital-loss",
"hours-per-week",
"native-country",
"income",
]
df_train = pd.read_csv(f"{path}/{dataset_name}.data", names=column_names)
df_test = pd.read_csv(
f"{path}/{dataset_name}.test", names=column_names, header=0
)
df_test["income"] = df_test["income"].str.replace(".", "", regex=False)
df = pd.concat([df_train, df_test], axis=0)
df = df.replace(" ?", np.nan)
df = df.dropna()
X, y = df.iloc[:, :-1], df.iloc[:, -1]

categorical_features = [col for col in X.columns if X[col].dtype == "object"]
numerical_features = [
col for col in X.columns if X[col].dtype in ["int64", "float64"]
]

onehot_encoder = OneHotEncoder(sparse_output=False, handle_unknown="ignore")
X_categorical = onehot_encoder.fit_transform(X[categorical_features])

scaler = StandardScaler()
X_numerical = scaler.fit_transform(X[numerical_features])

X = np.hstack([X_numerical, X_categorical])

# label encode the target variable to have the classes 0 and 1
y = LabelEncoder().fit_transform(y)

all_data = GeneralDataset(X,y)
with open(f"{path}.pkl", "wb") as file:
pickle.dump(all_data, file)
print(f"Save data to {path}.pkl")
elif "cifar10" in dataset_name:
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trainset = torchvision.datasets.CIFAR10(root="./data/cifar10", train=True, download=True, transform=transform)
testset = torchvision.datasets.CIFAR10(root="./data/cifar10", train=False,download=True, transform=transform)
X = np.vstack([trainset.data, testset.data])
y = np.hstack([trainset.targets, testset.targets])

all_data = GeneralDataset(X, y, transform)

with open(f"{path}.pkl", "wb") as file:
pickle.dump(all_data, file)
print(f"Save data to {path}.pkl")

return all_data


Expand All @@ -380,16 +379,19 @@ def get_split(
"""Select points based on the splitting methods
Args:
----
all_index (list): All the possible dataset index list
used_index (list): Index list of used points
size (int): Size of the points needs to be selected
split_method (str): Splitting (selection) method
Raises:
------
NotImplementedError: If the splitting the methods isn't implemented
ValueError: If there aren't enough points to select
Returns:
np.ndarray: List of index
"""
if split_method in "no_overlapping":
selected_index = np.setdiff1d(all_index, used_index, assume_unique=True)
Expand All @@ -414,14 +416,16 @@ def prepare_train_test_datasets(dataset_size: int, configs: dict):
"""Prepare the dataset for training the target models when the training data are sampled uniformly from the distribution (pool of all possible data).
Args:
----
dataset_size (int): Size of the whole dataset
num_datasets (int): Number of datasets we should generate
configs (dict): Data split configuration
Returns:
-------
dict: Data split information which saves the information of training points index and test points index for all target models.
"""
"""
# The index_list will save all the information about the train, test and auit for each target model.
all_index = np.arange(dataset_size)
train_size = int(configs["f_train"] * dataset_size)
Expand All @@ -437,8 +441,10 @@ def get_dataset_subset(dataset: Dataset, indices: List[int]):
"""Get a subset of the dataset.
Args:
----
dataset (torchvision.datasets): Whole dataset.
index (list): List of index.
"""
assert max(indices) < len(dataset) and min(indices) >= 0, "Index out of range"

Expand Down
22 changes: 10 additions & 12 deletions leakpro/metrics/attack_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@


class AttackResult:
"""
Contains results related to the performance of the attack.
"""Contains results related to the performance of the attack.
"""

def __init__(
Expand All @@ -27,17 +26,18 @@ def __init__(
signal_values=None,
threshold: float = None,
):
"""
Constructor.
"""Constructor.
Computes and stores the accuracy, ROC AUC score, and the confusion matrix for a metric.
Args:
----
metric_id: ID of the metric that was used (c.f. the report_files/explanations.json file).
predicted_labels: Membership predictions of the metric.
true_labels: True membership labels used to evaluate the metric.
predictions_proba: Continuous version of the predicted_labels.
signal_values: Values of the signal used by the metric.
threshold: Threshold computed by the metric.
"""
self.predicted_labels = predicted_labels
self.true_labels = true_labels
Expand All @@ -62,8 +62,7 @@ def __init__(
).ravel()

def __str__(self):
"""
Returns a string describing the metric result.
"""Returns a string describing the metric result.
"""
txt = [
f'{" METRIC RESULT OBJECT ":=^48}',
Expand All @@ -76,8 +75,7 @@ def __str__(self):


class CombinedMetricResult:
"""
Contains results related to the performance of the metric. It contains the results for multiple fpr.
"""Contains results related to the performance of the metric. It contains the results for multiple fpr.
"""

def __init__(
Expand All @@ -88,16 +86,17 @@ def __init__(
signal_values=None,
threshold: list = None,
):
"""
Constructor.
"""Constructor.
Computes and stores the accuracy, ROC AUC score, and the confusion matrix for a metric.
Args:
----
predicted_labels: Membership predictions of the metric.
true_labels: True membership labels used to evaluate the metric.
predictions_proba: Continuous version of the predicted_labels.
signal_values: Values of the signal used by the metric.
threshold: Threshold computed by the metric.
"""
self.predicted_labels = predicted_labels
self.true_labels = true_labels
Expand All @@ -120,8 +119,7 @@ def __init__(
)

def __str__(self):
"""
Returns a string describing the metric result.
"""Returns a string describing the metric result.
"""
txt_list = []
for idx in range(len(self.accuracy)):
Expand Down
3 changes: 1 addition & 2 deletions leakpro/mia_attacks/attack_factory.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from leakpro.mia_attacks.attack_utils import AttackUtils
from leakpro.mia_attacks.attacks.attack_p import AttackP
from leakpro.mia_attacks.attacks.rmia import AttackRMIA

from leakpro.mia_attacks.attack_utils import AttackUtils


class AttackFactory:
attack_classes = {
Expand Down
Loading

0 comments on commit 8d08a77

Please sign in to comment.