Skip to content

Commit

Permalink
Improve debugging for Federated experiments
Browse files Browse the repository at this point in the history
  • Loading branch information
JMGaljaard committed Sep 18, 2022
1 parent 0b5ec3a commit 4830779
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 19 deletions.
46 changes: 37 additions & 9 deletions fltk/core/client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
from __future__ import annotations

import gc
import multiprocessing
import queue
from typing import Tuple, Any

import numpy as np
Expand All @@ -21,6 +25,8 @@ class Client(Node):
Federated experiment client.
"""
running = False
request_queue = queue.Queue()
result_queue = queue.Queue()

def __init__(self, identifier: str, rank: int, world_size: int, config: FedLearnerConfig):
super().__init__(identifier, rank, world_size, config)
Expand All @@ -44,7 +50,7 @@ def remote_registration(self):
self.message('federator', 'ping', 'new_sender')
self.message('federator', 'register_client', self.id, self.rank)
self.running = True
self._event_loop()
# self._event_loop()

def stop_client(self):
"""
Expand All @@ -57,9 +63,12 @@ def stop_client(self):
self.running = False

def _event_loop(self):
return
self.logger.info('Starting event loop')
while self.running:
time.sleep(0.1)
if not self.request_queue.empty():
self.result_queue.put(self.exec_round(*self.request_queue.get()))
time.sleep(5)
self.logger.info('Exiting node')

def train(self, num_epochs: int, round_id: int):
Expand Down Expand Up @@ -89,28 +98,27 @@ def train(self, num_epochs: int, round_id: int):
training_cardinality = len(self.dataset.get_train_loader())
self.logger.info(f'{progress}{self.id}: Number of training samples: {training_cardinality}')
for i, (inputs, labels) in enumerate(self.dataset.get_train_loader(), 0):
inputs, labels = inputs.to(self.device), labels.to(self.device)

# zero the parameter gradients
self.optimizer.zero_grad()

outputs = self.net(inputs)
loss = self.loss_function(outputs, labels)

running_loss += loss.detach().item()
loss.backward()
self.optimizer.step()
running_loss += loss.item()
# Mark logging update step
if i % self.config.log_interval == 0:
self.logger.info(
f'[{self.id}] [{local_epoch}/{num_epochs:d}, {i:5d}] loss: {running_loss / self.config.log_interval:.3f}')
final_running_loss = running_loss / self.config.log_interval
running_loss = 0.0
del loss, inputs, labels

end_time = time.time()
duration = end_time - start_time
self.logger.info(f'{progress} Train duration is {duration} seconds')

return final_running_loss, self.get_nn_parameters(),
return final_running_loss, self.get_nn_parameters()

def set_tau_eff(self, total):
client_weight = self.get_client_datasize() / total
Expand Down Expand Up @@ -141,12 +149,12 @@ def test(self) -> Tuple[float, float, np.array]:

_, predicted = torch.max(outputs.data, 1) # pylint: disable=no-member
total += labels.size(0)
correct += (predicted == labels).sum().item()
correct += (predicted == labels).sum().detach().item()

targets_.extend(labels.cpu().view_as(predicted).numpy())
pred_.extend(predicted.cpu().numpy())

loss += self.loss_function(outputs, labels).item()
loss += self.loss_function(outputs, labels).detach().item()
# Calculate learning statistics
loss /= len(self.dataset.get_test_loader().dataset)
accuracy = 100.0 * correct / total
Expand All @@ -156,11 +164,30 @@ def test(self) -> Tuple[float, float, np.array]:
end_time = time.time()
duration = end_time - start_time
self.logger.info(f'Test duration is {duration} seconds')
del targets_, pred_
return accuracy, loss, confusion_mat

def get_client_datasize(self): # pylint: disable=missing-function-docstring
return len(self.dataset.get_train_sampler())

def run(self):
event = multiprocessing.Event()
while self.running:
if not self.request_queue.empty():
self.logger.info("Got request, running synchronously")
request = self.request_queue.get()
self.result_queue.put(self.exec_round(*request))
event.wait(1)

def request_round(self, num_epochs: int, round_id:int):
event = multiprocessing.Event()
self.request_queue.put([num_epochs, round_id])

while self.result_queue.empty():
event.wait(5)
self.logger.info("Finished request!")
return self.result_queue.get()

def exec_round(self, num_epochs: int, round_id: int) -> Tuple[Any, Any, Any, Any, float, float, float, np.array]:
"""
Function as access point for the Federator Node to kick off a remote learning round on a client.
Expand All @@ -186,6 +213,7 @@ def exec_round(self, num_epochs: int, round_id: int) -> Tuple[Any, Any, Any, Any
self.optimizer.pre_communicate()
for k, value in weights.items():
weights[k] = value.cpu()
gc.collect()
return loss, weights, accuracy, test_loss, round_duration, train_duration, test_duration, test_conf_matrix

def __del__(self):
Expand Down
2 changes: 1 addition & 1 deletion fltk/core/federator.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ def training_cb(fut: torch.Future, client_ref: LocalClient, client_weights, clie
client_ref.exp_data.append(c_record)

for client in selected_clients:
future = self.message_async(client.ref, Client.exec_round, num_epochs, com_round_id)
future = self.message_async(client.ref, Client.request_round, num_epochs, com_round_id)
cb_factory(future, training_cb, client, client_weights, client_sizes, num_epochs)
self.logger.info(f'Request sent to client {client.name}')
training_futures.append(future)
Expand Down
17 changes: 8 additions & 9 deletions fltk/core/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@

import abc
import copy
import gc
import os
from typing import Callable, Any, Union

import deprecate
import torch
from torch.distributed import rpc
from fltk.datasets.federated import get_fed_dataset
Expand Down Expand Up @@ -56,6 +56,7 @@ def _config(self, config: FedLearnerConfig):
self.device = self.init_device()
self.distributed = config.distributed
self.net = get_net(self.config.net_name)()
self.net.to(self.device)

def init_dataloader(self, world_size: int = None):
"""
Expand Down Expand Up @@ -125,7 +126,7 @@ def set_net(self, net):
:param net:
"""
self.net = net
# self.net.to(self.device)
self.net.to(self.device)

def get_nn_parameters(self):
"""
Expand Down Expand Up @@ -153,19 +154,17 @@ def load_model_from_file(self, model_file_path):
self.logger.warning(f"Could not find model: {model_file_path}")
return model

def update_nn_parameters(self, new_params, is_offloaded_model = False):
def update_nn_parameters(self, new_params):
"""
Update the NN's parameters by parameters provided by Federator.
:param new_params: New weights for the neural network
:type new_params: dict
"""
if is_offloaded_model:
pass
# self.offloaded_net.load_state_dict(copy.deepcopy(new_params), strict=True)
else:
self.logger.info("Updating parameters")
self.net.load_state_dict(copy.deepcopy(new_params), strict=True)
self.logger.info("Updating parameters")
self.net.load_state_dict(copy.deepcopy(new_params), strict=True)
del new_params
gc.collect()

def message(self, other_node: str, method: Union[Callable, str], *args, **kwargs) -> torch.Future: # pylint: disable=no-member
"""
Expand Down

0 comments on commit 4830779

Please sign in to comment.