Skip to content

Commit

Permalink
Add extenable classes to work with
Browse files Browse the repository at this point in the history
  • Loading branch information
JMGaljaard committed May 17, 2021
1 parent cf5adb6 commit fad5d85
Show file tree
Hide file tree
Showing 5 changed files with 123 additions and 5 deletions.
6 changes: 1 addition & 5 deletions configs/poison.example.yaml
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
# For a simple config is provided in configs/poison.example.yaml
poison:
experiments:
min_ratio: 0.0
max_ratio: 0.25
steps: 2
ratio: 0.2
attack:
scenario: "fraction"
# TODO: timed, early, late
type: "LABEL_FLIP"
1 change: 1 addition & 0 deletions fltk/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ def train(self, epoch):
self.dataset.train_sampler.set_epoch(epoch)

for i, (inputs, labels) in enumerate(self.dataset.get_train_loader(), 0):
# TODO: Implement swap based on received attack.
inputs, labels = inputs.to(self.device), labels.to(self.device)

# zero the parameter gradients
Expand Down
55 changes: 55 additions & 0 deletions fltk/strategy/attack.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import logging
from abc import abstractmethod, ABC
from logging import ERROR, WARNING, INFO
from math import floor
from typing import List, Dict

from numpy import random


class Attack(ABC):

def __init__(self, max_rounds: int, seed=42):
self.logger = logging.getLogger()
self.round = 1
self.max_rounds = max_rounds
self.seed = seed

def advance_round(self):
"""
Function to advance to the ne
"""
self.round += 1
if self.round > self.max_rounds:
self.logger.log(WARNING, f'Advancing outside of preset number of rounds {self.round} / {self.max_rounds}')

@abstractmethod
def select_poisoned_workers(self, workers: List, ratio: float) -> List:
pass

@abstractmethod
def build_attack(self):
pass


class LabelFlipAttack(Attack):

def __init__(self, max_rounds: int, ratio: float, label_shuffle: Dict, seed: int = 42, random=False):
"""
"""
if 0 > ratio > 1:
self.logger.log(ERROR, f'Cannot run with a ratio of {ratio}, needs to be in range [0, 1]')
raise Exception("ratio is out of bounds")
Attack.__init__(self, max_rounds, seed)
self.ratio = ratio
self.random = random

def select_poisoned_workers(self, workers: List, ratio: float):
"""
Randomly select workers from a list of workers provided by the Federator.
"""
self.logger.log(INFO)
if not self.random:
random.seed(self.seed)
return random.choice(workers, floor(len(workers) * ratio), replace=False)
Empty file added fltk/util/poison/__init__.py
Empty file.
66 changes: 66 additions & 0 deletions fltk/util/poison/poisonpill.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import logging
from abc import abstractmethod, ABC
from typing import Dict
from torch.nn.functional import one_hot
import torch
from logging import WARNING, ERROR


class PoisonPill(ABC):

def __init__(self):
self.logger = logging.getLogger()

@abstractmethod
def poison_input(self, X: torch.Tensor, *args, **kwargs):
"""
Poison the output according to the corresponding attack.
"""
pass

@abstractmethod
def poison_output(self, X: torch.Tensor, Y: torch.Tensor, *args, **kwargs):
"""
Poison the output according to the corresponding attack.
"""
pass


class FlipPill(PoisonPill):


@staticmethod
def check_consistency(flips) -> None:
for attack in flips.keys():
if flips.get(flips[attack], -1) != attack:
# -1 because ONE_HOT encoding can never represent a negative number
logging.getLogger().log(ERROR,
f'Cyclic inconsistency, {attack} resolves back to {flips[flips[attack]]}')
raise Exception('Inconsistent flip attack!')

def __init__(self, flip_description: Dict[int, int]):
"""
Implements the flip attack scenario, where one or multiple attacks are implemented
"""
super().__init__()
assert FlipPill.check_consistency(flip_description)
self.flips = flip_description

def poison_output(self, X: torch.Tensor, Y: torch.Tensor, *args, **kwargs) -> (torch.Tensor, torch.Tensor):
"""
Apply flip attack, assumes a ONE_HOT encoded input (see torch.nn.functional.one_hot). The
"""
if kwargs['classification']:
decoded: torch.Tensor = Y.argmax(-1).cpu()
updated_decoded = decoded.apply_(lambda x: self.flips.get(x, x)).to(Y.device)
new_Y = torch.nn.functional.one_hot(updated_decoded)
else:
self.logger.log(WARNING, f'Label flip attack only support classification')
new_Y = Y
return X, new_Y

def poison_input(self, X: torch.Tensor, *args, **kwargs) -> torch.Tensor:
"""
Flip attack does not change the input during training.
"""
return X

0 comments on commit fad5d85

Please sign in to comment.