Skip to content

Commit

Permalink
Add pill to calls to client
Browse files Browse the repository at this point in the history
  • Loading branch information
JMGaljaard committed May 25, 2021
1 parent 81148a8 commit c33926a
Show file tree
Hide file tree
Showing 14 changed files with 187 additions and 80 deletions.
4 changes: 2 additions & 2 deletions configs/experiment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ tensor_board_active: true
clients_per_round: 1
system:
federator:
hostname: '131.180.40.72'
nic: 'wlo1'
hostname: '10.5.0.2'
nic: 'eth0'
clients:
amount: 1
# For a simple config is provided in configs/poison.example.yaml
Expand Down
7 changes: 4 additions & 3 deletions configs/local_experiment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ cuda: false
experiment_prefix: 'experiment_single_machine'
output_location: 'output'
tensor_board_active: true
clients_per_round: 1
system:
federator:
hostname: '172.18.0.2'
Expand All @@ -24,10 +23,12 @@ system:
amount: 1
# For a simple config is provided in configs/poison.example.yaml
poison:
seed: 420
experiments:
min_ratio: 0.0
max_ratio: 0.25
steps: 2
attack:
scenario: "fraction"
type: "LABEL_FLIP"
type: "flip"
# Flip 9 and 0
config: { 0:9, 9:0 }
31 changes: 25 additions & 6 deletions docker-compose-adapted.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,21 +26,40 @@ services:
command: python3 -m fltk single configs/local_experiment.yaml --rank=1
volumes:
- ./configs:/opt/federation-lab/configs
- ./docker_data:/opt/federation-lab/data
# - ./docker_data:/opt/federation-lab/data
- ./default_models:/opt/federation-lab/default_models
- ./data_loaders:/opt/federation-lab/data_loaders
# - ./data_loaders:/opt/federation-lab/data_loaders
environment:
- PYTHONUNBUFFERED=1
- RANK=1
- WORLD_SIZE=2
- GLOO_SOCKET_IFNAME=eth0
- TP_SOCKET_IFNAME=eth0
ports:
- 5001:5000
depends_on:
- fl_server
networks:
- local_network_dev
deploy:
resources:
limits:
cpus: '0.5'
memory: 1024M

client_slow_2:
restart: 'no'
build: .
command: python3 -m fltk single configs/local_experiment.yaml --rank=2
volumes:
- ./configs:/opt/federation-lab/configs
# - ./docker_data:/opt/federation-lab/data
- ./default_models:/opt/federation-lab/default_models
# - ./data_loaders:/opt/federation-lab/data_loaders
environment:
- PYTHONUNBUFFERED=1
- RANK=1
- WORLD_SIZE=2
- GLOO_SOCKET_IFNAME=eth0
- TP_SOCKET_IFNAME=eth0
depends_on:
- fl_server
deploy:
resources:
limits:
Expand Down
13 changes: 9 additions & 4 deletions fltk/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
import torch.multiprocessing as mp
from fltk.federator import Federator
from fltk.launch import run_single, run_spawn
from fltk.strategy.attack import create_attack
from fltk.util.base_config import BareConfig
from fltk.util.poison.poisonpill import FlipPill

logging.basicConfig(level=logging.DEBUG)

Expand Down Expand Up @@ -76,7 +78,7 @@ def main():
yaml_data = yaml.load(config_file, Loader=yaml.FullLoader)
cfg.merge_yaml(yaml_data)
if args.mode == 'poison':
perform_poison_experiment(args, cfg, yaml_data)
perform_poison_experiment(args, cfg, parser, yaml_data)
elif args.mode == 'single':
perform_single_experiment(args, cfg, parser, yaml_data)
else:
Expand All @@ -102,7 +104,9 @@ def perform_single_experiment(args, cfg, parser, yaml_data):
run_single(rank=args.rank, world_size=world_size, host=master_address, args=cfg, nic=nic)


def perform_poison_experiment(args, cfg, yaml_data):


def perform_poison_experiment(args, cfg, parser, yaml_data):
"""
Function to start poisoned experiment.
"""
Expand All @@ -117,16 +121,17 @@ def perform_poison_experiment(args, cfg, yaml_data):
master_address = args.host
nic = args.nic

attack = create_attack(cfg)
if not world_size:
world_size = yaml_data['system']['clients']['amount'] + 1
if not master_address:
master_address = yaml_data['system']['federator']['hostname']
if not nic:
nic = yaml_data['system']['federator']['nic']
print(f'rank={args.rank}, world_size={world_size}, host={master_address}, args=cfg, nic={nic}')
run_single(rank=args.rank, world_size=world_size, host=master_address, args=cfg, nic=nic)
run_single(rank=args.rank, world_size=world_size, host=master_address, args=cfg, nic=nic, attack=attack)


if __name__ == "__main__":
load_dotenv(Path("./"))
load_dotenv()
main()
4 changes: 2 additions & 2 deletions fltk/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,14 +104,14 @@ def set_configuration(self, config: str):
def init(self):
pass

