Skip to content

Commit

Permalink
Make log level configurable
Browse files Browse the repository at this point in the history
  • Loading branch information
bacox committed Feb 21, 2022
1 parent 8ec6104 commit d434239
Show file tree
Hide file tree
Showing 8 changed files with 81 additions and 22 deletions.
16 changes: 15 additions & 1 deletion fltk/core/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,15 @@ def train(self, num_epochs: int):

return final_running_loss, self.get_nn_parameters(),

def set_tau_eff(self, total):
client_weight = self.get_client_datasize() / total
n = self.get_client_datasize()
E = self.config.epochs
B = 16 # nicely hardcoded :)
tau_eff = int(E * n / B) * client_weight
if hasattr(self.optimizer, 'set_tau_eff'):
self.optimizer.set_tau_eff(tau_eff)

def test(self):
start_time = time.time()
correct = 0
Expand All @@ -80,7 +89,7 @@ def test(self):

loss += self.loss_function(outputs, labels).item()
loss /= len(self.dataset.get_test_loader().dataset)
accuracy = 100 * correct / total
accuracy = 100.0 * correct / total
# confusion_mat = confusion_matrix(targets_, pred_)
# accuracy_per_class = confusion_mat.diagonal() / confusion_mat.sum(1)
#
Expand All @@ -105,6 +114,11 @@ def exec_round(self, num_epochs: int) -> Tuple[float, Any, float, float]:
end = time.time()
duration = end - start
# self.logger.info(f'Round duration is {duration} seconds')

if hasattr(self.optimizer, 'pre_communicate'): # aka fednova or fedprox
self.optimizer.pre_communicate()
for k, v in weights.items():
weights[k] = v.cpu()
return loss, weights, accuracy, test_loss

def __del__(self):
Expand Down
51 changes: 38 additions & 13 deletions fltk/core/federator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,22 @@
from fltk.datasets.loader_util import get_dataset
from fltk.strategy import FedAvg, random_selection, average_nn_parameters, average_nn_parameters_simple
from fltk.util.config import Config

from dataclasses import dataclass

NodeReference = Union[Node, str]
@dataclass
class LocalClient:
name: str
ref: NodeReference
data_size: int


class Federator(Node):
clients: List[NodeReference] = []
clients: List[LocalClient] = []
# clients: List[NodeReference] = []
num_rounds: int


def __init__(self, id: int, rank: int, world_size: int, config: Config):
super().__init__(id, rank, world_size, config)
self.loss_function = self.config.get_loss_function()()
Expand All @@ -32,12 +39,13 @@ def create_clients(self):
world_size = self.config.num_clients + 1
for client_id in range(1, self.config.num_clients+ 1):
client_name = f'client{client_id}'
self.clients.append(Client(client_name, client_id, world_size, copy.deepcopy(self.config)))
client = Client(client_name, client_id, world_size, copy.deepcopy(self.config))
self.clients.append(LocalClient(client_name, client, 0))

def register_client(self, client_name, rank):
if self.config.single_machine:
self.logger.warning('This function should not be called when in single machine mode!')
self.clients.append(client_name)
self.clients.append(LocalClient(client_name, client_name, 0))

def _num_clients_online(self) -> int:
return len(self.clients)
Expand All @@ -55,21 +63,29 @@ def clients_ready(self):
responses = []
all_ready = True
for client in self.clients:
resp = self.message(client, Client.is_ready)
resp = self.message(client.ref, Client.is_ready)
if resp:
self.logger.info(f'Client {client} is ready')
else:
self.logger.info(f'Waiting for client {client}')
all_ready = False
time.sleep(2)

def get_client_data_sizes(self):
for client in self.clients:
client.data_size = self.message(client.ref, Client.get_client_datasize)

def run(self):
self.init_dataloader()
# Load dataset with world size 2 to load the whole dataset.
# Caused by the fact that the dataloader subtracts 1 from the world size to exclude the federator by default.
self.init_dataloader(world_size=2)

self.create_clients()
while not self._all_clients_online():
self.logger.info(f'Waiting for all clients to come online. Waiting for {self.world_size - 1 -self._num_clients_online()} clients')
time.sleep(2)
self.client_load_data()
self.get_client_data_sizes()
self.clients_ready()

for communications_round in range(self.config.rounds):
Expand All @@ -79,7 +95,15 @@ def run(self):

def client_load_data(self):
for client in self.clients:
self.message(client, Client.init_dataloader)
self.message(client.ref, Client.init_dataloader)

def set_tau_eff(self):
total = sum(client.data_size for client in self.clients)
# responses = []
for client in self.clients:
self.message(client.ref, Client.set_tau_eff, client.ref, total)
# responses.append((client, _remote_method_async(Client.set_tau_eff, client.ref, total)))
# torch.futures.wait_all([x[1] for x in responses])

def test(self, net):
start_time = time.time()
Expand All @@ -103,7 +127,7 @@ def test(self, net):

