Skip to content

Commit

Permalink
Resolve issues with label flipping
Browse files Browse the repository at this point in the history
  • Loading branch information
JMGaljaard committed May 25, 2021
1 parent 4e0b1cc commit 937622f
Show file tree
Hide file tree
Showing 6 changed files with 39 additions and 50 deletions.
9 changes: 4 additions & 5 deletions configs/local_experiment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,9 @@ system:
# For a simple config is provided in configs/poison.example.yaml
poison:
seed: 420
experiments:
min_ratio: 0.0
max_ratio: 0.25
steps: 2
ratio: 0.5
attack:
type: "flip"
config: [ { 0:9, 9:0 } ]
config:
- 0: 9
- 9: 0
42 changes: 14 additions & 28 deletions fltk/federator.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,19 @@
import abc
import datetime
import logging
import time
from typing import List, Any, Callable
from pathlib import Path
from typing import List, Callable

from dataclass_csv import DataclassWriter
from torch.distributed import rpc
from torch.utils.tensorboard import SummaryWriter

from fltk.client import Client
from fltk.datasets.data_distribution import distribute_batches_equally
from fltk.strategy.attack import Attack, LabelFlipAttack
from fltk.strategy.attack import Attack
from fltk.strategy.client_selection import random_selection
from fltk.util.arguments import Arguments
from fltk.util.base_config import BareConfig
from fltk.util.data_loader_utils import load_train_data_loader, load_test_data_loader, \
generate_data_loaders_from_distributed_dataset
from fltk.util.fed_avg import average_nn_parameters
from fltk.util.log import FLLogger
from torchsummary import summary
from torch.utils.tensorboard import SummaryWriter
from pathlib import Path
import logging

from fltk.util.poison.poisonpill import FlipPill
from fltk.util.results import EpochData
from fltk.util.tensor_converter import convert_distributed_data_into_numpy

logging.basicConfig(level=logging.DEBUG)

Expand Down Expand Up @@ -84,7 +74,7 @@ class Federator(object):
client_data = {}
poisoned_workers = {}

def __init__(self, client_id_triple, num_epochs=3, config=None, attack=Attack):
def __init__(self, client_id_triple, num_epochs=3, config: BareConfig = None, attack: Attack = None):
log_rref = rpc.RRef(FLLogger())
self.log_rref = log_rref
self.num_epoch = num_epochs
Expand All @@ -105,7 +95,6 @@ def __init__(self, client_id_triple, num_epochs=3, config=None, attack=Attack):
# Poisoning
self.attack = attack