def init_dataloader(self, ):
def init_dataloader(self, pill: PoisonPill = None):
self.args.distributed = True
self.args.rank = self.rank

self.args.world_size = self.world_size

try:
self.dataset = self.args.DistDatasets[self.args.dataset_name](self.args)
self.dataset = self.args.DistDatasets[self.args.dataset_name](self.args, pill)
except Exception as e:
tb = traceback.format_exc()
print(tb)
Expand Down
15 changes: 12 additions & 3 deletions fltk/datasets/distributed/cifar10.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,19 @@
from fltk.strategy.data_samplers import get_sampler
import logging

from fltk.util.poison.poisonpill import PoisonPill


class DistCIFAR10Dataset(DistDataset):

def __init__(self, args):
super(DistCIFAR10Dataset, self).__init__(args)
def __init__(self, args, pill: PoisonPill = None):
super(DistCIFAR10Dataset, self).__init__(args, pill)
self.init_train_dataset()
self.init_test_dataset()
if pill:
self.ingest_pill(pill)

def init_train_dataset(self):
def init_train_dataset(self, pill: PoisonPill = None):
dist_loader_text = "distributed" if self.args.get_distributed() else ""
self.get_args().get_logger().debug(f"Loading '{dist_loader_text}' CIFAR10 train data")
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
Expand All @@ -26,6 +30,10 @@ def init_train_dataset(self):
])
self.train_dataset = datasets.CIFAR10(root=self.get_args().get_data_path(), train=True, download=True,
transform=transform)

# Poison
if pill:
self.ingest_pill(pill)
self.train_sampler = get_sampler(self.train_dataset, self.args)
self.train_loader = DataLoader(self.train_dataset, batch_size=16, sampler=self.train_sampler)
logging.info("this client gets {} samples".format(len(self.train_sampler)))
Expand All @@ -42,3 +50,4 @@ def init_test_dataset(self):
transform=transform)
self.test_sampler = get_sampler(self.test_dataset, self.args)
self.test_loader = DataLoader(self.test_dataset, batch_size=16, sampler=self.test_sampler)

15 changes: 14 additions & 1 deletion fltk/datasets/distributed/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import numpy

from fltk.util.arguments import Arguments
from fltk.util.poison.poisonpill import PoisonPill


class DistDataset:
Expand All @@ -15,8 +16,9 @@ class DistDataset:
test_dataset = None
train_loader = None
test_loader = None
def __init__(self, args: Arguments):
def __init__(self, args: Arguments, pill: PoisonPill = None):
self.args = args
self.pill = pill
# self.train_dataset = self.load_train_dataset()
# self.test_dataset = self.load_test_dataset()

Expand Down Expand Up @@ -137,3 +139,14 @@ def init_test_dataset(self):
# :return: tuple
# """
# return (next(iter(data_loader))[0].numpy(), next(iter(data_loader))[1].numpy())
def ingest_pill(self, pill: PoisonPill):
"""
Drink the CoolAid, apply poison to the input regarding the pill. Note that the pill may implement a noop,
meaning that this has no real result.
@param pill:
@type pill:
@return:
@rtype:
"""
self.train_dataset.targets = pill.poison_targets(self.train_dataset.targets)
self.test_dataset.targets = pill.poison_targets(self.test_dataset)
31 changes: 20 additions & 11 deletions fltk/federator.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ class Federator(object):
client_data = {}
poisoned_workers = {}

def __init__(self, client_id_triple, num_epochs=3, config=None):
def __init__(self, client_id_triple, num_epochs=3, config=None, attack=Attack):
log_rref = rpc.RRef(FLLogger())
self.log_rref = log_rref
self.num_epoch = num_epochs
Expand All @@ -102,6 +102,10 @@ def __init__(self, client_id_triple, num_epochs=3, config=None):
self.test_data.init_dataloader()
config.data_sampler = copy_sampler

# Poisoning
self.attack = attack


