diff --git a/fltk/client.py b/fltk/client.py index 12ad0b56..90acc8ed 100644 --- a/fltk/client.py +++ b/fltk/client.py @@ -250,25 +250,24 @@ def test(self): return accuracy, loss, class_precision, class_recall def run_epochs(self, num_epoch): - start_time_train = datetime.datetime.now() + start_time_train = time.time() self.dataset.get_train_sampler().set_epoch_size(num_epoch) loss, weights = self.train(self.epoch_counter) self.epoch_counter += num_epoch - elapsed_time_train = datetime.datetime.now() - start_time_train - train_time_ms = int(elapsed_time_train.total_seconds()*1000) + elapsed_train_time = time.time() - start_time_train - start_time_test = datetime.datetime.now() + start_time_test = time.time() accuracy, test_loss, class_precision, class_recall = self.test() - elapsed_time_test = datetime.datetime.now() - start_time_test - test_time_ms = int(elapsed_time_test.total_seconds()*1000) + elapsed_test_time = time.time() - start_time_test - data = EpochData(self.epoch_counter, train_time_ms, test_time_ms, loss, accuracy, test_loss, class_precision, class_recall, client_id=self.id) + data = EpochData(self.epoch_counter, num_epoch, elapsed_train_time, elapsed_test_time, loss, accuracy, test_loss, class_precision, class_recall, client_id=self.id) self.epoch_results.append(data) # Copy GPU tensors to CPU for k, v in weights.items(): weights[k] = v.cpu() - return data, weights + end_func_time = time.time() - start_time_train + return data, weights, end_func_time def save_model(self, epoch, suffix): """ diff --git a/fltk/federator.py b/fltk/federator.py index 47acf0af..fe41363f 100644 --- a/fltk/federator.py +++ b/fltk/federator.py @@ -107,7 +107,7 @@ def clients_ready(self): time.sleep(2) logging.info('All clients are ready') - def remote_run_epoch(self, epochs): + def remote_run_epoch(self, epochs_subset): """ Federated Learning steps: 1. Client selection @@ -131,25 +131,33 @@ def remote_run_epoch(self, epochs): for res in responses: func_duration = res.future.wait() - res.client.timing_data.append(TimingRecord(res.client.name, 'update_param_inner', func_duration)) - res.client.timing_data.append(TimingRecord(f'{res.client.name}', 'update_param_round_trip', res.duration())) + res.client.timing_data.append(TimingRecord(res.client.name, 'update_param_inner', func_duration, epochs_subset[0])) + res.client.timing_data.append(TimingRecord(res.client.name, 'update_param_round_trip', res.duration(), epochs_subset[0])) + communication_duration_2way = res.duration() - func_duration + res.client.timing_data.append( + TimingRecord(res.client.name, 'communication_2way', communication_duration_2way, epochs_subset[0])) logging.info('Weights are updated') # 3. Local training responses: List[AsyncCall] = [] client_weights = [] for client in selected_clients: - response = timed_remote_async_call(client, Client.run_epochs, client.ref, num_epoch=epochs) + response = timed_remote_async_call(client, Client.run_epochs, client.ref, num_epoch=len(epochs_subset)) responses.append(response) - self.epoch_counter += epochs + self.epoch_counter += len(epochs_subset) for res in responses: res.future.wait() - epoch_data, weights = res.future.wait() + epoch_data, weights, func_duration = res.future.wait() self.client_data[epoch_data.client_id].append(epoch_data) logging.info(f'{res.client.name} had a loss of {epoch_data.loss}') logging.info(f'{res.client.name} had a epoch data of {epoch_data}') - res.client.timing_data.append(TimingRecord(f'{res.client.name}', 'epoch_time_round_trip', res.duration())) + res.client.timing_data.append(TimingRecord(res.client.name, 'epoch_time_inner', func_duration, epochs_subset[0])) + res.client.timing_data.append(TimingRecord(res.client.name, 'epoch_time_train', epoch_data.duration_train, epochs_subset[0])) + res.client.timing_data.append(TimingRecord(res.client.name, 'epoch_time_test', epoch_data.duration_test, epochs_subset[0])) + res.client.timing_data.append(TimingRecord(res.client.name, 'epoch_time_round_trip', res.duration(), epochs_subset[0])) + communication_duration_2way = res.duration() - func_duration + res.client.timing_data.append(TimingRecord(res.client.name, 'communication_2way', communication_duration_2way, epochs_subset[0])) res.client.tb_writer.add_scalar('training loss', epoch_data.loss_train, # for every 1000 minibatches @@ -231,14 +239,17 @@ def run(self): self.clients_ready() self.update_client_data_sizes() - epoch_to_run = self.num_epoch - addition = 0 + + + # Get total epoch to run epoch_to_run = self.config.epochs epoch_size = self.config.epochs_per_cycle - for epoch in range(epoch_to_run): - logging.info(f'Running epoch {epoch}') - self.remote_run_epoch(epoch_size) - addition += 1 + + epochs = list(range(1, epoch_to_run + 1)) + epoch_chunks = [epochs[x:x + epoch_size] for x in range(0, len(epochs), epoch_size)] + for epoch_subset in epoch_chunks: + logging.info(f'Running epochs {epoch_subset}') + self.remote_run_epoch(epoch_subset) logging.info('Available clients with data') logging.info(self.client_data.keys()) diff --git a/fltk/util/results.py b/fltk/util/results.py index af560479..cf762b8a 100644 --- a/fltk/util/results.py +++ b/fltk/util/results.py @@ -4,8 +4,9 @@ @dataclass class EpochData: epoch_id: int - duration_train: int - duration_test: int + num_epochs: int + duration_train: float + duration_test: float loss_train: float accuracy: float loss: float