From dbcf49b0f9b1772f4c73495a1145ce779ab3aeeb Mon Sep 17 00:00:00 2001 From: bacox Date: Thu, 10 Mar 2022 16:11:48 +0100 Subject: [PATCH] Make message calls opaque --- Dockerfile | 9 +- fltk/__main__.py | 282 +++++++++++++++++++------ fltk/core/client.py | 24 ++- fltk/core/federator.py | 70 +++++- fltk/core/node.py | 30 ++- fltk/strategy/aggregation/FedAvg.py | 2 +- fltk/strategy/aggregation/__init__.py | 13 +- fltk/strategy/optimization/__init__.py | 2 +- fltk/util/config.py | 42 +++- fltk/util/data_container.py | 1 + fltk/util/definitions.py | 11 +- fltk/util/generate_docker_compose.py | 9 +- 12 files changed, 396 insertions(+), 99 deletions(-) diff --git a/Dockerfile b/Dockerfile index 6e79f4a4..09cdfe0b 100644 --- a/Dockerfile +++ b/Dockerfile @@ -7,9 +7,6 @@ MAINTAINER Bart Cox # Run build without interactive dialogue ARG DEBIAN_FRONTEND=noninteractive -ENV GLOO_SOCKET_IFNAME=eth0 -ENV TP_SOCKET_IFNAME=eth0 - # Define the working directory of the current Docker container WORKDIR /opt/federation-lab @@ -26,6 +23,9 @@ COPY requirements.txt ./ # Install all required packages for the generator RUN python3 -m pip install -r requirements.txt +ENV GLOO_SOCKET_IFNAME=$NIC +ENV TP_SOCKET_IFNAME=$NIC + #RUN mkdir -p ./data/MNIST #COPY ./data/MNIST ../data/MNIST #ADD fltk ./fedsim @@ -46,5 +46,6 @@ COPY fltk ./fltk COPY configs ./configs #CMD python3 ./fltk/__main__.py single configs/experiment.yaml --rank=$RANK # CMD python3 -m fltk single configs/experiment_vanilla.yaml --rank=$RANK -CMD python3 -m fltk single $EXP_CONFIG --rank=$RANK +#CMD python3 -m fltk single $EXP_CONFIG --rank=$RANK +CMD python3 -m fltk remote $EXP_CONFIG $RANK --nic=$NIC --host=$MASTER_HOSTNAME #CMD python3 setup.py \ No newline at end of file diff --git a/fltk/__main__.py b/fltk/__main__.py index 263a7fa5..cf614012 100644 --- a/fltk/__main__.py +++ b/fltk/__main__.py @@ -1,87 +1,233 @@ +# import os +# import random +# import sys +# import time +# +# import torch.distributed.rpc as rpc +# import logging +# +# import yaml +# import argparse +# +# import torch.multiprocessing as mp +# from fltk.federator import Federator +# from fltk.launch import run_single, run_spawn +# from fltk.util.base_config import BareConfig +# +# logging.basicConfig(level=logging.DEBUG) +# +# def add_default_arguments(parser): +# parser.add_argument('--world_size', type=str, default=None, +# help='Number of entities in the world. This is the number of clients + 1') +# +# def main(): +# parser = argparse.ArgumentParser(description='Experiment launcher for the Federated Learning Testbed') +# +# subparsers = parser.add_subparsers(dest="mode") +# +# single_parser = subparsers.add_parser('single') +# single_parser.add_argument('config', type=str) +# single_parser.add_argument('--rank', type=int) +# single_parser.add_argument('--nic', type=str, default=None) +# single_parser.add_argument('--host', type=str, default=None) +# add_default_arguments(single_parser) +# +# spawn_parser = subparsers.add_parser('spawn') +# spawn_parser.add_argument('config', type=str) +# add_default_arguments(spawn_parser) +# +# remote_parser = subparsers.add_parser('remote') +# remote_parser.add_argument('--rank', type=int) +# remote_parser.add_argument('--nic', type=str, default=None) +# remote_parser.add_argument('--host', type=str, default=None) +# add_default_arguments(remote_parser) +# args = parser.parse_args() +# if args.mode == 'remote': +# if args.rank is None or args.host is None or args.world_size is None or args.nic is None: +# print('Missing rank, host, world-size, or nic argument when in \'remote\' mode!') +# parser.print_help() +# exit(1) +# world_size = int(args.world_size) +# master_address = args.host +# nic = args.nic +# rank = int(args.rank) +# if rank == 0: +# print('Remote mode only supports ranks > 0!') +# exit(1) +# print(f'rank={args.rank}, world_size={world_size}, host={master_address}, args=None, nic={nic}') +# run_single(rank=args.rank, world_size=world_size, host=master_address, args=None, nic=nic) +# else: +# with open(args.config) as file: +# sleep_time = random.uniform(0, 5.0) +# time.sleep(sleep_time) +# cfg = BareConfig() +# yaml_data = yaml.load(file, Loader=yaml.FullLoader) +# cfg.merge_yaml(yaml_data) +# if args.mode == 'single': +# if args.rank is None: +# print('Missing rank argument when in \'single\' mode!') +# parser.print_help() +# exit(1) +# world_size = args.world_size +# master_address = args.host +# nic = args.nic +# +# if not world_size: +# world_size = yaml_data['system']['clients']['amount'] + 1 +# if not master_address: +# master_address = yaml_data['system']['federator']['hostname'] +# if not nic: +# nic = yaml_data['system']['federator']['nic'] +# print(f'rank={args.rank}, world_size={world_size}, host={master_address}, args=cfg, nic={nic}') +# run_single(rank=args.rank, world_size=world_size, host=master_address, args=cfg, nic=nic) +# else: +# run_spawn(cfg) +# +# if __name__ == "__main__": +# main() import os -import random import sys -import time +from pathlib import Path -import torch.distributed.rpc as rpc -import logging +from torch.distributed import rpc -import yaml +from fltk.core.client import Client + +print(sys.path) +# from fltk.core.federator import Federator as Fed +print(list(Path.cwd().iterdir())) import argparse +from enum import Enum +from pathlib import Path -import torch.multiprocessing as mp -from fltk.federator import Federator -from fltk.launch import run_single, run_spawn -from fltk.util.base_config import BareConfig +from fltk.core.federator import Federator +from fltk.util.config import Config +from fltk.util.definitions import Aggregations, Optimizations -logging.basicConfig(level=logging.DEBUG) +def run_single(config_path: Path): -def add_default_arguments(parser): - parser.add_argument('--world_size', type=str, default=None, - help='Number of entities in the world. This is the number of clients + 1') + # We can iterate over all the experiments in the directory and execute it, as long as the system remains the same! + # System = machines and its configuration + + print(config_path) + config = Config.FromYamlFile(config_path) + config.world_size = config.num_clients + 1 + config.replication_id = 1 + federator_node = Federator('federator', 0, config.world_size, config) + federator_node.run() + + +def retrieve_env_params(nic=None, host=None): + if host: + os.environ['MASTER_ADDR'] = host + os.environ['MASTER_PORT'] = '5000' + if nic: + os.environ['GLOO_SOCKET_IFNAME'] = nic + os.environ['TP_SOCKET_IFNAME'] = nic + +def retrieve_network_params_from_config(config: Config, nic=None, host=None): + if hasattr(config, 'system'): + system_attr = getattr(config, 'system') + if 'federator' in system_attr: + if 'hostname' in system_attr['federator'] and not host: + host = system_attr['federator']['hostname'] + if 'nic' in system_attr['federator'] and not nic: + nic = system_attr['federator']['nic'] + return nic, host + +def run_remote(config_path: Path, rank: int, nic=None, host=None): + print(config_path, rank) + config = Config.FromYamlFile(config_path) + config.world_size = config.num_clients + 1 + nic, host = retrieve_network_params_from_config(config, nic, host) + if not nic or not host: + print('Missing rank, host, world-size, or nic argument when in \'remote\' mode!') + parser.print_help() + exit(1) + retrieve_env_params(nic, host) + print(f'Starting with host={os.environ["MASTER_ADDR"]} and port={os.environ["MASTER_PORT"]} and interface={nic}') + options = rpc.TensorPipeRpcBackendOptions( + num_worker_threads=16, + rpc_timeout=0, # infinite timeout + init_method=f'tcp://{os.environ["MASTER_ADDR"]}:{os.environ["MASTER_PORT"]}' + ) + if rank != 0: + print(f'Starting worker {rank}') + rpc.init_rpc( + f"client{rank}", + rank=rank, + world_size=config.world_size, + rpc_backend_options=options, + ) + client_node = Client(f'client{rank}', rank, config.world_size, config) + client_node.remote_registration() + + # trainer passively waiting for ps to kick off training iterations + else: + print(f'Starting the ps with world size={config.world_size}') + rpc.init_rpc( + "federator", + rank=rank, + world_size=config.world_size, + rpc_backend_options=options + + ) + federator_node = Federator('federator', 0, config.world_size, config) + # federator_node.create_clients() + federator_node.run() + federator_node.stop_all_clients() + print('Ending program') + # if rank == 0: + # print('FEDERATOR!') + # else: + # print(f'CLIENT {rank}') def main(): - parser = argparse.ArgumentParser(description='Experiment launcher for the Federated Learning Testbed') + pass - subparsers = parser.add_subparsers(dest="mode") - single_parser = subparsers.add_parser('single') - single_parser.add_argument('config', type=str) - single_parser.add_argument('--rank', type=int) - single_parser.add_argument('--nic', type=str, default=None) - single_parser.add_argument('--host', type=str, default=None) - add_default_arguments(single_parser) +def add_default_arguments(parser): + parser.add_argument('config', type=str, + help='') - spawn_parser = subparsers.add_parser('spawn') - spawn_parser.add_argument('config', type=str) - add_default_arguments(spawn_parser) +if __name__ == '__main__': + parser = argparse.ArgumentParser(prog='fltk', description='Experiment launcher for the Federated Learning Testbed (fltk)') + subparsers = parser.add_subparsers(dest="action", required=True) + launch_parser = subparsers.add_parser('launch-util') remote_parser = subparsers.add_parser('remote') - remote_parser.add_argument('--rank', type=int) + single_machine_parser = subparsers.add_parser('single') + add_default_arguments(launch_parser) + add_default_arguments(remote_parser) + add_default_arguments(single_machine_parser) + + remote_parser.add_argument('rank', type=int) remote_parser.add_argument('--nic', type=str, default=None) remote_parser.add_argument('--host', type=str, default=None) - add_default_arguments(remote_parser) + + # single_parser = subparsers.add_parser('single', help='single help') + # single_parser.add_argument('config') + # util_parser = subparsers.add_parser('util', help='util help') + # util_parser.add_argument('action') + # print(sys.argv) args = parser.parse_args() - if args.mode == 'remote': - if args.rank is None or args.host is None or args.world_size is None or args.nic is None: - print('Missing rank, host, world-size, or nic argument when in \'remote\' mode!') - parser.print_help() - exit(1) - world_size = int(args.world_size) - master_address = args.host - nic = args.nic - rank = int(args.rank) - if rank == 0: - print('Remote mode only supports ranks > 0!') - exit(1) - print(f'rank={args.rank}, world_size={world_size}, host={master_address}, args=None, nic={nic}') - run_single(rank=args.rank, world_size=world_size, host=master_address, args=None, nic=nic) + if args.action == 'launch-util': + pass + # run_single(Path(args.config)) + if args.action == 'remote': + run_remote(Path(args.config), args.rank, args.nic, args.host) else: - with open(args.config) as file: - sleep_time = random.uniform(0, 5.0) - time.sleep(sleep_time) - cfg = BareConfig() - yaml_data = yaml.load(file, Loader=yaml.FullLoader) - cfg.merge_yaml(yaml_data) - if args.mode == 'single': - if args.rank is None: - print('Missing rank argument when in \'single\' mode!') - parser.print_help() - exit(1) - world_size = args.world_size - master_address = args.host - nic = args.nic - - if not world_size: - world_size = yaml_data['system']['clients']['amount'] + 1 - if not master_address: - master_address = yaml_data['system']['federator']['hostname'] - if not nic: - nic = yaml_data['system']['federator']['nic'] - print(f'rank={args.rank}, world_size={world_size}, host={master_address}, args=cfg, nic={nic}') - run_single(rank=args.rank, world_size=world_size, host=master_address, args=cfg, nic=nic) - else: - run_spawn(cfg) - -if __name__ == "__main__": - main() \ No newline at end of file + # Run single machine mode + run_single(Path(args.config)) + + # if args.mode == 'single': + # print('Single') + # c = Config(optimizer=Optimizations.fedprox) + # print(isinstance(Config.aggregation, Enum)) + # config = Config.FromYamlFile(args.config) + # + # auto = config.optimizer + # print(config) + # print('Parsed') + + # print(args) \ No newline at end of file diff --git a/fltk/core/client.py b/fltk/core/client.py index 21e922f3..3f0ed23a 100644 --- a/fltk/core/client.py +++ b/fltk/core/client.py @@ -10,7 +10,7 @@ class Client(Node): - + running = False def __init__(self, id: int, rank: int, world_size: int, config: Config): super().__init__(id, rank, world_size, config) @@ -22,6 +22,23 @@ def __init__(self, id: int, rank: int, world_size: int, config: Config): self.config.scheduler_gamma, self.config.min_lr) + def remote_registration(self): + self.logger.info('Sending registration') + self.message('federator', 'ping', 'new_sender', be_weird=True) + self.message('federator', 'register_client', self.id, self.rank) + self.running = True + self._event_loop() + + def stop_client(self): + self.logger.info('Got call to stop event loop') + self.running = False + + def _event_loop(self): + self.logger.info('Starting event loop') + while self.running: + time.sleep(0.1) + self.logger.info('Exiting node') + def train(self, num_epochs: int): start_time = time.time() @@ -47,10 +64,11 @@ def train(self, num_epochs: int): running_loss += loss.item() # Mark logging update step if i % self.config.log_interval == 0: - # self.logger.info( - # '[%d, %5d] loss: %.3f' % (num_epochs, i, running_loss / self.config.log_interval)) + self.logger.info( + '[%s] [%d, %5d] loss: %.3f' % (self.id, num_epochs, i, running_loss / self.config.log_interval)) final_running_loss = running_loss / self.config.log_interval running_loss = 0.0 + # break end_time = time.time() duration = end_time - start_time diff --git a/fltk/core/federator.py b/fltk/core/federator.py index 70b65a38..43d9c392 100644 --- a/fltk/core/federator.py +++ b/fltk/core/federator.py @@ -14,6 +14,7 @@ from dataclasses import dataclass from fltk.util.data_container import DataContainer, FederatorRecord, ClientRecord +from fltk.strategy import get_aggregation NodeReference = Union[Node, str] @dataclass @@ -36,11 +37,14 @@ def __init__(self, id: int, rank: int, world_size: int, config: Config): self.loss_function = self.config.get_loss_function()() self.num_rounds = config.rounds self.config = config + config.output_path = Path(config.output_path) / config.exp_name / f'{config.name}_r{config.replication_id}' self.exp_data = DataContainer('federator', config.output_path, FederatorRecord, config.save_data_append) + self.aggregation_method = get_aggregation(config.aggregation) def create_clients(self): + self.logger.info('Creating clients') if self.config.single_machine: # Create direct clients world_size = self.config.num_clients + 1 @@ -49,13 +53,20 @@ def create_clients(self): client = Client(client_name, client_id, world_size, copy.deepcopy(self.config)) self.clients.append(LocalClient(client_name, client, 0, DataContainer(client_name, self.config.output_path, ClientRecord, self.config.save_data_append))) + self.logger.info(f'Client "{client_name}" created') def register_client(self, client_name, rank): + self.logger.info(f'Got new client registration from client {client_name}') if self.config.single_machine: self.logger.warning('This function should not be called when in single machine mode!') - self.clients.append(LocalClient(client_name, client_name, 0, DataContainer(client_name, self.config.output_path, + self.clients.append(LocalClient(client_name, client_name, rank, DataContainer(client_name, self.config.output_path, ClientRecord, self.config.save_data_append))) + def stop_all_clients(self): + for client in self.clients: + self.message(client.ref, Client.stop_client) + + def _num_clients_online(self) -> int: return len(self.clients) @@ -93,6 +104,9 @@ def run(self): while not self._all_clients_online(): self.logger.info(f'Waiting for all clients to come online. Waiting for {self.world_size - 1 -self._num_clients_online()} clients') time.sleep(2) + self.logger.info('All clients are online') + # self.logger.info('Running') + # time.sleep(10) self.client_load_data() self.get_client_data_sizes() self.clients_ready() @@ -169,18 +183,58 @@ def exec_round(self, id: int): # Actual training calls client_weights = {} client_sizes = {} - pbar = tqdm(selected_clients) - for client in pbar: - pbar.set_description(f'[Round {id:>3}] Running clients') - train_loss, weights, accuracy, test_loss, round_duration, train_duration, test_duration = self.message(client.ref, Client.exec_round, num_epochs) + # pbar = tqdm(selected_clients) + # for client in pbar: + + # Client training + training_futures: List[torch.Future] = [] + + + + def training_cb(fut: torch.Future, client: LocalClient): + train_loss, weights, accuracy, test_loss, round_duration, train_duration, test_duration = fut.wait() client_weights[client.name] = weights client_data_size = self.message(client.ref, Client.get_client_datasize) client_sizes[client.name] = client_data_size - client.exp_data.append(ClientRecord(id, train_duration, test_duration, round_duration, num_epochs, 0, accuracy, train_loss, test_loss)) - # self.logger.info(f'[Round {id:>3}] Client {client} has a accuracy of {accuracy}, train loss={train_loss}, test loss={test_loss},datasize={client_data_size}') + self.logger.info(f'Training callback for client {client.name}') + client.exp_data.append( + ClientRecord(id, train_duration, test_duration, round_duration, num_epochs, 0, accuracy, train_loss, + test_loss)) + + for client in selected_clients: + # future: torch.Future + # if not self.real_time: + # future = torch.futures.Future() + # future.set_result(self.message(client.ref, Client.exec_round, num_epochs)) + # future.then(lambda x: training_cb(x, client)) + # training_futures.append(future) + # else: + future = self.message_async(client.ref, Client.exec_round, num_epochs) + future.then(lambda x: training_cb(x, client)) + training_futures.append(future) + + def all_futures_done(futures: List[torch.Future])->bool: + return all(map(lambda x: x.done(), futures)) + + while not all_futures_done(training_futures): + time.sleep(0.1) + # self.logger.info(f'Waiting for other clients') + + self.logger.info(f'Continue with rest [1]') + + + # for client in selected_clients: + # # pbar.set_description(f'[Round {id:>3}] Running clients') + # train_loss, weights, accuracy, test_loss, round_duration, train_duration, test_duration = self.message(client.ref, Client.exec_round, num_epochs) + # client_weights[client.name] = weights + # client_data_size = self.message(client.ref, Client.get_client_datasize) + # client_sizes[client.name] = client_data_size + # client.exp_data.append(ClientRecord(id, train_duration, test_duration, round_duration, num_epochs, 0, accuracy, train_loss, test_loss)) + # # self.logger.info(f'[Round {id:>3}] Client {client} has a accuracy of {accuracy}, train loss={train_loss}, test loss={test_loss},datasize={client_data_size}') # updated_model = FedAvg(client_weights, client_sizes) - updated_model = average_nn_parameters_simple(list(client_weights.values())) + updated_model = self.aggregation_method(client_weights, client_sizes) + # updated_model = average_nn_parameters_simple(list(client_weights.values())) self.update_nn_parameters(updated_model) test_accuracy, test_loss = self.test(self.net) diff --git a/fltk/core/node.py b/fltk/core/node.py index 17d594a0..6d8aac0c 100644 --- a/fltk/core/node.py +++ b/fltk/core/node.py @@ -46,6 +46,7 @@ def __init__(self, id: int, rank: int, world_size: int, config: Config): self.id = id self.rank = rank self.world_size = world_size + self.real_time = config.real_time global global_vars global_vars['self'] = self self._config(config) @@ -77,8 +78,19 @@ def is_ready(self): @staticmethod def _receive(method: Callable, sender: str, *args, **kwargs): global global_vars + # print('_receive') + # print(global_vars) global_self = global_vars['self'] - return method(global_self, *args, **kwargs) + # print(type(method)) + # print(type(global_self)) + if type(method) is str: + # print(f'Retrieving method from string: "{method}"') + method = getattr(global_self, method) + return method(*args, **kwargs) + else: + # print(method) + # print(global_self, *args, kwargs) + return method(global_self, *args, **kwargs) # def _lookup_reference(self, node_name: str): @@ -153,13 +165,27 @@ def update_nn_parameters(self, new_params, is_offloaded_model = False): self.net.load_state_dict(copy.deepcopy(new_params), strict=True) # self.logger.info(f'Weights of the model are updated') - def message(self, other_node: str, method: Callable, *args, **kwargs): + def message(self, other_node: str, method: Callable, *args, **kwargs) -> torch.Future: if self.real_time: func = Node._receive args_list = [method, self.id] + list(args) return rpc.rpc_sync(other_node, func, args=args_list, kwargs=kwargs) return method(other_node, *args, **kwargs) + def message_async(self, other_node: str, method: Callable, *args, **kwargs) -> torch.Future: + if self.real_time: + func = Node._receive + args_list = [method, self.id] + list(args) + return rpc.rpc_async(other_node, func, args=args_list, kwargs=kwargs) + # Wrap inside a future to keep the logic the same + future = torch.futures.Future() + future.set_result(method(other_node, *args, **kwargs)) + return future + + # def register_client(self, client_name, rank): + # print(f'self={self}') + # self.logger.info(f'[Default Implementation!] Got new client registration from client {client_name}') + def ping(self, sender: str, be_weird=False): self.logger.info(f'Pong from {self.id}, got call from {sender} [{self.counter}]') # print(f'Pong from {self.id}, got call from {sender} [{self.counter}]') diff --git a/fltk/strategy/aggregation/FedAvg.py b/fltk/strategy/aggregation/FedAvg.py index 98a72396..041f4628 100644 --- a/fltk/strategy/aggregation/FedAvg.py +++ b/fltk/strategy/aggregation/FedAvg.py @@ -1,6 +1,6 @@ -def FedAvg(parameters, sizes): +def fed_avg(parameters, sizes): new_params = {} sum_size = 0 for client in parameters: diff --git a/fltk/strategy/aggregation/__init__.py b/fltk/strategy/aggregation/__init__.py index 696cca51..fca94c72 100644 --- a/fltk/strategy/aggregation/__init__.py +++ b/fltk/strategy/aggregation/__init__.py @@ -1,2 +1,13 @@ -from .FedAvg import FedAvg +from fltk.util.definitions import Aggregations +from .FedAvg import fed_avg from .aggregation import average_nn_parameters, average_nn_parameters_simple + + +def get_aggregation(name: Aggregations): + enum_type = Aggregations(name.value) + aggregations_dict = { + Aggregations.fedavg: fed_avg, + Aggregations.sum: lambda x: x, + Aggregations.avg: lambda x: x*2 + } + return aggregations_dict[enum_type] \ No newline at end of file diff --git a/fltk/strategy/optimization/__init__.py b/fltk/strategy/optimization/__init__.py index e4a501ce..a38c3de0 100644 --- a/fltk/strategy/optimization/__init__.py +++ b/fltk/strategy/optimization/__init__.py @@ -1,7 +1,7 @@ import torch from .fedprox import FedProx from .FedNova import FedNova -from ...util.definitions import Optimizations +from fltk.util.definitions import Optimizations def get_optimizer(name: Optimizations): diff --git a/fltk/util/config.py b/fltk/util/config.py index 86266a22..a6811735 100644 --- a/fltk/util/config.py +++ b/fltk/util/config.py @@ -1,13 +1,19 @@ +import copy from dataclasses import dataclass +from enum import Enum, EnumMeta from pathlib import Path +from typing import Type import torch +import yaml -from fltk.util.definitions import Dataset, Nets, DataSampler, Optimizations, LogLevel +from fltk.util.definitions import Dataset, Nets, DataSampler, Optimizations, LogLevel, Aggregations @dataclass class Config: + # optimizer: Optimizations + name: str = '' batch_size: int = 1 test_batch_size: int = 1000 rounds: int = 2 @@ -20,36 +26,53 @@ class Config: scheduler_step_size: int = 50 scheduler_gamma: float = 0.5 min_lr: float = 1e-10 - optimizer = Optimizations.sgd + # Enum + optimizer: Optimizations = Optimizations.sgd optimizer_args = { 'lr': lr, 'momentum': momentum } loss_function = torch.nn.CrossEntropyLoss - + # Enum log_level: LogLevel = LogLevel.DEBUG num_clients: int = 10 clients_per_round: int = 2 distributed: bool = True single_machine: bool = False - + # Enum + aggregation: Aggregations = Aggregations.fedavg + # Enum dataset_name: Dataset = Dataset.mnist + # Enum net_name: Nets = Nets.mnist_cnn default_model_folder_path: str = "default_models" data_path: str = "data" + # Enum data_sampler: DataSampler = DataSampler.uniform data_sampler_args = [] rank: int = 0 world_size: int = 0 + replication_id: int = None + exp_name: str = 'experiment' + + real_time : bool = False + # Save data in append mode. Thereby flushing on every append to file. # This could be useful when a system is likely to crash midway an experiment save_data_append: bool = False - output_path: Path = Path('output_test_2') + def __init__(self, **kwargs) -> None: + enum_fields = [x for x in self.__dataclass_fields__.items() if isinstance(x[1].type, Enum) or isinstance(x[1].type, EnumMeta)] + for name, field in enum_fields: + if name in kwargs and isinstance(kwargs[name], str): + kwargs[name] = field.type(kwargs[name]) + for name, value in kwargs.items(): + self.__setattr__(name, value) + def get_default_model_folder_path(self): return self.default_model_folder_path @@ -73,3 +96,12 @@ def get_data_path(self): def get_loss_function(self): return self.loss_function + + @classmethod + def FromYamlFile(cls, path: Path): + print(f'Loading yaml from {path.absolute()}') + with open(path) as file: + content = yaml.safe_load(file) + for k, v in content.items(): + print(f'Inserting key "{k}" into config') + return cls(**content) diff --git a/fltk/util/data_container.py b/fltk/util/data_container.py index d0500820..6a6f1350 100644 --- a/fltk/util/data_container.py +++ b/fltk/util/data_container.py @@ -54,6 +54,7 @@ def __init__(self, name: str, output_location: Path, record_type: DataRecord, ap self.records = [] self.file_name = f'{name}.csv' self.name = name + output_location = Path(output_location) output_location.mkdir(parents=True, exist_ok=True) self.file_path = output_location / self.file_name self.append_mode = append_mode diff --git a/fltk/util/definitions.py b/fltk/util/definitions.py index c81f9a4f..2492b062 100644 --- a/fltk/util/definitions.py +++ b/fltk/util/definitions.py @@ -7,9 +7,10 @@ # 6. Optimizers # ############################### # Use enums instead of dataclasses? -from enum import Enum +from enum import Enum, unique +@unique class DataSampler(Enum): uniform = "uniform" q_sampler = "q sampler" @@ -21,19 +22,20 @@ class DataSampler(Enum): n_labels = "n labels" +@unique class Optimizations(Enum): sgd = 'SGD' fedprox = 'FedProx' fednova = 'FedNova' +@unique class Dataset(Enum): cifar10 = 'cifar10' cifar100 = 'cifar100' fashion_mnist = 'fashion-mnist' mnist = 'mnist' - class LogLevel(Enum): CRITICAL = 50 FATAL = CRITICAL @@ -44,13 +46,14 @@ class LogLevel(Enum): DEBUG = 10 NOTSET = 0 - +@unique class Aggregations(Enum): avg = 'Avg' - fed_avg = 'FedAvg' + fedavg = 'FedAvg' sum = 'Sum' +@unique class Nets(Enum): cifar100_resnet = "Cifar100ResNet" cifar100_vgg = "Cifar100VGG" diff --git a/fltk/util/generate_docker_compose.py b/fltk/util/generate_docker_compose.py index 8ac761a4..b58233d3 100644 --- a/fltk/util/generate_docker_compose.py +++ b/fltk/util/generate_docker_compose.py @@ -1,10 +1,15 @@ import sys +from pathlib import Path + import yaml import copy import argparse -global_template_path = './deploy/templates' +# global_template_path = './deploy/templates' +global_template_path = Path(__file__).absolute().parent.parent.parent / 'deploy' / 'templates' +global_template_path = global_template_path.__str__() +print(global_template_path) def load_system_template(template_path = global_template_path): print(f'Loading system template from {template_path}/system_stub.yml') with open(f'{template_path}/system_stub.yml') as file: @@ -17,7 +22,7 @@ def load_client_template(type='default', template_path = global_template_path): return documents def get_deploy_path(name: str): - return f'./deploy/{name}' + return f'{Path(global_template_path).parent}/{name}' def generate_client(id, template: dict, world_size: int, type='default', cpu_set=''):