def create_clients(self, client_id_triple):
for id, rank, world_size in client_id_triple:
client = rpc.remote(id, Client, kwargs=dict(id=id, log_rref=self.log_rref, rank=rank, world_size=world_size,
Expand Down Expand Up @@ -139,7 +128,8 @@ def rpc_test_all(self):

def client_load_data(self, poison_pill):
for client in self.clients:
_remote_method_async(Client.init_dataloader, client.ref, pill=None if poison_pill and client not in self.poisoned_workers else poison_pill)
_remote_method_async(Client.init_dataloader, client.ref,
pill=None if poison_pill and client not in self.poisoned_workers else poison_pill)

def clients_ready(self):
all_ready = False
Expand All @@ -162,7 +152,7 @@ def clients_ready(self):
time.sleep(2)
logging.info('All clients are ready')

def remote_run_epoch(self, epochs, attack: Attack):
def remote_run_epoch(self, epochs):
responses = []
client_weights = []
selected_clients = self.select_clients(self.config.clients_per_round)
Expand All @@ -174,7 +164,7 @@ def remote_run_epoch(self, epochs, attack: Attack):
"""
pill = None
if client in self.poisoned_workers:
pill = attack.get_poison_pill()
pill = self.attack.get_poison_pill()
responses.append((client, _remote_method_async(Client.run_epochs, client.ref, num_epoch=epochs, pill=pill)))
self.epoch_counter += epochs
for res in responses:
Expand Down Expand Up @@ -249,39 +239,35 @@ def save_epoch_data(self):
def ensure_path_exists(self, path):
Path(path).mkdir(parents=True, exist_ok=True)

def run(self, attack: Attack = None):
def run(self):
"""
Main loop of the Federator
:return:
"""

# # Select clients which will be poisened
# TODO: get attack type and ratio from config, temp solution now
ratio = 0.2
poison_pill = None
if attack:
self.poisoned_workers = attack.select_poisoned_workers(self.clients, ratio)
if self.attack:
self.poisoned_workers = self.attack.select_poisoned_workers(self.clients)
poison_pill = self.attack.get_poison_pill()

self.client_load_data(poison_pill)
self.ping_all()
self.clients_ready()
self.update_client_data_sizes()



epoch_to_run = self.num_epoch
addition = 0
epoch_to_run = self.config.epochs
epoch_size = self.config.epochs_per_cycle
for epoch in range(epoch_to_run):
print(f'Running epoch {epoch}')
self.remote_run_epoch(epoch_size, attack)
self.remote_run_epoch(epoch_size)
addition += 1
logging.info('Printing client data')
print(self.client_data)

logging.info(f'Saving data')
self.save_epoch_data()
logging.info(f'Federator is stopping')

6 changes: 3 additions & 3 deletions fltk/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@
logging.basicConfig(level=logging.DEBUG)


def run_ps(rpc_ids_triple, args, attack=Attack):
def run_ps(rpc_ids_triple, args, attack: Attack = None):
print(f'Starting the federator...')
fed = Federator(rpc_ids_triple, config=args, attack=Attack)
fed.run(attack)
fed = Federator(rpc_ids_triple, config=args, attack=attack)
fed.run()

def run_single(rank, world_size, host = None, args = None, nic = None, attack=None):
logging.info(f'Starting with rank={rank} and world size={world_size}')
Expand Down
25 changes: 14 additions & 11 deletions fltk/strategy/attack.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import logging
from abc import abstractmethod, ABC
from logging import ERROR, WARNING, INFO
from math import floor
from math import floor, ceil
from typing import List, Dict

from numpy import random

from collections import ChainMap
from fltk.util.base_config import BareConfig
from fltk.util.poison.poisonpill import FlipPill, PoisonPill

Expand All @@ -27,15 +27,15 @@ def advance_round(self):
self.logger.log(WARNING, f'Advancing outside of preset number of rounds {self.round} / {self.max_rounds}')

@abstractmethod
def select_poisoned_workers(self, workers: List, ratio: float) -> List:
def select_poisoned_workers(self, workers: List, ratio: float = None) -> List:
pass

@abstractmethod
def build_attack(self):
pass

@abstractmethod
def get_poison_pill(self, *args, **kwargs) -> PoisonPill:
def get_poison_pill(self) -> PoisonPill:
pass


Expand All @@ -51,7 +51,7 @@ def build_attack(self, flip_description=None) -> PoisonPill:
return FlipPill(flip_description=flip_description)

def __init__(self, max_rounds: int = 0, ratio: float = 0, label_shuffle: Dict = None, seed: int = 42, random=False,
cfg: dict = None):
cfg: BareConfig = None):
"""
@param max_rounds:
@type max_rounds: int
Expand All @@ -70,21 +70,23 @@ def __init__(self, max_rounds: int = 0, ratio: float = 0, label_shuffle: Dict =
if 0 > ratio > 1:
self.logger.log(ERROR, f'Cannot run with a ratio of {ratio}, needs to be in range [0, 1]')
raise Exception("ratio is out of bounds")
Attack.__init__(self, cfg.get('total_epochs', 0), cfg.get('poison', None).get('seed', None))
Attack.__init__(self, cfg.epochs, cfg.get_poison_config().get('seed', None))
else:
Attack.__init__(self, max_rounds, seed)
self.ratio = ratio
self.label_shuffle = label_shuffle
self.ratio = cfg.poison['ratio']
self.label_shuffle = dict(ChainMap(*cfg.get_attack_config()['config']))
self.random = random

def select_poisoned_workers(self, workers: List, ratio: float):
def select_poisoned_workers(self, workers: List, ratio: float = None):
"""
Randomly select workers from a list of workers provided by the Federator.
"""
self.logger.log(INFO, "Selecting workers to gather from")
if not self.random:
random.seed(self.seed)
return random.choice(workers, floor(len(workers) * ratio), replace=False)
cloned_workers = workers.copy()
random.shuffle(cloned_workers)
return cloned_workers[0:ceil(len(workers) * self.ratio)]

def get_poison_pill(self):
return FlipPill(self.label_shuffle)
Expand All @@ -101,6 +103,7 @@ def create_attack(cfg: BareConfig) -> Attack:
attack_class = attack_mapper.get(cfg.get_attack_type(), None)

if not attack_class is None:
attack = attack_class(cfg=cfg.get_attack_config())
attack = attack_class(cfg=cfg)
else:
raise Exception("Requested attack is not supported...")
return attack
2 changes: 1 addition & 1 deletion fltk/util/base_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ def get_attack_type(self) -> str:
return self.poison['attack']['type']

def get_attack_config(self) -> dict:
return self.poison['attack']['config']
return self.poison['attack']


def __str__(self):
Expand Down
5 changes: 3 additions & 2 deletions fltk/util/poison/poisonpill.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,14 @@ def poison_targets(self, targets: List[int]) -> List[int]:
return list(map(lambda y: self.flips.get(y, y), targets))

@staticmethod
def check_consistency(flips) -> None:
def check_consistency(flips) -> bool:
for attack in flips.keys():
if flips.get(flips[attack], -1) != attack:
if flips.get(flips.get(attack, -2), -1) != attack:
# -1 because ONE_HOT encoding can never represent a negative number
logging.getLogger().log(ERROR,
f'Cyclic inconsistency, {attack} resolves back to {flips[flips[attack]]}')
raise Exception('Inconsistent flip attack!')
return True

def __init__(self, flip_description: Dict[int, int]):
"""
Expand Down

0 comments on commit 937622f

Please sign in to comment.