def create_clients(self, client_id_triple):
for id, rank, world_size in client_id_triple:
client = rpc.remote(id, Client, kwargs=dict(id=id, log_rref=self.log_rref, rank=rank, world_size=world_size,
Expand Down Expand Up @@ -133,9 +137,9 @@ def rpc_test_all(self):
while not res.done():
pass

def client_load_data(self):
def client_load_data(self, poison_pill):
for client in self.clients:
_remote_method_async(Client.init_dataloader, client.ref)
_remote_method_async(Client.init_dataloader, client.ref, pill=None if poison_pill and client not in self.poisoned_workers else poison_pill)

def clients_ready(self):
all_ready = False
Expand Down Expand Up @@ -245,22 +249,26 @@ def save_epoch_data(self):
def ensure_path_exists(self, path):
Path(path).mkdir(parents=True, exist_ok=True)

def run(self):
def run(self, attack: Attack = None):
"""
Main loop of the Federator
:return:
"""
# # Make sure the clients have loaded all the data
self.client_load_data()
self.ping_all()
self.clients_ready()
self.update_client_data_sizes()

# # Select clients which will be poisened
# TODO: get attack type and ratio from config, temp solution now
ratio = 0.2
attack = LabelFlipAttack(10, ratio, {1: 2})
self.poisoned_workers = attack.select_poisoned_workers(self.clients, ratio)
poison_pill = None
if attack:
self.poisoned_workers = attack.select_poisoned_workers(self.clients, ratio)
poison_pill = self.attack.get_poison_pill()

self.client_load_data(poison_pill)
self.ping_all()
self.clients_ready()
self.update_client_data_sizes()



epoch_to_run = self.num_epoch
addition = 0
Expand All @@ -276,3 +284,4 @@ def run(self):
logging.info(f'Saving data')
self.save_epoch_data()
logging.info(f'Federator is stopping')

11 changes: 6 additions & 5 deletions fltk/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,19 @@

import torch.multiprocessing as mp
from fltk.federator import Federator
from fltk.strategy.attack import Attack
from fltk.util.base_config import BareConfig
from fltk.util.env.learner_environment import prepare_environment

logging.basicConfig(level=logging.DEBUG)


def run_ps(rpc_ids_triple, args):
def run_ps(rpc_ids_triple, args, attack=Attack):
print(f'Starting the federator...')
fed = Federator(rpc_ids_triple, config=args)
fed.run()
fed = Federator(rpc_ids_triple, config=args, attack=Attack)
fed.run(attack)

def run_single(rank, world_size, host = None, args = None, nic = None):
def run_single(rank, world_size, host = None, args = None, nic = None, attack=None):
logging.info(f'Starting with rank={rank} and world size={world_size}')
prepare_environment(host, nic)

Expand Down Expand Up @@ -48,7 +49,7 @@ def run_single(rank, world_size, host = None, args = None, nic = None):
rpc_backend_options=options

)
run_ps([(f"client{r}", r, world_size) for r in range(1, world_size)], args)
run_ps([(f"client{r}", r, world_size) for r in range(1, world_size)], args, attack)
# block until all rpc finish
rpc.shutdown()

Expand Down
46 changes: 39 additions & 7 deletions fltk/strategy/attack.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from numpy import random

from fltk.util.base_config import BareConfig
from fltk.util.poison.poisonpill import FlipPill, PoisonPill


Expand Down Expand Up @@ -34,7 +35,7 @@ def build_attack(self):
pass

@abstractmethod
def get_poison_pill(self, *ars, **kwargs) -> PoisonPill:
def get_poison_pill(self, *args, **kwargs) -> PoisonPill:
pass


Expand All @@ -49,14 +50,29 @@ def build_attack(self, flip_description=None) -> PoisonPill:
flip_description = {0: 9, 9: 0}
return FlipPill(flip_description=flip_description)

def __init__(self, max_rounds: int, ratio: float, label_shuffle: Dict, seed: int = 42, random=False):
def __init__(self, max_rounds: int = 0, ratio: float = 0, label_shuffle: Dict = None, seed: int = 42, random=False,
cfg: dict = None):
"""
@param max_rounds:
@type max_rounds: int
@param ratio:
@type ratio: float
@param label_shuffle:
@type label_shuffle: dict
@param seed:
@type seed: int
@param random:
@type random: bool
@param cfg:
@type cfg: dict
"""
if 0 > ratio > 1:
self.logger.log(ERROR, f'Cannot run with a ratio of {ratio}, needs to be in range [0, 1]')
raise Exception("ratio is out of bounds")
Attack.__init__(self, max_rounds, seed)
if cfg is None:
if 0 > ratio > 1:
self.logger.log(ERROR, f'Cannot run with a ratio of {ratio}, needs to be in range [0, 1]')
raise Exception("ratio is out of bounds")
Attack.__init__(self, cfg.get('total_epochs', 0), cfg.get('poison', None).get('seed', None))
else:
Attack.__init__(self, max_rounds, seed)
self.ratio = ratio
self.label_shuffle = label_shuffle
self.random = random
Expand All @@ -72,3 +88,19 @@ def select_poisoned_workers(self, workers: List, ratio: float):

def get_poison_pill(self):
return FlipPill(self.label_shuffle)


def create_attack(cfg: BareConfig) -> Attack:
"""
Function to create Poison attack based on the configuration that was passed during execution.
Exception gets thrown when the configuration file is not correct.
"""
assert not cfg is None and not cfg.poison is None
attack_mapper = {'flip': LabelFlipAttack}

attack_class = attack_mapper.get(cfg.get_attack_type(), None)

if not attack_class is None:
attack = attack_class(cfg=cfg.get_attack_config())
else:
raise Exception("Requested attack is not supported...")
Loading

0 comments on commit c33926a

Please sign in to comment.