From 13602bab9a9002ba915517e8a69f988b170cbc74 Mon Sep 17 00:00:00 2001 From: bacox Date: Thu, 24 Feb 2022 19:59:39 +0100 Subject: [PATCH] Record experiment data --- fltk/core/client.py | 10 +++-- fltk/core/federator.py | 42 +++++++++++++----- fltk/util/analysis.py | 70 ++++++++++++++++++++++++++++++ fltk/util/config.py | 7 +++ fltk/util/data_container.py | 86 +++++++++++++++++++++++++++++++++++++ 5 files changed, 200 insertions(+), 15 deletions(-) create mode 100644 fltk/util/analysis.py create mode 100644 fltk/util/data_container.py diff --git a/fltk/core/client.py b/fltk/core/client.py index 34820964..21e922f3 100644 --- a/fltk/core/client.py +++ b/fltk/core/client.py @@ -103,23 +103,25 @@ def test(self): def get_client_datasize(self): return len(self.dataset.get_train_sampler()) - def exec_round(self, num_epochs: int) -> Tuple[float, Any, float, float]: + def exec_round(self, num_epochs: int) -> Tuple[Any, Any, Any, Any, float, float, float]: start = time.time() loss, weights = self.train(num_epochs) - + time_mark_between = time.time() accuracy, test_loss = self.test() end = time.time() - duration = end - start + round_duration = end - start + train_duration = time_mark_between - start + test_duration = end - time_mark_between # self.logger.info(f'Round duration is {duration} seconds') if hasattr(self.optimizer, 'pre_communicate'): # aka fednova or fedprox self.optimizer.pre_communicate() for k, v in weights.items(): weights[k] = v.cpu() - return loss, weights, accuracy, test_loss + return loss, weights, accuracy, test_loss, round_duration, train_duration, test_duration def __del__(self): self.logger.info(f'Client {self.id} is stopping') \ No newline at end of file diff --git a/fltk/core/federator.py b/fltk/core/federator.py index 25d90857..70b65a38 100644 --- a/fltk/core/federator.py +++ b/fltk/core/federator.py @@ -1,8 +1,10 @@ import copy import time +from pathlib import Path from typing import List, Union import torch +from tqdm import tqdm from fltk.core.client import Client from fltk.core.node import Node @@ -11,25 +13,30 @@ from fltk.util.config import Config from dataclasses import dataclass +from fltk.util.data_container import DataContainer, FederatorRecord, ClientRecord + NodeReference = Union[Node, str] @dataclass class LocalClient: name: str ref: NodeReference data_size: int + exp_data: DataContainer + class Federator(Node): clients: List[LocalClient] = [] # clients: List[NodeReference] = [] num_rounds: int - + exp_data: DataContainer def __init__(self, id: int, rank: int, world_size: int, config: Config): super().__init__(id, rank, world_size, config) self.loss_function = self.config.get_loss_function()() self.num_rounds = config.rounds self.config = config + self.exp_data = DataContainer('federator', config.output_path, FederatorRecord, config.save_data_append) @@ -40,12 +47,14 @@ def create_clients(self): for client_id in range(1, self.config.num_clients+ 1): client_name = f'client{client_id}' client = Client(client_name, client_id, world_size, copy.deepcopy(self.config)) - self.clients.append(LocalClient(client_name, client, 0)) + self.clients.append(LocalClient(client_name, client, 0, DataContainer(client_name, self.config.output_path, + ClientRecord, self.config.save_data_append))) def register_client(self, client_name, rank): if self.config.single_machine: self.logger.warning('This function should not be called when in single machine mode!') - self.clients.append(LocalClient(client_name, client_name, 0)) + self.clients.append(LocalClient(client_name, client_name, 0, DataContainer(client_name, self.config.output_path, + ClientRecord, self.config.save_data_append))) def _num_clients_online(self) -> int: return len(self.clients) @@ -88,11 +97,18 @@ def run(self): self.get_client_data_sizes() self.clients_ready() - for communications_round in range(self.config.rounds): - self.exec_round() + for communication_round in range(self.config.rounds): + self.exec_round(communication_round) + self.save_data() self.logger.info('Federator is stopping') + + def save_data(self): + self.exp_data.save() + for client in self.clients: + client.exp_data.save() + def client_load_data(self): for client in self.clients: self.message(client.ref, Client.init_dataloader) @@ -138,7 +154,7 @@ def test(self, net): self.logger.info(f'Test duration is {duration} seconds') return accuracy, loss - def exec_round(self): + def exec_round(self, id: int): start_time = time.time() num_epochs = self.config.epochs @@ -153,21 +169,25 @@ def exec_round(self): # Actual training calls client_weights = {} client_sizes = {} - for client in selected_clients: - train_loss, weights, accuracy, test_loss = self.message(client.ref, Client.exec_round, num_epochs) + pbar = tqdm(selected_clients) + for client in pbar: + pbar.set_description(f'[Round {id:>3}] Running clients') + train_loss, weights, accuracy, test_loss, round_duration, train_duration, test_duration = self.message(client.ref, Client.exec_round, num_epochs) client_weights[client.name] = weights client_data_size = self.message(client.ref, Client.get_client_datasize) client_sizes[client.name] = client_data_size - self.logger.info(f'Client {client} has a accuracy of {accuracy}, train loss={train_loss}, test loss={test_loss},datasize={client_data_size}') + client.exp_data.append(ClientRecord(id, train_duration, test_duration, round_duration, num_epochs, 0, accuracy, train_loss, test_loss)) + # self.logger.info(f'[Round {id:>3}] Client {client} has a accuracy of {accuracy}, train loss={train_loss}, test loss={test_loss},datasize={client_data_size}') # updated_model = FedAvg(client_weights, client_sizes) updated_model = average_nn_parameters_simple(list(client_weights.values())) self.update_nn_parameters(updated_model) test_accuracy, test_loss = self.test(self.net) - self.logger.info(f'Federator has a accuracy of {test_accuracy} and loss={test_loss}') + self.logger.info(f'[Round {id:>3}] Federator has a accuracy of {test_accuracy} and loss={test_loss}') end_time = time.time() duration = end_time - start_time - self.logger.info(f'Round duration is {duration} seconds') + self.exp_data.append(FederatorRecord(len(selected_clients), 0, duration, test_loss, test_accuracy)) + self.logger.info(f'[Round {id:>3}] Round duration is {duration} seconds') diff --git a/fltk/util/analysis.py b/fltk/util/analysis.py new file mode 100644 index 00000000..b35ce4c8 --- /dev/null +++ b/fltk/util/analysis.py @@ -0,0 +1,70 @@ +from pathlib import Path +import argparse +from typing import List + +import matplotlib.pyplot as plt +import pandas as pd +import seaborn as sns +import re + +# alt.renderers.enable('mimetype') + +def get_cwd() -> Path: + return Path.cwd() + + +def get_exp_name(path: Path) -> str: + return path.parent.name + + +def ensure_path_exists(path: Path): + path.mkdir(parents=True, exist_ok=True) + +def load_and_merge_dfs(files: List[Path]) -> pd.DataFrame: + dfs = [pd.read_csv(x) for x in files] + return pd.concat(dfs, ignore_index=True) + +def order_client_names(names: List[str]) -> List[str]: + return sorted(names, key=lambda x: float(re.findall(r'\d+', x)[0])) + +def plot_client_duration(df: pd.DataFrame): + small_df = df[['round_id', 'train_duration', 'test_duration', 'round_duration', 'node_name']].melt(id_vars=['round_id', 'node_name'], var_name='type') + ordered_clients = order_client_names(small_df['node_name'].unique()) + plt.figure() + g = sns.FacetGrid(small_df, col="type", sharey=False) + g.map(sns.boxplot, "node_name", "value", order=ordered_clients) + for axes in g.axes.flat: + _ = axes.set_xticklabels(axes.get_xticklabels(), rotation=90) + plt.tight_layout() + plt.show() + + plt.figure() + g = sns.FacetGrid(small_df, col="type", sharey=False, hue='node_name', hue_order=ordered_clients) + g.map(sns.lineplot, "round_id", "value") + for axes in g.axes.flat: + _ = axes.set_xticklabels(axes.get_xticklabels(), rotation=90) + plt.tight_layout() + plt.show() + + +def analyse(path: Path): + cwd = get_cwd() + output_path = cwd / get_exp_name(path) + ensure_path_exists(output_path) + all_files = [x for x in path.iterdir() if x.is_file()] + federator_files = [x for x in all_files if 'federator' in x.name] + client_files = [x for x in all_files if x.name.startswith('client')] + + federator_data = load_and_merge_dfs(federator_files) + client_data = load_and_merge_dfs(client_files) + + # print(len(client_data), len(federator_data)) + plot_client_duration(client_data) + # What do we want to plot in terms of data? + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Basic experiment analysis') + parser.add_argument('path', type=str, help='Path pointing to experiment results files') + args = parser.parse_args() + analyse(Path(args.path)) diff --git a/fltk/util/config.py b/fltk/util/config.py index 2b8ba98a..86266a22 100644 --- a/fltk/util/config.py +++ b/fltk/util/config.py @@ -1,4 +1,5 @@ from dataclasses import dataclass +from pathlib import Path import torch @@ -43,6 +44,12 @@ class Config: rank: int = 0 world_size: int = 0 + # Save data in append mode. Thereby flushing on every append to file. + # This could be useful when a system is likely to crash midway an experiment + save_data_append: bool = False + + output_path: Path = Path('output_test_2') + def get_default_model_folder_path(self): return self.default_model_folder_path diff --git a/fltk/util/data_container.py b/fltk/util/data_container.py new file mode 100644 index 00000000..d0500820 --- /dev/null +++ b/fltk/util/data_container.py @@ -0,0 +1,86 @@ +import csv +import time +from dataclasses import dataclass +from pathlib import Path +from typing import Union, List, Type +from typing.io import TextIO + + +@dataclass +class DataRecord: + pass + + +@dataclass +class FederatorRecord(DataRecord): + num_selected_clients: int + round_id: int + round_duration: int + test_loss: float + test_accuracy: float + # Accuracy per class? + timestamp: float = time.time() + node_name: str = '' + + +@dataclass +class ClientRecord(DataRecord): + round_id: int + train_duration: float + test_duration: float + round_duration: float + num_epochs: int + trained_items: int + accuracy: float + train_loss: float + test_loss: float + # Accuracy per class? + timestamp: float = time.time() + node_name: str = '' + + +class DataContainer: + records: List[DataRecord] + file_name: str + file_handle: TextIO + file_path: Path + append_mode: bool + record_type: DataRecord + delimiter = ',' + name: str + + def __init__(self, name: str, output_location: Path, record_type: DataRecord, append_mode: bool = False): + # print(f'Creating new Data container for client {name}') + self.records = [] + self.file_name = f'{name}.csv' + self.name = name + output_location.mkdir(parents=True, exist_ok=True) + self.file_path = output_location / self.file_name + self.append_mode = append_mode + file_flag = 'a' if append_mode else 'w' + self.file_handle = open(self.file_path, file_flag) + self.record_type = record_type + if self.append_mode: + open(self.file_path, 'w').close() + dw = csv.DictWriter(self.file_handle, self.record_type.__annotations__) + dw.writeheader() + self.file_handle.flush() + + def append(self, record: DataRecord): + record.node_name = self.name + self.records.append(record) + if self.append_mode: + dw = csv.DictWriter(self.file_handle, self.record_type.__annotations__) + dw.writerow(record.__dict__) + self.file_handle.flush() + + def save(self): + if self.append_mode: + return + dw = csv.DictWriter(self.file_handle, self.record_type.__annotations__) + dw.writeheader() + # print(f'Saving {len(self.records)} for node {self.name}') + for record in self.records: + record.node_name = self.name + dw.writerow(record.__dict__) + self.file_handle.flush()