diff --git a/fltk/core/client.py b/fltk/core/client.py index 011c8660..f0aaef2f 100644 --- a/fltk/core/client.py +++ b/fltk/core/client.py @@ -1,4 +1,8 @@ from __future__ import annotations + +import gc +import multiprocessing +import queue from typing import Tuple, Any import numpy as np @@ -21,6 +25,8 @@ class Client(Node): Federated experiment client. """ running = False + request_queue = queue.Queue() + result_queue = queue.Queue() def __init__(self, identifier: str, rank: int, world_size: int, config: FedLearnerConfig): super().__init__(identifier, rank, world_size, config) @@ -44,7 +50,7 @@ def remote_registration(self): self.message('federator', 'ping', 'new_sender') self.message('federator', 'register_client', self.id, self.rank) self.running = True - self._event_loop() + # self._event_loop() def stop_client(self): """ @@ -57,9 +63,12 @@ def stop_client(self): self.running = False def _event_loop(self): + return self.logger.info('Starting event loop') while self.running: - time.sleep(0.1) + if not self.request_queue.empty(): + self.result_queue.put(self.exec_round(*self.request_queue.get())) + time.sleep(5) self.logger.info('Exiting node') def train(self, num_epochs: int, round_id: int): @@ -89,28 +98,27 @@ def train(self, num_epochs: int, round_id: int): training_cardinality = len(self.dataset.get_train_loader()) self.logger.info(f'{progress}{self.id}: Number of training samples: {training_cardinality}') for i, (inputs, labels) in enumerate(self.dataset.get_train_loader(), 0): - inputs, labels = inputs.to(self.device), labels.to(self.device) - # zero the parameter gradients self.optimizer.zero_grad() outputs = self.net(inputs) loss = self.loss_function(outputs, labels) - + running_loss += loss.detach().item() loss.backward() self.optimizer.step() - running_loss += loss.item() # Mark logging update step if i % self.config.log_interval == 0: self.logger.info( f'[{self.id}] [{local_epoch}/{num_epochs:d}, {i:5d}] loss: {running_loss / self.config.log_interval:.3f}') final_running_loss = running_loss / self.config.log_interval running_loss = 0.0 + del loss, inputs, labels + end_time = time.time() duration = end_time - start_time self.logger.info(f'{progress} Train duration is {duration} seconds') - return final_running_loss, self.get_nn_parameters(), + return final_running_loss, self.get_nn_parameters() def set_tau_eff(self, total): client_weight = self.get_client_datasize() / total @@ -141,12 +149,12 @@ def test(self) -> Tuple[float, float, np.array]: _, predicted = torch.max(outputs.data, 1) # pylint: disable=no-member total += labels.size(0) - correct += (predicted == labels).sum().item() + correct += (predicted == labels).sum().detach().item() targets_.extend(labels.cpu().view_as(predicted).numpy()) pred_.extend(predicted.cpu().numpy()) - loss += self.loss_function(outputs, labels).item() + loss += self.loss_function(outputs, labels).detach().item() # Calculate learning statistics loss /= len(self.dataset.get_test_loader().dataset) accuracy = 100.0 * correct / total @@ -156,11 +164,30 @@ def test(self) -> Tuple[float, float, np.array]: end_time = time.time() duration = end_time - start_time self.logger.info(f'Test duration is {duration} seconds') + del targets_, pred_ return accuracy, loss, confusion_mat def get_client_datasize(self): # pylint: disable=missing-function-docstring return len(self.dataset.get_train_sampler()) + def run(self): + event = multiprocessing.Event() + while self.running: + if not self.request_queue.empty(): + self.logger.info("Got request, running synchronously") + request = self.request_queue.get() + self.result_queue.put(self.exec_round(*request)) + event.wait(1) + + def request_round(self, num_epochs: int, round_id:int): + event = multiprocessing.Event() + self.request_queue.put([num_epochs, round_id]) + + while self.result_queue.empty(): + event.wait(5) + self.logger.info("Finished request!") + return self.result_queue.get() + def exec_round(self, num_epochs: int, round_id: int) -> Tuple[Any, Any, Any, Any, float, float, float, np.array]: """ Function as access point for the Federator Node to kick off a remote learning round on a client. @@ -186,6 +213,7 @@ def exec_round(self, num_epochs: int, round_id: int) -> Tuple[Any, Any, Any, Any self.optimizer.pre_communicate() for k, value in weights.items(): weights[k] = value.cpu() + gc.collect() return loss, weights, accuracy, test_loss, round_duration, train_duration, test_duration, test_conf_matrix def __del__(self): diff --git a/fltk/core/federator.py b/fltk/core/federator.py index af29d968..839fff23 100644 --- a/fltk/core/federator.py +++ b/fltk/core/federator.py @@ -288,7 +288,7 @@ def training_cb(fut: torch.Future, client_ref: LocalClient, client_weights, clie client_ref.exp_data.append(c_record) for client in selected_clients: - future = self.message_async(client.ref, Client.exec_round, num_epochs, com_round_id) + future = self.message_async(client.ref, Client.request_round, num_epochs, com_round_id) cb_factory(future, training_cb, client, client_weights, client_sizes, num_epochs) self.logger.info(f'Request sent to client {client.name}') training_futures.append(future) diff --git a/fltk/core/node.py b/fltk/core/node.py index eea1ede9..fd844f4a 100644 --- a/fltk/core/node.py +++ b/fltk/core/node.py @@ -2,10 +2,10 @@ import abc import copy +import gc import os from typing import Callable, Any, Union -import deprecate import torch from torch.distributed import rpc from fltk.datasets.federated import get_fed_dataset @@ -56,6 +56,7 @@ def _config(self, config: FedLearnerConfig): self.device = self.init_device() self.distributed = config.distributed self.net = get_net(self.config.net_name)() + self.net.to(self.device) def init_dataloader(self, world_size: int = None): """ @@ -125,7 +126,7 @@ def set_net(self, net): :param net: """ self.net = net - # self.net.to(self.device) + self.net.to(self.device) def get_nn_parameters(self): """ @@ -153,19 +154,17 @@ def load_model_from_file(self, model_file_path): self.logger.warning(f"Could not find model: {model_file_path}") return model - def update_nn_parameters(self, new_params, is_offloaded_model = False): + def update_nn_parameters(self, new_params): """ Update the NN's parameters by parameters provided by Federator. :param new_params: New weights for the neural network :type new_params: dict """ - if is_offloaded_model: - pass - # self.offloaded_net.load_state_dict(copy.deepcopy(new_params), strict=True) - else: - self.logger.info("Updating parameters") - self.net.load_state_dict(copy.deepcopy(new_params), strict=True) + self.logger.info("Updating parameters") + self.net.load_state_dict(copy.deepcopy(new_params), strict=True) + del new_params + gc.collect() def message(self, other_node: str, method: Union[Callable, str], *args, **kwargs) -> torch.Future: # pylint: disable=no-member """