Skip to content

Commit

Permalink
add inverting audit
Browse files Browse the repository at this point in the history
  • Loading branch information
viktorvaladi committed Oct 21, 2024
1 parent baf79af commit bdf14dd
Show file tree
Hide file tree
Showing 4 changed files with 218 additions and 0 deletions.
28 changes: 28 additions & 0 deletions examples/gia/cifar10_inverting_audit/cifar.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
"""Module with functions for preparing the dataset for training the target models."""
import torchvision
from torch import as_tensor
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms

from leakpro.fl_utils.data_utils import get_meanstd


def get_cifar10_dataset(pre_train_batch_size: int = 64, num_workers:int = 2) -> Dataset:
"""Get the full dataset for CIFAR10."""
trainset = torchvision.datasets.CIFAR10(root="./data", train=True, download=True, transform=transforms.ToTensor())
client_dataset = torchvision.datasets.CIFAR10(root="./data", train=False, download=True, transform=transforms.ToTensor())
data_mean, data_std = get_meanstd(trainset)
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(data_mean, data_std)])
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transform])
trainset.transform = transform_train
client_dataset.transform = transform_train
data_mean = as_tensor(data_mean)[:, None, None]
data_std = as_tensor(data_std)[:, None, None]
pre_train_loader = DataLoader(trainset, batch_size=pre_train_batch_size,
shuffle=False, drop_last=True, num_workers=num_workers)
return pre_train_loader, client_dataset, data_mean, data_std
35 changes: 35 additions & 0 deletions examples/gia/cifar10_inverting_audit/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
"""Inverting on a single image."""

import os

import torch
from cifar import get_cifar10_dataset
from model import ResNet
from pre_train import pre_train
from torchvision.models.resnet import BasicBlock

from leakpro.fl_utils.gia_train import train
from leakpro.run import run_inverting_audit
from leakpro.utils.logger import logger

if __name__ == "__main__":
model = ResNet(BasicBlock, [5, 5, 5], num_classes=10, base_width=16 * 10)
pre_train_loader, client_dataset, data_mean, data_std = get_cifar10_dataset()

# pre training the model is important since the attacks work better on later stages of model training.
# check out the transforms in the dataset, pre-training with those transformations make the attacks a lot stronger.
pre_train_epochs = 10
model_path = "model_epochs_" + str(pre_train_epochs) + ".pth"
if os.path.exists(model_path):
model.load_state_dict(torch.load(model_path))
logger.info(f"Model loaded from {model_path}")
else:
logger.info("No saved model found. Training from scratch...")
pre_train(model, pre_train_loader, epochs=10)
torch.save(model.state_dict(), model_path)
logger.info(f"Model trained and saved to {model_path}")

# meta train function designed to work with GIA
train_fn = train
# run audit with multiple client partitions, different epochs and total_variation scale.
result = run_inverting_audit(model, client_dataset, train_fn, data_mean, data_std)
81 changes: 81 additions & 0 deletions examples/gia/cifar10_inverting_audit/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
"""ResNet model."""
from typing import Optional

import torch
import torchvision
from torch import nn
from torchvision.models.resnet import BasicBlock, Bottleneck

from leakpro.utils.import_helper import Self


class ResNet(torchvision.models.ResNet):
"""ResNet generalization for CIFAR thingies."""

def __init__(self: Self, block: BasicBlock, layers: list, num_classes: int=10, zero_init_residual: bool=False, # noqa: C901
groups: int=1, base_width: int=64, replace_stride_with_dilation: list=None,
norm_layer: Optional[nn.Module]=None, strides: list=[1, 2, 2, 2], pool: str="avg") -> None: # noqa: B006
"""Initialize as usual. Layers and strides are scriptable."""
super(torchvision.models.ResNet, self).__init__() # nn.Module
if norm_layer is None:
norm_layer = nn.BatchNorm2d
self._norm_layer = norm_layer


