-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
baf79af
commit bdf14dd
Showing
4 changed files
with
218 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
"""Module with functions for preparing the dataset for training the target models.""" | ||
import torchvision | ||
from torch import as_tensor | ||
from torch.utils.data import DataLoader, Dataset | ||
from torchvision import transforms | ||
|
||
from leakpro.fl_utils.data_utils import get_meanstd | ||
|
||
|
||
def get_cifar10_dataset(pre_train_batch_size: int = 64, num_workers:int = 2) -> Dataset: | ||
"""Get the full dataset for CIFAR10.""" | ||
trainset = torchvision.datasets.CIFAR10(root="./data", train=True, download=True, transform=transforms.ToTensor()) | ||
client_dataset = torchvision.datasets.CIFAR10(root="./data", train=False, download=True, transform=transforms.ToTensor()) | ||
data_mean, data_std = get_meanstd(trainset) | ||
transform = transforms.Compose([ | ||
transforms.ToTensor(), | ||
transforms.Normalize(data_mean, data_std)]) | ||
transform_train = transforms.Compose([ | ||
transforms.RandomCrop(32, padding=4), | ||
transforms.RandomHorizontalFlip(), | ||
transform]) | ||
trainset.transform = transform_train | ||
client_dataset.transform = transform_train | ||
data_mean = as_tensor(data_mean)[:, None, None] | ||
data_std = as_tensor(data_std)[:, None, None] | ||
pre_train_loader = DataLoader(trainset, batch_size=pre_train_batch_size, | ||
shuffle=False, drop_last=True, num_workers=num_workers) | ||
return pre_train_loader, client_dataset, data_mean, data_std |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
"""Inverting on a single image.""" | ||
|
||
import os | ||
|
||
import torch | ||
from cifar import get_cifar10_dataset | ||
from model import ResNet | ||
from pre_train import pre_train | ||
from torchvision.models.resnet import BasicBlock | ||
|
||
from leakpro.fl_utils.gia_train import train | ||
from leakpro.run import run_inverting_audit | ||
from leakpro.utils.logger import logger | ||
|
||
if __name__ == "__main__": | ||
model = ResNet(BasicBlock, [5, 5, 5], num_classes=10, base_width=16 * 10) | ||
pre_train_loader, client_dataset, data_mean, data_std = get_cifar10_dataset() | ||
|
||
# pre training the model is important since the attacks work better on later stages of model training. | ||
# check out the transforms in the dataset, pre-training with those transformations make the attacks a lot stronger. | ||
pre_train_epochs = 10 | ||
model_path = "model_epochs_" + str(pre_train_epochs) + ".pth" | ||
if os.path.exists(model_path): | ||
model.load_state_dict(torch.load(model_path)) | ||
logger.info(f"Model loaded from {model_path}") | ||
else: | ||
logger.info("No saved model found. Training from scratch...") | ||
pre_train(model, pre_train_loader, epochs=10) | ||
torch.save(model.state_dict(), model_path) | ||
logger.info(f"Model trained and saved to {model_path}") | ||
|
||
# meta train function designed to work with GIA | ||
train_fn = train | ||
# run audit with multiple client partitions, different epochs and total_variation scale. | ||
result = run_inverting_audit(model, client_dataset, train_fn, data_mean, data_std) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
"""ResNet model.""" | ||
from typing import Optional | ||
|
||
import torch | ||
import torchvision | ||
from torch import nn | ||
from torchvision.models.resnet import BasicBlock, Bottleneck | ||
|
||
from leakpro.utils.import_helper import Self | ||
|
||
|
||
class ResNet(torchvision.models.ResNet): | ||
"""ResNet generalization for CIFAR thingies.""" | ||
|
||
def __init__(self: Self, block: BasicBlock, layers: list, num_classes: int=10, zero_init_residual: bool=False, # noqa: C901 | ||
groups: int=1, base_width: int=64, replace_stride_with_dilation: list=None, | ||
norm_layer: Optional[nn.Module]=None, strides: list=[1, 2, 2, 2], pool: str="avg") -> None: # noqa: B006 | ||
"""Initialize as usual. Layers and strides are scriptable.""" | ||
super(torchvision.models.ResNet, self).__init__() # nn.Module | ||
if norm_layer is None: | ||
norm_layer = nn.BatchNorm2d | ||
self._norm_layer = norm_layer | ||
|
||
|
||
self.dilation = 1 | ||
if replace_stride_with_dilation is None: | ||
# each element in the tuple indicates if we should replace | ||
# the 2x2 stride with a dilated convolution instead | ||
replace_stride_with_dilation = [False, False, False, False] | ||
if len(replace_stride_with_dilation) != 4: | ||
raise ValueError("replace_stride_with_dilation should be None " | ||
"or a 4-element tuple, got {}".format(replace_stride_with_dilation)) | ||
self.groups = groups | ||
|
||
self.inplanes = base_width | ||
self.base_width = 64 # Do this to circumvent BasicBlock errors. The value is not actually used. | ||
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False) | ||
self.bn1 = norm_layer(self.inplanes) | ||
self.relu = nn.ReLU(inplace=True) | ||
|
||
self.layers = torch.nn.ModuleList() | ||
width = self.inplanes | ||
for idx, layer in enumerate(layers): | ||
self.layers.append(self._make_layer(block, width, layer, stride=strides[idx], dilate=replace_stride_with_dilation[idx])) | ||
width *= 2 | ||
|
||
self.pool = nn.AdaptiveAvgPool2d((1, 1)) if pool == "avg" else nn.AdaptiveMaxPool2d((1, 1)) | ||
self.fc = nn.Linear(width // 2 * block.expansion, num_classes) | ||
|
||
for m in self.modules(): | ||
if isinstance(m, nn.Conv2d): | ||
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") | ||
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): | ||
nn.init.constant_(m.weight, 1) | ||
nn.init.constant_(m.bias, 0) | ||
|
||
# Zero-initialize the last BN in each residual branch, | ||
# so that the residual branch starts with zeros, and each residual block behaves like an identity. | ||
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 | ||
if zero_init_residual: | ||
for m in self.modules(): | ||
if isinstance(m, Bottleneck): | ||
nn.init.constant_(m.bn3.weight, 0) | ||
elif isinstance(m, torchvision.models.resnet.BasicBlock): | ||
nn.init.constant_(m.bn2.weight, 0) | ||
|
||
|
||
def _forward_impl(self: Self, x: torch.Tensor) -> None: | ||
# See note [TorchScript super()] | ||
x = self.conv1(x) | ||
x = self.bn1(x) | ||
x = self.relu(x) | ||
|
||
for layer in self.layers: | ||
x = layer(x) | ||
|
||
x = self.pool(x) | ||
x = torch.flatten(x, 1) | ||
x = self.fc(x) | ||
|
||
return x |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
"""To be used in the future maybe.""" | ||
from collections import defaultdict | ||
|
||
import torch | ||
from torch import nn | ||
from torch.utils.data import DataLoader | ||
|
||
from leakpro.utils.import_helper import Self | ||
from leakpro.utils.logger import logger | ||
|
||
|
||
def pre_train(model: nn.Module, trainloader: DataLoader, epochs: int = 10) -> None: | ||
"""Pre train a model for a specified amount of epochs.""" | ||
loss_fn = Classification() | ||
setup = {"dtype": torch.float, "device": torch.device("cuda" if torch.cuda.is_available() else "cpu")} | ||
model.to(setup["device"]) | ||
stats = defaultdict(list) | ||
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, | ||
weight_decay=5e-4, nesterov=True) | ||
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, | ||
milestones=[120 // 2.667, 120 // 1.6, | ||
120 // 1.142], gamma=0.1) | ||
|
||
for _ in range(epochs): | ||
logger.info(stats) | ||
model.train() | ||
epoch_metric = 0 | ||
for i, (inputs, targets) in enumerate(trainloader): | ||
optimizer.zero_grad() | ||
inputs = inputs.to(**setup) | ||
targets = targets.to(device=setup["device"], non_blocking=False) | ||
outputs = model(inputs) | ||
loss, _, _ = loss_fn(outputs, targets) | ||
if i % 200 ==0 : | ||
logger.info(f"train loss {loss}") | ||
loss.backward() | ||
optimizer.step() | ||
|
||
metric, name, _ = loss_fn.metric(outputs, targets) | ||
epoch_metric += metric.item() | ||
scheduler.step() | ||
|
||
stats["train_" + name].append(epoch_metric / (len(trainloader) + 1)) | ||
logger.info(stats) | ||
|
||
|
||
class Classification(): | ||
"""A classical NLL loss for classification. Evaluation has the softmax baked in. | ||
The minimized criterion is cross entropy, the actual metric is total accuracy. | ||
""" | ||
|
||
def __init__(self: Self) -> None: | ||
"""Init with torch MSE.""" | ||
self.loss_fn = torch.nn.CrossEntropyLoss(weight=None, size_average=None, ignore_index=-100, | ||
reduce=None, reduction="mean") | ||
|
||
def __call__(self, x=None, y=None): # noqa: ANN001, ANN101, ANN204 | ||
"""Return l(x, y).""" | ||
name = "CrossEntropy" | ||
format = "1.5f" | ||
if x is None: | ||
return name, format | ||
value = self.loss_fn(x, y) | ||
return value, name, format | ||
|
||
def metric(self, x=None, y=None): # noqa: ANN001, ANN101, ANN201 | ||
"""The actually sought metric.""" | ||
name = "Accuracy" | ||
format = "6.2%" | ||
if x is None: | ||
return name, format | ||
value = (x.data.argmax(dim=1) == y).sum().float() / y.shape[0] | ||
return value.detach(), name, format |