From 2a22902574c5a275e15bd2fa2969703103106c9c Mon Sep 17 00:00:00 2001 From: bacox Date: Thu, 17 Mar 2022 10:52:08 +0100 Subject: [PATCH] Update analysis script --- examples/README.md | 1 + experiments/example_docker/descr.yaml | 1 + experiments/example_native/descr.yaml | 5 ++- fltk/core/federator.py | 2 +- fltk/util/analysis.py | 60 +++++++++++++++++++++++---- 5 files changed, 58 insertions(+), 11 deletions(-) create mode 100644 examples/README.md diff --git a/examples/README.md b/examples/README.md new file mode 100644 index 00000000..d3f2afac --- /dev/null +++ b/examples/README.md @@ -0,0 +1 @@ +These examples are outdated! \ No newline at end of file diff --git a/experiments/example_docker/descr.yaml b/experiments/example_docker/descr.yaml index 79695fcf..b1a7aaa2 100644 --- a/experiments/example_docker/descr.yaml +++ b/experiments/example_docker/descr.yaml @@ -1,6 +1,7 @@ --- # Experiment configuration total_epochs: 3 +rounds: 5 epochs_per_cycle: 1 wait_for_clients: true net: MNISTCNN diff --git a/experiments/example_native/descr.yaml b/experiments/example_native/descr.yaml index d5e5386f..c254640b 100644 --- a/experiments/example_native/descr.yaml +++ b/experiments/example_native/descr.yaml @@ -1,6 +1,7 @@ --- # Experiment configuration total_epochs: 3 +rounds: 5 epochs_per_cycle: 1 wait_for_clients: true net: MNISTCNN @@ -20,5 +21,5 @@ sampler: "uniform" # "limit labels" || "q sampler" || "dirichlet" || "uniform" ( sampler_args: - 0.07 # label limit || q probability || alpha || unused - 42 # random seed || random seed || random seed || unused -num_clients: 2 -replications: 2 +num_clients: 10 +replications: 5 diff --git a/fltk/core/federator.py b/fltk/core/federator.py index 4975e0bc..99fde847 100644 --- a/fltk/core/federator.py +++ b/fltk/core/federator.py @@ -245,6 +245,6 @@ def all_futures_done(futures: List[torch.Future])->bool: end_time = time.time() duration = end_time - start_time - self.exp_data.append(FederatorRecord(len(selected_clients), 0, duration, test_loss, test_accuracy)) + self.exp_data.append(FederatorRecord(len(selected_clients), id, duration, test_loss, test_accuracy)) self.logger.info(f'[Round {id:>3}] Round duration is {duration} seconds') diff --git a/fltk/util/analysis.py b/fltk/util/analysis.py index b35ce4c8..1fca083f 100644 --- a/fltk/util/analysis.py +++ b/fltk/util/analysis.py @@ -7,7 +7,8 @@ import seaborn as sns import re -# alt.renderers.enable('mimetype') +from matplotlib.lines import Line2D + def get_cwd() -> Path: return Path.cwd() @@ -46,21 +47,64 @@ def plot_client_duration(df: pd.DataFrame): plt.tight_layout() plt.show() +def plot_federator_accuracy(df: pd.DataFrame): + plt.figure() + g = sns.lineplot(data=df, x='round_id', y='test_accuracy') + # df.plot(x="date", y="column2", ax=ax2, legend=False, color="r") + sns.lineplot(ax=g.axes.twinx(), data=df, x='round_id', y='test_loss', color='r') + plt.title('Federator test accuracy') + g.legend(handles=[Line2D([], [], marker='_', color="r", label='test_loss'), + Line2D([], [], marker='_', color="b", label='test_accuracy')]) + plt.tight_layout() + plt.show() -def analyse(path: Path): - cwd = get_cwd() - output_path = cwd / get_exp_name(path) - ensure_path_exists(output_path) +def plot_clients_accuracy(df: pd.DataFrame): + plt.figure() + g = sns.lineplot(data=df, x='round_id', y='accuracy', hue='node_name') + plt.title('Client test accuracy') + plt.tight_layout() + plt.show() + + +def load_replication(path: Path, replication_id: int): all_files = [x for x in path.iterdir() if x.is_file()] federator_files = [x for x in all_files if 'federator' in x.name] client_files = [x for x in all_files if x.name.startswith('client')] federator_data = load_and_merge_dfs(federator_files) + federator_data['replication'] = replication_id client_data = load_and_merge_dfs(client_files) + client_data['replication'] = replication_id + return federator_data, client_data - # print(len(client_data), len(federator_data)) - plot_client_duration(client_data) - # What do we want to plot in terms of data? +def analyse(path: Path): + # cwd = get_cwd() + # output_path = cwd / get_exp_name(path) + # ensure_path_exists(output_path) + replications = [x for x in path.iterdir() if x.is_dir()] + print(replications) + client_dfs = [] + federator_dfs = [] + for replication_path in replications: + replication_id = int(replication_path.name.split('_')[-1][1:]) + federator_data, client_data = load_replication(replication_path, replication_id) + client_dfs.append(client_data) + federator_dfs.append(federator_data) + + federator_df = pd.concat(federator_dfs, ignore_index=True) + client_df = pd.concat(client_dfs, ignore_index=True) + # all_files = [x for x in path.iterdir() if x.is_file()] + # federator_files = [x for x in all_files if 'federator' in x.name] + # client_files = [x for x in all_files if x.name.startswith('client')] + # + # federator_data = load_and_merge_dfs(federator_files) + # client_data = load_and_merge_dfs(client_files) + # + # # print(len(client_data), len(federator_data)) + plot_client_duration(client_df) + plot_federator_accuracy(federator_df) + plot_clients_accuracy(client_df) + # # What do we want to plot in terms of data? if __name__ == '__main__':