diff --git a/fltk/federator.py b/fltk/federator.py index cbd6307f..3d05c5d7 100644 --- a/fltk/federator.py +++ b/fltk/federator.py @@ -2,7 +2,7 @@ import pathlib import time from pathlib import Path -from typing import List, Callable +from typing import List, Callable, Dict import torch from dataclass_csv import DataclassWriter @@ -11,8 +11,6 @@ from fltk.client import Client from fltk.nets.util.utils import flatten_params, save_model -from fltk.strategy.antidote import Antidote -from fltk.strategy.attack import Attack from fltk.util.base_config import BareConfig from fltk.util.log import FLLogger from fltk.util.results import EpochData @@ -71,44 +69,37 @@ class Federator(object): - Keep track of timing """ + # Dictionary containing active (e.g. deployed) tasks. + active_tasks: Dict[str, Dict[str, ClientRef]] + # List of active clients clients: List[ClientRef] = [] - epoch_counter = 0 - # TODO: Misnormer, but no time to refactor - client_data = {} - poisoned_clients = {} - healthy_clients = {} - - def __init__(self, client_id_triple, num_epochs=3, config: BareConfig = None, attack: Attack = None, - antidote: Antidote = None): + # epoch_counter = 0 + # client_data = {} + # poisoned_clients = {} + # healthy_clients = {} + + def __init__(self, client_id_triple, config: BareConfig = None): log_rref = rpc.RRef(FLLogger()) - # Poisoning - self.attack = attack - logging.info(f'Federator with attack {attack}') - self.antidote = antidote - logging.info(f'Federator with antidote {antidote}') + self.log_rref = log_rref - self.num_epoch = num_epochs self.config = config - self.tb_path_base = self.config.output_location - self.ensure_path_exists(self.tb_path_base) - self.tb_writer = SummaryWriter(f'{self.tb_path_base}/{self.config.experiment_prefix}_federator') + + # TODO: Change to Kubernetes spawning self.create_clients(client_id_triple) + # TODO: Decide on using a more persitent logging approach self.config.init_logger(logging) - logging.info("Creating test client") - copy_sampler = config.data_sampler - config.data_sampler = "uniform" - self.test_data = None - self.set_data() - config.data_sampler = copy_sampler - - def set_data(self): - self.test_data = Client("test", None, 1, 2, self.config) - self.test_data.init_dataloader() - def create_clients(self, client_id_triple): + """ + Function to spin up worker clients for a task. + @param client_id_triple: + @type client_id_triple: + @return: + @rtype: + """ + # TODO: Change to spinning up different clients. for id, rank, world_size in client_id_triple: client = rpc.remote(id, Client, kwargs=dict(id=id, log_rref=self.log_rref, rank=rank, world_size=world_size, config=self.config)) @@ -119,8 +110,14 @@ def create_clients(self, client_id_triple): self.client_data['federator'] = [] def update_clients(self, ratio): + """ + TODO remove functionality, move to new function & clean up + @param ratio: + @type ratio: + @return: + @rtype: + """ # Prevent abrupt ending of the client - self.tb_writer.close() self.tb_writer = SummaryWriter(f'{self.tb_path_base}/{self.config.experiment_prefix}_federator_{ratio}') for client in self.clients: # Create new writer and close old writer @@ -130,9 +127,6 @@ def update_clients(self, ratio): # Clear client updates ofteraf self.client_data[client.name] = [] - def select_clients(self, n): - return self.attack.select_clients(self.poisoned_clients, self.healthy_clients, n) - def ping_all(self): for client in self.clients: logging.info(f'Sending ping to {client}') @@ -366,14 +360,15 @@ def distribute_new_model(self, updated_model): res[1].wait() logging.info('Weights are updated') - def test_model(self) -> EpochData: + def test_model(self, model, writer) -> EpochData: """ + TODO: Move this function somewhere else. Maybe even the federator shouldn't be bothered with testing. Function to test the model on the test dataset. @return: @rtype: """ # Test interleaved to speed up execution, i.e. don't keep the clients waiting. - accuracy, loss, class_precision, class_recall = self.test_data.test() + accuracy, loss, class_precision, class_recall = model.test() data = EpochData(epoch_id=self.epoch_counter, duration_train=0, duration_test=0, @@ -383,6 +378,6 @@ def test_model(self) -> EpochData: class_precision=class_precision, class_recall=class_recall, client_id='federator') - self.tb_writer.add_scalar('accuracy', accuracy, self.epoch_counter * self.test_data.get_client_datasize()) - self.tb_writer.add_scalar('accuracy per epoch', accuracy, self.epoch_counter) + writer.add_scalar('accuracy', accuracy, self.epoch_counter * self.test_data.get_client_datasize()) + writer.add_scalar('accuracy per epoch', accuracy, self.epoch_counter) return data