loss += self.loss_function(outputs, labels).item()
loss /= len(self.dataset.get_test_loader().dataset)
accuracy = 100 * correct / total
accuracy = 100.0 * correct / total
# confusion_mat = confusion_matrix(targets_, pred_)
# accuracy_per_class = confusion_mat.diagonal() / confusion_mat.sum(1)
#
Expand All @@ -119,20 +143,21 @@ def exec_round(self):
num_epochs = self.config.epochs

# Client selection
selected_clients: List[LocalClient]
selected_clients = random_selection(self.clients, self.config.clients_per_round)

last_model = self.get_nn_parameters()
for client in selected_clients:
self.message(client, Client.update_nn_parameters, last_model)
self.message(client.ref, Client.update_nn_parameters, last_model)

# Actual training calls
client_weights = {}
client_sizes = {}
for client in selected_clients:
train_loss, weights, accuracy, test_loss = self.message(client, Client.exec_round, num_epochs)
client_weights[client] = weights
client_data_size = self.message(client, Client.get_client_datasize)
client_sizes[client] = client_data_size
train_loss, weights, accuracy, test_loss = self.message(client.ref, Client.exec_round, num_epochs)
client_weights[client.name] = weights
client_data_size = self.message(client.ref, Client.get_client_datasize)
client_sizes[client.name] = client_data_size
self.logger.info(f'Client {client} has a accuracy of {accuracy}, train loss={train_loss}, test loss={test_loss},datasize={client_data_size}')

# updated_model = FedAvg(client_weights, client_sizes)
Expand Down
10 changes: 7 additions & 3 deletions fltk/core/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,16 +51,20 @@ def __init__(self, id: int, rank: int, world_size: int, config: Config):
self._config(config)

def _config(self, config: Config):
self.logger.setLevel(config.log_level.value)
self.config.rank = self.rank
self.config.world_size = self.world_size
self.cuda = config.cuda
self.init_device()
self.distributed = config.distributed
self.set_net(self.load_default_model())

def init_dataloader(self):
self.logger.info(f'world size = {self.config.world_size} with rank={self.config.rank}')
self.dataset = get_dataset(self.config.dataset_name)(self.config)
def init_dataloader(self, world_size: int = None):
config = copy.deepcopy(self.config)
if world_size:
config.world_size = world_size
self.logger.info(f'world size = {config.world_size} with rank={config.rank}')
self.dataset = get_dataset(config.dataset_name)(config)
self.finished_init = True
self.logger.info('Done with init')

Expand Down
2 changes: 1 addition & 1 deletion fltk/samplers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def get_sampler(dataset, args):
sampler = None
if args.get_distributed():
method = args.get_sampler()
logger.info(
logger.debug(
"Using {} sampler method, with args: {}".format(method, args.get_sampler_args()))

if method == DataSampler.uniform:
Expand Down
2 changes: 1 addition & 1 deletion fltk/samplers/uniform.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@ class UniformSampler(DistributedSamplerWrapper):
def __init__(self, dataset, num_replicas=None, rank=None, seed=0):
super().__init__(dataset, num_replicas=num_replicas, rank=rank, seed=seed)
indices = list(range(len(self.dataset)))
self.indices = indices[self.rank:self.total_size:self.num_replicas]
self.indices = indices[self.rank:self.total_size:self.n_clients]
4 changes: 3 additions & 1 deletion fltk/util/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import torch

from fltk.util.definitions import Dataset, Nets, DataSampler, Optimizations
from fltk.util.definitions import Dataset, Nets, DataSampler, Optimizations, LogLevel


@dataclass
Expand All @@ -26,6 +26,8 @@ class Config:
}
loss_function = torch.nn.CrossEntropyLoss

log_level: LogLevel = LogLevel.DEBUG

num_clients: int = 10
clients_per_round: int = 2
distributed: bool = True
Expand Down
11 changes: 11 additions & 0 deletions fltk/util/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,17 @@ class Dataset(Enum):
mnist = 'mnist'


class LogLevel(Enum):
CRITICAL = 50
FATAL = CRITICAL
ERROR = 40
WARNING = 30
WARN = WARNING
INFO = 20
DEBUG = 10
NOTSET = 0


class Aggregations(Enum):
avg = 'Avg'
fed_avg = 'FedAvg'
Expand Down
7 changes: 5 additions & 2 deletions fltk/util/log.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,19 @@

from torch.distributed import rpc

from fltk.util.definitions import LogLevel


class FLLogger:
@staticmethod
@rpc.functions.async_execution
def log(arg1, node_id, log_line, report_time):
logging.info(f'[{node_id}: {report_time}]: {log_line}')


def getLogger(module_name):
def getLogger(module_name, level: LogLevel = LogLevel.INFO):
logging.basicConfig(
level=logging.DEBUG,
level=level.value,
format='%(asctime)s %(levelname)s %(module)s - %(funcName)s: %(message)s',
)
return logging.getLogger(module_name)

0 comments on commit d434239

Please sign in to comment.