self.dilation = 1
if replace_stride_with_dilation is None:
# each element in the tuple indicates if we should replace
# the 2x2 stride with a dilated convolution instead
replace_stride_with_dilation = [False, False, False, False]
if len(replace_stride_with_dilation) != 4:
raise ValueError("replace_stride_with_dilation should be None "
"or a 4-element tuple, got {}".format(replace_stride_with_dilation))
self.groups = groups

self.inplanes = base_width
self.base_width = 64 # Do this to circumvent BasicBlock errors. The value is not actually used.
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
self.bn1 = norm_layer(self.inplanes)
self.relu = nn.ReLU(inplace=True)

self.layers = torch.nn.ModuleList()
width = self.inplanes
for idx, layer in enumerate(layers):
self.layers.append(self._make_layer(block, width, layer, stride=strides[idx], dilate=replace_stride_with_dilation[idx]))
width *= 2

self.pool = nn.AdaptiveAvgPool2d((1, 1)) if pool == "avg" else nn.AdaptiveMaxPool2d((1, 1))
self.fc = nn.Linear(width // 2 * block.expansion, num_classes)

for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)

# Zero-initialize the last BN in each residual branch,
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
if zero_init_residual:
for m in self.modules():
if isinstance(m, Bottleneck):
nn.init.constant_(m.bn3.weight, 0)
elif isinstance(m, torchvision.models.resnet.BasicBlock):
nn.init.constant_(m.bn2.weight, 0)


def _forward_impl(self: Self, x: torch.Tensor) -> None:
# See note [TorchScript super()]
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)

for layer in self.layers:
x = layer(x)

x = self.pool(x)
x = torch.flatten(x, 1)
x = self.fc(x)

return x
74 changes: 74 additions & 0 deletions examples/gia/cifar10_inverting_audit/pre_train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
"""To be used in the future maybe."""
from collections import defaultdict

import torch
from torch import nn
from torch.utils.data import DataLoader

from leakpro.utils.import_helper import Self
from leakpro.utils.logger import logger


def pre_train(model: nn.Module, trainloader: DataLoader, epochs: int = 10) -> None:
"""Pre train a model for a specified amount of epochs."""
loss_fn = Classification()
setup = {"dtype": torch.float, "device": torch.device("cuda" if torch.cuda.is_available() else "cpu")}
model.to(setup["device"])
stats = defaultdict(list)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9,
weight_decay=5e-4, nesterov=True)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
milestones=[120 // 2.667, 120 // 1.6,
120 // 1.142], gamma=0.1)

for _ in range(epochs):
logger.info(stats)
model.train()
epoch_metric = 0
for i, (inputs, targets) in enumerate(trainloader):
optimizer.zero_grad()
inputs = inputs.to(**setup)
targets = targets.to(device=setup["device"], non_blocking=False)
outputs = model(inputs)
loss, _, _ = loss_fn(outputs, targets)
if i % 200 ==0 :
logger.info(f"train loss {loss}")
loss.backward()
optimizer.step()

metric, name, _ = loss_fn.metric(outputs, targets)
epoch_metric += metric.item()
scheduler.step()

stats["train_" + name].append(epoch_metric / (len(trainloader) + 1))
logger.info(stats)


class Classification():
"""A classical NLL loss for classification. Evaluation has the softmax baked in.
The minimized criterion is cross entropy, the actual metric is total accuracy.
"""

def __init__(self: Self) -> None:
"""Init with torch MSE."""
self.loss_fn = torch.nn.CrossEntropyLoss(weight=None, size_average=None, ignore_index=-100,
reduce=None, reduction="mean")

def __call__(self, x=None, y=None): # noqa: ANN001, ANN101, ANN204
"""Return l(x, y)."""
name = "CrossEntropy"
format = "1.5f"
if x is None:
return name, format
value = self.loss_fn(x, y)
return value, name, format

def metric(self, x=None, y=None): # noqa: ANN001, ANN101, ANN201
"""The actually sought metric."""
name = "Accuracy"
format = "6.2%"
if x is None:
return name, format
value = (x.data.argmax(dim=1) == y).sum().float() / y.shape[0]
return value.detach(), name, format

0 comments on commit bdf14dd

Please sign in to comment.