From ebebb26c5c3372c1d6311e7d59770f9d6a2d99f2 Mon Sep 17 00:00:00 2001 From: Bart Cox Date: Fri, 7 May 2021 16:07:35 +0200 Subject: [PATCH 1/6] Add timing callbacks --- fltk/client.py | 12 ++++--- fltk/federator.py | 84 ++++++++++++++++++--------------------------- fltk/util/remote.py | 54 +++++++++++++++++++++++++++++ 3 files changed, 96 insertions(+), 54 deletions(-) create mode 100644 fltk/util/remote.py diff --git a/fltk/client.py b/fltk/client.py index de3095e3..99b17271 100644 --- a/fltk/client.py +++ b/fltk/client.py @@ -171,10 +171,14 @@ def update_nn_parameters(self, new_params): :param new_params: New weights for the neural network :type new_params: dict """ + start_time = time.time() self.net.load_state_dict(copy.deepcopy(new_params), strict=True) if self.log_rref: self.remote_log(f'Weights of the model are updated') + end_time = time.time() + return end_time - start_time + def train(self, epoch): """ :param epoch: Current epoch # @@ -248,10 +252,10 @@ def test(self): self.args.get_logger().debug('Test set: Accuracy: {}/{} ({:.0f}%)'.format(correct, total, accuracy)) self.args.get_logger().debug('Test set: Loss: {}'.format(loss)) - self.args.get_logger().debug("Classification Report:\n" + classification_report(targets_, pred_)) - self.args.get_logger().debug("Confusion Matrix:\n" + str(confusion_mat)) - self.args.get_logger().debug("Class precision: {}".format(str(class_precision))) - self.args.get_logger().debug("Class recall: {}".format(str(class_recall))) + # self.args.get_logger().debug("Classification Report:\n" + classification_report(targets_, pred_)) + # self.args.get_logger().debug("Confusion Matrix:\n" + str(confusion_mat)) + # self.args.get_logger().debug("Class precision: {}".format(str(class_precision))) + # self.args.get_logger().debug("Class recall: {}".format(str(class_recall))) return accuracy, loss, class_precision, class_recall diff --git a/fltk/federator.py b/fltk/federator.py index 88ccf31d..cbb7ac97 100644 --- a/fltk/federator.py +++ b/fltk/federator.py @@ -1,4 +1,3 @@ -import datetime import time from typing import List @@ -6,50 +5,18 @@ from torch.distributed import rpc from fltk.client import Client -from fltk.datasets.data_distribution import distribute_batches_equally from fltk.strategy.client_selection import random_selection -from fltk.util.arguments import Arguments -from fltk.util.base_config import BareConfig -from fltk.util.data_loader_utils import load_train_data_loader, load_test_data_loader, \ - generate_data_loaders_from_distributed_dataset from fltk.util.fed_avg import average_nn_parameters from fltk.util.log import FLLogger -from torchsummary import summary from torch.utils.tensorboard import SummaryWriter from pathlib import Path import logging +from fltk.util.remote import ClientRef, _remote_method, _remote_method_async, AsyncCall, time_remote_async_call from fltk.util.results import EpochData -from fltk.util.tensor_converter import convert_distributed_data_into_numpy logging.basicConfig(level=logging.DEBUG) -def _call_method(method, rref, *args, **kwargs): - return method(rref.local_value(), *args, **kwargs) - - -def _remote_method(method, rref, *args, **kwargs): - args = [method, rref] + list(args) - return rpc.rpc_sync(rref.owner(), _call_method, args=args, kwargs=kwargs) - -def _remote_method_async(method, rref, *args, **kwargs): - args = [method, rref] + list(args) - return rpc.rpc_async(rref.owner(), _call_method, args=args, kwargs=kwargs) - -class ClientRef: - ref = None - name = "" - data_size = 0 - tb_writer = None - - def __init__(self, name, ref, tensorboard_writer): - self.name = name - self.ref = ref - self.tb_writer = tensorboard_writer - - def __repr__(self): - return self.name - class Federator: """ Central component of the Federated Learning System: The Federator @@ -137,31 +104,42 @@ def clients_ready(self): logging.info('All clients are ready') def remote_run_epoch(self, epochs): - responses = [] + + responses: List[AsyncCall] = [] client_weights = [] selected_clients = self.select_clients(self.config.clients_per_round) for client in selected_clients: - responses.append((client, _remote_method_async(Client.run_epochs, client.ref, num_epoch=epochs))) + response = time_remote_async_call(client, Client.run_epochs, client.ref, num_epoch=epochs) + responses.append(response) + self.epoch_counter += epochs + durations = [] for res in responses: - epoch_data, weights = res[1].wait() + res.future.wait() + epoch_data, weights = res.future.wait() + fed_stop_time = time.time() self.client_data[epoch_data.client_id].append(epoch_data) - logging.info(f'{res[0]} had a loss of {epoch_data.loss}') - logging.info(f'{res[0]} had a epoch data of {epoch_data}') + logging.info(f'{res.client.name} had a loss of {epoch_data.loss}') + logging.info(f'{res.client.name} had a epoch data of {epoch_data}') + logging.info(f'[TIMING FUT]\t{res.client.name} had a epoch duration of {res.duration()}') + fed_duration = fed_stop_time - res.end_time + logging.info(f'[TIMING LOCAL]\t{res.client.name} had a epoch duration of {fed_duration}') + durations.append((res.client.name, res.duration(), fed_duration)) - res[0].tb_writer.add_scalar('training loss', + + res.client.tb_writer.add_scalar('training loss', epoch_data.loss_train, # for every 1000 minibatches - self.epoch_counter * res[0].data_size) + self.epoch_counter * res.client.data_size) - res[0].tb_writer.add_scalar('accuracy', + res.client.tb_writer.add_scalar('accuracy', epoch_data.accuracy, # for every 1000 minibatches - self.epoch_counter * res[0].data_size) + self.epoch_counter * res.client.data_size) - res[0].tb_writer.add_scalar('training loss per epoch', + res.client.tb_writer.add_scalar('training loss per epoch', epoch_data.loss_train, # for every 1000 minibatches self.epoch_counter) - res[0].tb_writer.add_scalar('accuracy per epoch', + res.client.tb_writer.add_scalar('accuracy per epoch', epoch_data.accuracy, # for every 1000 minibatches self.epoch_counter) @@ -172,20 +150,26 @@ def remote_run_epoch(self, epochs): logging.info("Testing on global test set") self.test_data.update_nn_parameters(updated_model) accuracy, loss, class_precision, class_recall = self.test_data.test() - # self.tb_writer.add_scalar('training loss', loss, self.epoch_counter * self.test_data.get_client_datasize()) # does not seem to work :( ) self.tb_writer.add_scalar('accuracy', accuracy, self.epoch_counter * self.test_data.get_client_datasize()) self.tb_writer.add_scalar('accuracy per epoch', accuracy, self.epoch_counter) - responses = [] for client in self.clients: - responses.append( - (client, _remote_method_async(Client.update_nn_parameters, client.ref, new_params=updated_model))) + response = time_remote_async_call(client, Client.update_nn_parameters, client.ref, new_params=updated_model) + responses.append(response) for res in responses: - res[1].wait() + func_duration = res.future.wait() + print(f'[Client:: {res.client.name}] internal weights copied in {func_duration}') + print(f'[Client:: {res.client.name}] model transfer time: {res.duration()}') logging.info('Weights are updated') + print('Duration timing') + for name, fut_time, fed_time in durations: + print(f'Client: {name} has these timings:') + print(f'FUT:\t{fut_time}') + print(f'Fed:\t{fed_time}') + def update_client_data_sizes(self): responses = [] for client in self.clients: diff --git a/fltk/util/remote.py b/fltk/util/remote.py new file mode 100644 index 00000000..9e5475c7 --- /dev/null +++ b/fltk/util/remote.py @@ -0,0 +1,54 @@ +import time + +from torch.distributed import rpc +from dataclasses import dataclass +from torch.futures import Future + +def _call_method(method, rref, *args, **kwargs): + return method(rref.local_value(), *args, **kwargs) + +def _remote_method(method, rref, *args, **kwargs): + args = [method, rref] + list(args) + return rpc.rpc_sync(rref.owner(), _call_method, args=args, kwargs=kwargs) + +def _remote_method_async(method, rref, *args, **kwargs): + args = [method, rref] + list(args) + return rpc.rpc_async(rref.owner(), _call_method, args=args, kwargs=kwargs) + +class ClientRef: + ref = None + name = "" + data_size = 0 + tb_writer = None + + def __init__(self, name, ref, tensorboard_writer): + self.name = name + self.ref = ref + self.tb_writer = tensorboard_writer + + def __repr__(self): + return self.name + +@dataclass +class AsyncCall: + future: Future + client: ClientRef + start_time: float = 0 + end_time: float = 0 + + def duration(self): + return self.end_time - self.start_time + + +def bind_timing_cb(response_obj: AsyncCall): + def callback(fut): + stop_time = time.time() + response_obj.end_time = stop_time + response_obj.future.then(callback) + +def time_remote_async_call(client, method, rref, *args, **kwargs): + start_time = time.time() + fut = _remote_method_async(method, rref, *args, **kwargs) + response = AsyncCall(fut, client, start_time=start_time) + bind_timing_cb(response) + return response \ No newline at end of file From 9754e0cd0d429bb344cbc6d8eccddd4b1a5b5371 Mon Sep 17 00:00:00 2001 From: Bart Cox Date: Sat, 8 May 2021 16:22:15 +0200 Subject: [PATCH 2/6] Add timing profiling --- fltk/client.py | 4 +- fltk/federator.py | 86 ++++++++++++++++++++++++---------------- fltk/util/base_config.py | 4 ++ fltk/util/remote.py | 16 +++++++- 4 files changed, 72 insertions(+), 38 deletions(-) diff --git a/fltk/client.py b/fltk/client.py index 99b17271..ffcb1534 100644 --- a/fltk/client.py +++ b/fltk/client.py @@ -214,6 +214,8 @@ def train(self, epoch): final_running_loss = running_loss / self.args.get_log_interval() running_loss = 0.0 + break + self.scheduler.step() # save model @@ -308,4 +310,4 @@ def get_client_datasize(self): return len(self.dataset.get_train_sampler()) def __del__(self): - print(f'Client {self.id} is stopping') + logging.info(f'Client {self.id} is stopping') diff --git a/fltk/federator.py b/fltk/federator.py index cbb7ac97..e073b720 100644 --- a/fltk/federator.py +++ b/fltk/federator.py @@ -12,7 +12,7 @@ from pathlib import Path import logging -from fltk.util.remote import ClientRef, _remote_method, _remote_method_async, AsyncCall, time_remote_async_call +from fltk.util.remote import ClientRef, AsyncCall, timed_remote_async_call, _remote_method, TimingRecord from fltk.util.results import EpochData logging.basicConfig(level=logging.DEBUG) @@ -70,17 +70,18 @@ def ping_all(self): answer = _remote_method(Client.ping, client.ref) t_end = time.time() duration = (t_end - t_start)*1000 + client.timing_data.append(TimingRecord(f'{client.name}', 'ping', duration)) logging.info(f'Ping to {client} is {duration:.3}ms') def rpc_test_all(self): for client in self.clients: - res = _remote_method_async(Client.rpc_test, client.ref) - while not res.done(): + res = timed_remote_async_call(client, Client.rpc_test, client.ref) + while not res.future.done(): pass def client_load_data(self): for client in self.clients: - _remote_method_async(Client.init_dataloader, client.ref) + timed_remote_async_call(client, Client.init_dataloader, client.ref) def clients_ready(self): all_ready = False @@ -89,15 +90,16 @@ def clients_ready(self): responses = [] for client in self.clients: if client.name not in ready_clients: - responses.append((client, _remote_method_async(Client.is_ready, client.ref))) + response = timed_remote_async_call(client, Client.is_ready, client.ref) + responses.append(response) all_ready = True for res in responses: - result = res[1].wait() + result = res.future.wait() if result: - logging.info(f'{res[0]} is ready') - ready_clients.append(res[0]) + logging.info(f'{res.client} is ready') + ready_clients.append(res.client) else: - logging.info(f'Waiting for {res[0]}') + logging.info(f'Waiting for {res.client}') all_ready = False time.sleep(2) @@ -109,7 +111,7 @@ def remote_run_epoch(self, epochs): client_weights = [] selected_clients = self.select_clients(self.config.clients_per_round) for client in selected_clients: - response = time_remote_async_call(client, Client.run_epochs, client.ref, num_epoch=epochs) + response = timed_remote_async_call(client, Client.run_epochs, client.ref, num_epoch=epochs) responses.append(response) self.epoch_counter += epochs @@ -121,11 +123,7 @@ def remote_run_epoch(self, epochs): self.client_data[epoch_data.client_id].append(epoch_data) logging.info(f'{res.client.name} had a loss of {epoch_data.loss}') logging.info(f'{res.client.name} had a epoch data of {epoch_data}') - logging.info(f'[TIMING FUT]\t{res.client.name} had a epoch duration of {res.duration()}') - fed_duration = fed_stop_time - res.end_time - logging.info(f'[TIMING LOCAL]\t{res.client.name} had a epoch duration of {fed_duration}') - durations.append((res.client.name, res.duration(), fed_duration)) - + res.client.timing_data.append(TimingRecord(f'{res.client.name}', 'epoch_time_round_trip', res.duration())) res.client.tb_writer.add_scalar('training loss', epoch_data.loss_train, # for every 1000 minibatches @@ -155,40 +153,36 @@ def remote_run_epoch(self, epochs): responses = [] for client in self.clients: - response = time_remote_async_call(client, Client.update_nn_parameters, client.ref, new_params=updated_model) + response = timed_remote_async_call(client, Client.update_nn_parameters, client.ref, new_params=updated_model) responses.append(response) for res in responses: func_duration = res.future.wait() - print(f'[Client:: {res.client.name}] internal weights copied in {func_duration}') - print(f'[Client:: {res.client.name}] model transfer time: {res.duration()}') + res.client.timing_data.append(TimingRecord(res.client.name, 'update_param_inner', func_duration)) + res.client.timing_data.append(TimingRecord(f'{res.client.name}', 'update_param_round_trip', res.duration())) logging.info('Weights are updated') - print('Duration timing') - for name, fut_time, fed_time in durations: - print(f'Client: {name} has these timings:') - print(f'FUT:\t{fut_time}') - print(f'Fed:\t{fed_time}') - def update_client_data_sizes(self): responses = [] for client in self.clients: - responses.append((client, _remote_method_async(Client.get_client_datasize, client.ref))) + response = timed_remote_async_call(client, Client.get_client_datasize, client.ref) + responses.append(response) for res in responses: - res[0].data_size = res[1].wait() - logging.info(f'{res[0]} had a result of datasize={res[0].data_size}') + res.client.data_size = res.future.wait() + logging.info(f'{res.client.name} had a result of datasize={res.client.data_size}') def remote_test_sync(self): responses = [] for client in self.clients: - responses.append((client, _remote_method_async(Client.test, client.ref))) + response = timed_remote_async_call(client, Client.test, client.ref) + responses.append(response) for res in responses: - accuracy, loss, class_precision, class_recall = res[1].wait() - logging.info(f'{res[0]} had a result of accuracy={accuracy}') + accuracy, loss, class_precision, class_recall = res.future.wait() + logging.info(f'{res.client.name} had a result of accuracy={accuracy}') def save_epoch_data(self): - file_output = f'./{self.config.output_location}' + file_output = f'./{self.config.output_location}/{self.config.experiment_prefix}_data' self.ensure_path_exists(file_output) for key in self.client_data: filename = f'{file_output}/{key}_epochs.csv' @@ -197,6 +191,18 @@ def save_epoch_data(self): w = DataclassWriter(f, self.client_data[key], EpochData) w.write() + def save_profiling_data(self): + file_output = f'./{self.config.output_location}/{self.config.experiment_prefix}_data' + filename = f'{file_output}/profiling_data.csv' + self.ensure_path_exists(file_output) + with open(filename, "w") as f: + for client in self.clients: + for record in client.timing_data: + w = DataclassWriter(f, [record], TimingRecord) + w.write() + + + def ensure_path_exists(self, path): Path(path).mkdir(parents=True, exist_ok=True) @@ -216,13 +222,23 @@ def run(self): epoch_to_run = self.config.epochs epoch_size = self.config.epochs_per_cycle for epoch in range(epoch_to_run): - print(f'Running epoch {epoch}') + logging.info(f'Running epoch {epoch}') self.remote_run_epoch(epoch_size) addition += 1 - logging.info('Printing client data') - print(self.client_data) + logging.info('Available clients with data') + logging.info(self.client_data.keys()) - logging.info(f'Saving data') + logging.info('Saving data') self.save_epoch_data() + + logging.info('Printing all clients timing data') + for client in self.clients: + logging.info(f"Timing data for client {client}") + for record in client.timing_data: + logging.info(f'{record}') + + logging.info('Saving profiling data') + self.save_profiling_data() + logging.info(f'Federator is stopping') diff --git a/fltk/util/base_config.py b/fltk/util/base_config.py index c814965f..5f87840a 100644 --- a/fltk/util/base_config.py +++ b/fltk/util/base_config.py @@ -1,3 +1,5 @@ +from datetime import datetime + import torch import json @@ -109,6 +111,8 @@ def merge_yaml(self, cfg = {}): self.dataset_name = cfg['dataset'] if 'experiment_prefix' in cfg: self.experiment_prefix = cfg['experiment_prefix'] + else: + self.experiment_prefix = f'{datetime.now()}' if 'output_location' in cfg: self.output_location = cfg['output_location'] if 'tensor_board_active' in cfg: diff --git a/fltk/util/remote.py b/fltk/util/remote.py index 9e5475c7..0202f92f 100644 --- a/fltk/util/remote.py +++ b/fltk/util/remote.py @@ -1,7 +1,8 @@ import time +from typing import Any, List from torch.distributed import rpc -from dataclasses import dataclass +from dataclasses import dataclass, field from torch.futures import Future def _call_method(method, rref, *args, **kwargs): @@ -15,16 +16,27 @@ def _remote_method_async(method, rref, *args, **kwargs): args = [method, rref] + list(args) return rpc.rpc_async(rref.owner(), _call_method, args=args, kwargs=kwargs) +@dataclass +class TimingRecord: + client_id: str + metric: str + value: Any + epoch: int = None + timestamp: float = field(default_factory=time.time) + + class ClientRef: ref = None name = "" data_size = 0 tb_writer = None + timing_data: List[TimingRecord] = [] def __init__(self, name, ref, tensorboard_writer): self.name = name self.ref = ref self.tb_writer = tensorboard_writer + self.timing_data = [] def __repr__(self): return self.name @@ -46,7 +58,7 @@ def callback(fut): response_obj.end_time = stop_time response_obj.future.then(callback) -def time_remote_async_call(client, method, rref, *args, **kwargs): +def timed_remote_async_call(client, method, rref, *args, **kwargs): start_time = time.time() fut = _remote_method_async(method, rref, *args, **kwargs) response = AsyncCall(fut, client, start_time=start_time) From ff03acd142be630ab28f58f1951b2c7866a691d9 Mon Sep 17 00:00:00 2001 From: Bart Cox Date: Sat, 8 May 2021 16:24:09 +0200 Subject: [PATCH 3/6] remove debug log lines --- fltk/federator.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/fltk/federator.py b/fltk/federator.py index e073b720..de36c6bd 100644 --- a/fltk/federator.py +++ b/fltk/federator.py @@ -228,15 +228,9 @@ def run(self): logging.info('Available clients with data') logging.info(self.client_data.keys()) + # Save experiment data logging.info('Saving data') self.save_epoch_data() - - logging.info('Printing all clients timing data') - for client in self.clients: - logging.info(f"Timing data for client {client}") - for record in client.timing_data: - logging.info(f'{record}') - logging.info('Saving profiling data') self.save_profiling_data() From 4e75481abc6b4a6c43c85a7f2f8a7dda9358fbb5 Mon Sep 17 00:00:00 2001 From: Bart Cox Date: Sat, 8 May 2021 21:26:27 +0200 Subject: [PATCH 4/6] Fix profiling data export --- fltk/client.py | 20 +++---------- fltk/federator.py | 72 ++++++++++++++++++++++++++++------------------- 2 files changed, 47 insertions(+), 45 deletions(-) diff --git a/fltk/client.py b/fltk/client.py index ffcb1534..12ad0b56 100644 --- a/fltk/client.py +++ b/fltk/client.py @@ -93,9 +93,6 @@ def remote_log(self, message): def local_log(self, message): logging.info(f'[{self.id}: {time.time()}]: {message}') - def set_configuration(self, config: str): - yaml_config = yaml.safe_load(config) - def init(self): pass @@ -115,11 +112,6 @@ def set_net(self, net): self.net = net self.net.to(self.device) - def load_model_from_file(self, model_file_path): - model_class = self.args.get_net() - default_model_path = os.path.join(self.args.get_default_model_folder_path(), model_class.__name__ + ".model") - return self.load_model_from_file(default_model_path) - def get_nn_parameters(self): """ Return the NN's parameters. @@ -184,8 +176,6 @@ def train(self, epoch): :param epoch: Current epoch # :type epoch: int """ - # self.net.train() - # save model if self.args.should_save_model(epoch): self.save_model(epoch, self.args.get_epoch_save_start_suffix()) @@ -214,8 +204,6 @@ def train(self, epoch): final_running_loss = running_loss / self.args.get_log_interval() running_loss = 0.0 - break - self.scheduler.step() # save model @@ -254,10 +242,10 @@ def test(self): self.args.get_logger().debug('Test set: Accuracy: {}/{} ({:.0f}%)'.format(correct, total, accuracy)) self.args.get_logger().debug('Test set: Loss: {}'.format(loss)) - # self.args.get_logger().debug("Classification Report:\n" + classification_report(targets_, pred_)) - # self.args.get_logger().debug("Confusion Matrix:\n" + str(confusion_mat)) - # self.args.get_logger().debug("Class precision: {}".format(str(class_precision))) - # self.args.get_logger().debug("Class recall: {}".format(str(class_recall))) + self.args.get_logger().debug("Classification Report:\n" + classification_report(targets_, pred_)) + self.args.get_logger().debug("Confusion Matrix:\n" + str(confusion_mat)) + self.args.get_logger().debug("Class precision: {}".format(str(class_precision))) + self.args.get_logger().debug("Class recall: {}".format(str(class_recall))) return accuracy, loss, class_precision, class_recall diff --git a/fltk/federator.py b/fltk/federator.py index de36c6bd..47acf0af 100644 --- a/fltk/federator.py +++ b/fltk/federator.py @@ -1,6 +1,9 @@ +import copy +import os import time from typing import List +import torch from dataclass_csv import DataclassWriter from torch.distributed import rpc @@ -52,7 +55,6 @@ def __init__(self, client_id_triple, num_epochs = 3, config=None): self.test_data.init_dataloader() config.data_sampler = copy_sampler - def create_clients(self, client_id_triple): for id, rank, world_size in client_id_triple: client = rpc.remote(id, Client, kwargs=dict(id=id, log_rref=self.log_rref, rank=rank, world_size=world_size, config=self.config)) @@ -106,42 +108,68 @@ def clients_ready(self): logging.info('All clients are ready') def remote_run_epoch(self, epochs): + """ + Federated Learning steps: + 1. Client selection + 2. Selected clients download model + 3. Local training + 4. Model aggregation + Repeat + :param epochs: + :return: + """ + + # 1. Client selection + selected_clients = self.select_clients(self.config.clients_per_round) + + # 2. Selected clients download model + responses = [] + for client in selected_clients: + response = timed_remote_async_call(client, Client.update_nn_parameters, client.ref, + new_params=self.test_data.get_nn_parameters()) + responses.append(response) + + for res in responses: + func_duration = res.future.wait() + res.client.timing_data.append(TimingRecord(res.client.name, 'update_param_inner', func_duration)) + res.client.timing_data.append(TimingRecord(f'{res.client.name}', 'update_param_round_trip', res.duration())) + logging.info('Weights are updated') + # 3. Local training responses: List[AsyncCall] = [] client_weights = [] - selected_clients = self.select_clients(self.config.clients_per_round) for client in selected_clients: response = timed_remote_async_call(client, Client.run_epochs, client.ref, num_epoch=epochs) responses.append(response) self.epoch_counter += epochs - durations = [] for res in responses: res.future.wait() epoch_data, weights = res.future.wait() - fed_stop_time = time.time() self.client_data[epoch_data.client_id].append(epoch_data) logging.info(f'{res.client.name} had a loss of {epoch_data.loss}') logging.info(f'{res.client.name} had a epoch data of {epoch_data}') res.client.timing_data.append(TimingRecord(f'{res.client.name}', 'epoch_time_round_trip', res.duration())) res.client.tb_writer.add_scalar('training loss', - epoch_data.loss_train, # for every 1000 minibatches - self.epoch_counter * res.client.data_size) + epoch_data.loss_train, # for every 1000 minibatches + self.epoch_counter * res.client.data_size) res.client.tb_writer.add_scalar('accuracy', - epoch_data.accuracy, # for every 1000 minibatches - self.epoch_counter * res.client.data_size) + epoch_data.accuracy, # for every 1000 minibatches + self.epoch_counter * res.client.data_size) res.client.tb_writer.add_scalar('training loss per epoch', - epoch_data.loss_train, # for every 1000 minibatches - self.epoch_counter) + epoch_data.loss_train, # for every 1000 minibatches + self.epoch_counter) res.client.tb_writer.add_scalar('accuracy per epoch', - epoch_data.accuracy, # for every 1000 minibatches - self.epoch_counter) + epoch_data.accuracy, # for every 1000 minibatches + self.epoch_counter) client_weights.append(weights) + + # 3. Model aggregation updated_model = average_nn_parameters(client_weights) # test global model @@ -151,17 +179,6 @@ def remote_run_epoch(self, epochs): self.tb_writer.add_scalar('accuracy', accuracy, self.epoch_counter * self.test_data.get_client_datasize()) self.tb_writer.add_scalar('accuracy per epoch', accuracy, self.epoch_counter) - responses = [] - for client in self.clients: - response = timed_remote_async_call(client, Client.update_nn_parameters, client.ref, new_params=updated_model) - responses.append(response) - - for res in responses: - func_duration = res.future.wait() - res.client.timing_data.append(TimingRecord(res.client.name, 'update_param_inner', func_duration)) - res.client.timing_data.append(TimingRecord(f'{res.client.name}', 'update_param_round_trip', res.duration())) - logging.info('Weights are updated') - def update_client_data_sizes(self): responses = [] for client in self.clients: @@ -195,13 +212,10 @@ def save_profiling_data(self): file_output = f'./{self.config.output_location}/{self.config.experiment_prefix}_data' filename = f'{file_output}/profiling_data.csv' self.ensure_path_exists(file_output) + records = [data for client in self.clients for data in client.timing_data] with open(filename, "w") as f: - for client in self.clients: - for record in client.timing_data: - w = DataclassWriter(f, [record], TimingRecord) - w.write() - - + w = DataclassWriter(f, records, TimingRecord) + w.write() def ensure_path_exists(self, path): Path(path).mkdir(parents=True, exist_ok=True) From 4e1416caaa32045a1c893a17adec29c10f3d7b9e Mon Sep 17 00:00:00 2001 From: Bart Cox Date: Sat, 8 May 2021 21:27:53 +0200 Subject: [PATCH 5/6] Version bump --- fltk/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fltk/__init__.py b/fltk/__init__.py index aa6e546b..d1eb1a0f 100644 --- a/fltk/__init__.py +++ b/fltk/__init__.py @@ -1,2 +1,2 @@ -__version__ = '0.3.1' \ No newline at end of file +__version__ = '0.3.2' \ No newline at end of file From e458dfc240d2fb07b03866950e14b900917e7f4d Mon Sep 17 00:00:00 2001 From: Bart Cox Date: Sat, 8 May 2021 22:08:16 +0200 Subject: [PATCH 6/6] Extend profiling data --- fltk/client.py | 15 +++++++-------- fltk/federator.py | 37 ++++++++++++++++++++++++------------- fltk/util/results.py | 5 +++-- 3 files changed, 34 insertions(+), 23 deletions(-) diff --git a/fltk/client.py b/fltk/client.py index 12ad0b56..90acc8ed 100644 --- a/fltk/client.py +++ b/fltk/client.py @@ -250,25 +250,24 @@ def test(self): return accuracy, loss, class_precision, class_recall def run_epochs(self, num_epoch): - start_time_train = datetime.datetime.now() + start_time_train = time.time() self.dataset.get_train_sampler().set_epoch_size(num_epoch) loss, weights = self.train(self.epoch_counter) self.epoch_counter += num_epoch - elapsed_time_train = datetime.datetime.now() - start_time_train - train_time_ms = int(elapsed_time_train.total_seconds()*1000) + elapsed_train_time = time.time() - start_time_train - start_time_test = datetime.datetime.now() + start_time_test = time.time() accuracy, test_loss, class_precision, class_recall = self.test() - elapsed_time_test = datetime.datetime.now() - start_time_test - test_time_ms = int(elapsed_time_test.total_seconds()*1000) + elapsed_test_time = time.time() - start_time_test - data = EpochData(self.epoch_counter, train_time_ms, test_time_ms, loss, accuracy, test_loss, class_precision, class_recall, client_id=self.id) + data = EpochData(self.epoch_counter, num_epoch, elapsed_train_time, elapsed_test_time, loss, accuracy, test_loss, class_precision, class_recall, client_id=self.id) self.epoch_results.append(data) # Copy GPU tensors to CPU for k, v in weights.items(): weights[k] = v.cpu() - return data, weights + end_func_time = time.time() - start_time_train + return data, weights, end_func_time def save_model(self, epoch, suffix): """ diff --git a/fltk/federator.py b/fltk/federator.py index 47acf0af..fe41363f 100644 --- a/fltk/federator.py +++ b/fltk/federator.py @@ -107,7 +107,7 @@ def clients_ready(self): time.sleep(2) logging.info('All clients are ready') - def remote_run_epoch(self, epochs): + def remote_run_epoch(self, epochs_subset): """ Federated Learning steps: 1. Client selection @@ -131,25 +131,33 @@ def remote_run_epoch(self, epochs): for res in responses: func_duration = res.future.wait() - res.client.timing_data.append(TimingRecord(res.client.name, 'update_param_inner', func_duration)) - res.client.timing_data.append(TimingRecord(f'{res.client.name}', 'update_param_round_trip', res.duration())) + res.client.timing_data.append(TimingRecord(res.client.name, 'update_param_inner', func_duration, epochs_subset[0])) + res.client.timing_data.append(TimingRecord(res.client.name, 'update_param_round_trip', res.duration(), epochs_subset[0])) + communication_duration_2way = res.duration() - func_duration + res.client.timing_data.append( + TimingRecord(res.client.name, 'communication_2way', communication_duration_2way, epochs_subset[0])) logging.info('Weights are updated') # 3. Local training responses: List[AsyncCall] = [] client_weights = [] for client in selected_clients: - response = timed_remote_async_call(client, Client.run_epochs, client.ref, num_epoch=epochs) + response = timed_remote_async_call(client, Client.run_epochs, client.ref, num_epoch=len(epochs_subset)) responses.append(response) - self.epoch_counter += epochs + self.epoch_counter += len(epochs_subset) for res in responses: res.future.wait() - epoch_data, weights = res.future.wait() + epoch_data, weights, func_duration = res.future.wait() self.client_data[epoch_data.client_id].append(epoch_data) logging.info(f'{res.client.name} had a loss of {epoch_data.loss}') logging.info(f'{res.client.name} had a epoch data of {epoch_data}') - res.client.timing_data.append(TimingRecord(f'{res.client.name}', 'epoch_time_round_trip', res.duration())) + res.client.timing_data.append(TimingRecord(res.client.name, 'epoch_time_inner', func_duration, epochs_subset[0])) + res.client.timing_data.append(TimingRecord(res.client.name, 'epoch_time_train', epoch_data.duration_train, epochs_subset[0])) + res.client.timing_data.append(TimingRecord(res.client.name, 'epoch_time_test', epoch_data.duration_test, epochs_subset[0])) + res.client.timing_data.append(TimingRecord(res.client.name, 'epoch_time_round_trip', res.duration(), epochs_subset[0])) + communication_duration_2way = res.duration() - func_duration + res.client.timing_data.append(TimingRecord(res.client.name, 'communication_2way', communication_duration_2way, epochs_subset[0])) res.client.tb_writer.add_scalar('training loss', epoch_data.loss_train, # for every 1000 minibatches @@ -231,14 +239,17 @@ def run(self): self.clients_ready() self.update_client_data_sizes() - epoch_to_run = self.num_epoch - addition = 0 + + + # Get total epoch to run epoch_to_run = self.config.epochs epoch_size = self.config.epochs_per_cycle - for epoch in range(epoch_to_run): - logging.info(f'Running epoch {epoch}') - self.remote_run_epoch(epoch_size) - addition += 1 + + epochs = list(range(1, epoch_to_run + 1)) + epoch_chunks = [epochs[x:x + epoch_size] for x in range(0, len(epochs), epoch_size)] + for epoch_subset in epoch_chunks: + logging.info(f'Running epochs {epoch_subset}') + self.remote_run_epoch(epoch_subset) logging.info('Available clients with data') logging.info(self.client_data.keys()) diff --git a/fltk/util/results.py b/fltk/util/results.py index af560479..cf762b8a 100644 --- a/fltk/util/results.py +++ b/fltk/util/results.py @@ -4,8 +4,9 @@ @dataclass class EpochData: epoch_id: int - duration_train: int - duration_test: int + num_epochs: int + duration_train: float + duration_test: float loss_train: float accuracy: float loss: float