-
Notifications
You must be signed in to change notification settings - Fork 62
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
cf5adb6
commit fad5d85
Showing
5 changed files
with
123 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,6 @@ | ||
# For a simple config is provided in configs/poison.example.yaml | ||
poison: | ||
experiments: | ||
min_ratio: 0.0 | ||
max_ratio: 0.25 | ||
steps: 2 | ||
ratio: 0.2 | ||
attack: | ||
scenario: "fraction" | ||
# TODO: timed, early, late | ||
type: "LABEL_FLIP" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
import logging | ||
from abc import abstractmethod, ABC | ||
from logging import ERROR, WARNING, INFO | ||
from math import floor | ||
from typing import List, Dict | ||
|
||
from numpy import random | ||
|
||
|
||
class Attack(ABC): | ||
|
||
def __init__(self, max_rounds: int, seed=42): | ||
self.logger = logging.getLogger() | ||
self.round = 1 | ||
self.max_rounds = max_rounds | ||
self.seed = seed | ||
|
||
def advance_round(self): | ||
""" | ||
Function to advance to the ne | ||
""" | ||
self.round += 1 | ||
if self.round > self.max_rounds: | ||
self.logger.log(WARNING, f'Advancing outside of preset number of rounds {self.round} / {self.max_rounds}') | ||
|
||
@abstractmethod | ||
def select_poisoned_workers(self, workers: List, ratio: float) -> List: | ||
pass | ||
|
||
@abstractmethod | ||
def build_attack(self): | ||
pass | ||
|
||
|
||
class LabelFlipAttack(Attack): | ||
|
||
def __init__(self, max_rounds: int, ratio: float, label_shuffle: Dict, seed: int = 42, random=False): | ||
""" | ||
""" | ||
if 0 > ratio > 1: | ||
self.logger.log(ERROR, f'Cannot run with a ratio of {ratio}, needs to be in range [0, 1]') | ||
raise Exception("ratio is out of bounds") | ||
Attack.__init__(self, max_rounds, seed) | ||
self.ratio = ratio | ||
self.random = random | ||
|
||
def select_poisoned_workers(self, workers: List, ratio: float): | ||
""" | ||
Randomly select workers from a list of workers provided by the Federator. | ||
""" | ||
self.logger.log(INFO) | ||
if not self.random: | ||
random.seed(self.seed) | ||
return random.choice(workers, floor(len(workers) * ratio), replace=False) |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
import logging | ||
from abc import abstractmethod, ABC | ||
from typing import Dict | ||
from torch.nn.functional import one_hot | ||
import torch | ||
from logging import WARNING, ERROR | ||
|
||
|
||
class PoisonPill(ABC): | ||
|
||
def __init__(self): | ||
self.logger = logging.getLogger() | ||
|
||
@abstractmethod | ||
def poison_input(self, X: torch.Tensor, *args, **kwargs): | ||
""" | ||
Poison the output according to the corresponding attack. | ||
""" | ||
pass | ||
|
||
@abstractmethod | ||
def poison_output(self, X: torch.Tensor, Y: torch.Tensor, *args, **kwargs): | ||
""" | ||
Poison the output according to the corresponding attack. | ||
""" | ||
pass | ||
|
||
|
||
class FlipPill(PoisonPill): | ||
|
||
|
||
@staticmethod | ||
def check_consistency(flips) -> None: | ||
for attack in flips.keys(): | ||
if flips.get(flips[attack], -1) != attack: | ||
# -1 because ONE_HOT encoding can never represent a negative number | ||
logging.getLogger().log(ERROR, | ||
f'Cyclic inconsistency, {attack} resolves back to {flips[flips[attack]]}') | ||
raise Exception('Inconsistent flip attack!') | ||
|
||
def __init__(self, flip_description: Dict[int, int]): | ||
""" | ||
Implements the flip attack scenario, where one or multiple attacks are implemented | ||
""" | ||
super().__init__() | ||
assert FlipPill.check_consistency(flip_description) | ||
self.flips = flip_description | ||
|
||
def poison_output(self, X: torch.Tensor, Y: torch.Tensor, *args, **kwargs) -> (torch.Tensor, torch.Tensor): | ||
""" | ||
Apply flip attack, assumes a ONE_HOT encoded input (see torch.nn.functional.one_hot). The | ||
""" | ||
if kwargs['classification']: | ||
decoded: torch.Tensor = Y.argmax(-1).cpu() | ||
updated_decoded = decoded.apply_(lambda x: self.flips.get(x, x)).to(Y.device) | ||
new_Y = torch.nn.functional.one_hot(updated_decoded) | ||
else: | ||
self.logger.log(WARNING, f'Label flip attack only support classification') | ||
new_Y = Y | ||
return X, new_Y | ||
|
||
def poison_input(self, X: torch.Tensor, *args, **kwargs) -> torch.Tensor: | ||
""" | ||
Flip attack does not change the input during training. | ||
""" | ||
return X |