Skip to content

Commit

Permalink
Update analysis script
Browse files Browse the repository at this point in the history
  • Loading branch information
bacox committed Mar 17, 2022
1 parent d2b7d80 commit 2a22902
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 11 deletions.
1 change: 1 addition & 0 deletions examples/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
These examples are outdated!
1 change: 1 addition & 0 deletions experiments/example_docker/descr.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
---
# Experiment configuration
total_epochs: 3
rounds: 5
epochs_per_cycle: 1
wait_for_clients: true
net: MNISTCNN
Expand Down
5 changes: 3 additions & 2 deletions experiments/example_native/descr.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
---
# Experiment configuration
total_epochs: 3
rounds: 5
epochs_per_cycle: 1
wait_for_clients: true
net: MNISTCNN
Expand All @@ -20,5 +21,5 @@ sampler: "uniform" # "limit labels" || "q sampler" || "dirichlet" || "uniform" (
sampler_args:
- 0.07 # label limit || q probability || alpha || unused
- 42 # random seed || random seed || random seed || unused
num_clients: 2
replications: 2
num_clients: 10
replications: 5
2 changes: 1 addition & 1 deletion fltk/core/federator.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,6 @@ def all_futures_done(futures: List[torch.Future])->bool:

end_time = time.time()
duration = end_time - start_time
self.exp_data.append(FederatorRecord(len(selected_clients), 0, duration, test_loss, test_accuracy))
self.exp_data.append(FederatorRecord(len(selected_clients), id, duration, test_loss, test_accuracy))
self.logger.info(f'[Round {id:>3}] Round duration is {duration} seconds')

60 changes: 52 additions & 8 deletions fltk/util/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
import seaborn as sns
import re

# alt.renderers.enable('mimetype')
from matplotlib.lines import Line2D


def get_cwd() -> Path:
return Path.cwd()
Expand Down Expand Up @@ -46,21 +47,64 @@ def plot_client_duration(df: pd.DataFrame):
plt.tight_layout()
plt.show()

def plot_federator_accuracy(df: pd.DataFrame):
plt.figure()
g = sns.lineplot(data=df, x='round_id', y='test_accuracy')
# df.plot(x="date", y="column2", ax=ax2, legend=False, color="r")
sns.lineplot(ax=g.axes.twinx(), data=df, x='round_id', y='test_loss', color='r')
plt.title('Federator test accuracy')
g.legend(handles=[Line2D([], [], marker='_', color="r", label='test_loss'),
Line2D([], [], marker='_', color="b", label='test_accuracy')])
plt.tight_layout()
plt.show()

def analyse(path: Path):
cwd = get_cwd()
output_path = cwd / get_exp_name(path)
ensure_path_exists(output_path)
def plot_clients_accuracy(df: pd.DataFrame):
plt.figure()
g = sns.lineplot(data=df, x='round_id', y='accuracy', hue='node_name')
plt.title('Client test accuracy')
plt.tight_layout()
plt.show()


def load_replication(path: Path, replication_id: int):
all_files = [x for x in path.iterdir() if x.is_file()]
federator_files = [x for x in all_files if 'federator' in x.name]
client_files = [x for x in all_files if x.name.startswith('client')]

federator_data = load_and_merge_dfs(federator_files)
federator_data['replication'] = replication_id
client_data = load_and_merge_dfs(client_files)
client_data['replication'] = replication_id
return federator_data, client_data

# print(len(client_data), len(federator_data))
plot_client_duration(client_data)
# What do we want to plot in terms of data?
def analyse(path: Path):
# cwd = get_cwd()
# output_path = cwd / get_exp_name(path)
# ensure_path_exists(output_path)
replications = [x for x in path.iterdir() if x.is_dir()]
print(replications)
client_dfs = []
federator_dfs = []
for replication_path in replications:
replication_id = int(replication_path.name.split('_')[-1][1:])
federator_data, client_data = load_replication(replication_path, replication_id)
client_dfs.append(client_data)
federator_dfs.append(federator_data)

federator_df = pd.concat(federator_dfs, ignore_index=True)
client_df = pd.concat(client_dfs, ignore_index=True)
# all_files = [x for x in path.iterdir() if x.is_file()]
# federator_files = [x for x in all_files if 'federator' in x.name]
# client_files = [x for x in all_files if x.name.startswith('client')]
#
# federator_data = load_and_merge_dfs(federator_files)
# client_data = load_and_merge_dfs(client_files)
#
# # print(len(client_data), len(federator_data))
plot_client_duration(client_df)
plot_federator_accuracy(federator_df)
plot_clients_accuracy(client_df)
# # What do we want to plot in terms of data?


if __name__ == '__main__':
Expand Down

0 comments on commit 2a22902

Please sign in to comment.