diff --git a/configs/poison.example.yaml b/configs/poison.example.yaml index 461ef328..c549fbb5 100644 --- a/configs/poison.example.yaml +++ b/configs/poison.example.yaml @@ -1,10 +1,6 @@ # For a simple config is provided in configs/poison.example.yaml poison: - experiments: - min_ratio: 0.0 - max_ratio: 0.25 - steps: 2 + ratio: 0.2 attack: scenario: "fraction" - # TODO: timed, early, late type: "LABEL_FLIP" diff --git a/fltk/client.py b/fltk/client.py index 5de27e45..bde8cb45 100644 --- a/fltk/client.py +++ b/fltk/client.py @@ -192,6 +192,7 @@ def train(self, epoch): self.dataset.train_sampler.set_epoch(epoch) for i, (inputs, labels) in enumerate(self.dataset.get_train_loader(), 0): + # TODO: Implement swap based on received attack. inputs, labels = inputs.to(self.device), labels.to(self.device) # zero the parameter gradients diff --git a/fltk/strategy/attack.py b/fltk/strategy/attack.py new file mode 100644 index 00000000..cb8d1c6b --- /dev/null +++ b/fltk/strategy/attack.py @@ -0,0 +1,55 @@ +import logging +from abc import abstractmethod, ABC +from logging import ERROR, WARNING, INFO +from math import floor +from typing import List, Dict + +from numpy import random + + +class Attack(ABC): + + def __init__(self, max_rounds: int, seed=42): + self.logger = logging.getLogger() + self.round = 1 + self.max_rounds = max_rounds + self.seed = seed + + def advance_round(self): + """ + Function to advance to the ne + """ + self.round += 1 + if self.round > self.max_rounds: + self.logger.log(WARNING, f'Advancing outside of preset number of rounds {self.round} / {self.max_rounds}') + + @abstractmethod + def select_poisoned_workers(self, workers: List, ratio: float) -> List: + pass + + @abstractmethod + def build_attack(self): + pass + + +class LabelFlipAttack(Attack): + + def __init__(self, max_rounds: int, ratio: float, label_shuffle: Dict, seed: int = 42, random=False): + """ + + """ + if 0 > ratio > 1: + self.logger.log(ERROR, f'Cannot run with a ratio of {ratio}, needs to be in range [0, 1]') + raise Exception("ratio is out of bounds") + Attack.__init__(self, max_rounds, seed) + self.ratio = ratio + self.random = random + + def select_poisoned_workers(self, workers: List, ratio: float): + """ + Randomly select workers from a list of workers provided by the Federator. + """ + self.logger.log(INFO) + if not self.random: + random.seed(self.seed) + return random.choice(workers, floor(len(workers) * ratio), replace=False) diff --git a/fltk/util/poison/__init__.py b/fltk/util/poison/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/fltk/util/poison/poisonpill.py b/fltk/util/poison/poisonpill.py new file mode 100644 index 00000000..a0dc3dc0 --- /dev/null +++ b/fltk/util/poison/poisonpill.py @@ -0,0 +1,66 @@ +import logging +from abc import abstractmethod, ABC +from typing import Dict +from torch.nn.functional import one_hot +import torch +from logging import WARNING, ERROR + + +class PoisonPill(ABC): + + def __init__(self): + self.logger = logging.getLogger() + + @abstractmethod + def poison_input(self, X: torch.Tensor, *args, **kwargs): + """ + Poison the output according to the corresponding attack. + """ + pass + + @abstractmethod + def poison_output(self, X: torch.Tensor, Y: torch.Tensor, *args, **kwargs): + """ + Poison the output according to the corresponding attack. + """ + pass + + +class FlipPill(PoisonPill): + + + @staticmethod + def check_consistency(flips) -> None: + for attack in flips.keys(): + if flips.get(flips[attack], -1) != attack: + # -1 because ONE_HOT encoding can never represent a negative number + logging.getLogger().log(ERROR, + f'Cyclic inconsistency, {attack} resolves back to {flips[flips[attack]]}') + raise Exception('Inconsistent flip attack!') + + def __init__(self, flip_description: Dict[int, int]): + """ + Implements the flip attack scenario, where one or multiple attacks are implemented + """ + super().__init__() + assert FlipPill.check_consistency(flip_description) + self.flips = flip_description + + def poison_output(self, X: torch.Tensor, Y: torch.Tensor, *args, **kwargs) -> (torch.Tensor, torch.Tensor): + """ + Apply flip attack, assumes a ONE_HOT encoded input (see torch.nn.functional.one_hot). The + """ + if kwargs['classification']: + decoded: torch.Tensor = Y.argmax(-1).cpu() + updated_decoded = decoded.apply_(lambda x: self.flips.get(x, x)).to(Y.device) + new_Y = torch.nn.functional.one_hot(updated_decoded) + else: + self.logger.log(WARNING, f'Label flip attack only support classification') + new_Y = Y + return X, new_Y + + def poison_input(self, X: torch.Tensor, *args, **kwargs) -> torch.Tensor: + """ + Flip attack does not change the input during training. + """ + return X