Skip to content

Commit

Permalink
Resolve gradient update later epoch issues
Browse files Browse the repository at this point in the history
  • Loading branch information
JMGaljaard committed Jun 8, 2021
1 parent f26a8fc commit f486c26
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 44 deletions.
2 changes: 2 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ RUN --mount=type=cache,target=/root/.cache/pip python3 -m pip install -r require
ADD configs configs

ADD fltk fltk

# Update relevant runtime configuration for experiment
COPY cloud_configs/cloud_experiment.yaml configs/cloud_config.yaml
# Install newest version of library
RUN python3 -m setup install
Expand Down
31 changes: 31 additions & 0 deletions cloud_configs/cloud_experiment.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Experiment configuration
total_epochs: 130
epochs_per_cycle: 1
wait_for_clients: true
net: Cifar10CNN
dataset: cifar10
sampler: "uniform"
sampler_args:
- 0.5 # p degree
- 42 # random seed
# Use cuda is available; setting to false will force CPU
cuda: false
experiment_prefix: 'experiment_single_machine'
output_location: 'output'
tensor_board_active: true
clients_per_round: 50
system:
federator:
# Use the SERVICE provided by the fl-server to connect
hostname: 'fl-server.test.svc.cluster.local'
# Default NIC is eth0
nic: 'eth0'
clients:
amount: 50
poison:
seed: 420
ratio: 0.2
attack:
type: "flip"
config:
- 5: 3
6 changes: 3 additions & 3 deletions configs/local_experiment.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Experiment configuration
total_epochs: 1
total_epochs: 3
epochs_per_cycle: 1
wait_for_clients: true
net: Cifar10CNN
Expand All @@ -13,15 +13,15 @@ cuda: false
experiment_prefix: 'experiment_single_machine'
output_location: 'output'
tensor_board_active: true
clients_per_round: 1
clients_per_round: 2
system:
federator:
# Use the SERVICE provided by the fl-server to connect
hostname: '172.18.0.2'
# Default NIC is eth0
nic: 'eth0'
clients:
amount: 1
amount: 2
# For a simple config is provided in configs/poison.example.yaml
poison:
seed: 420
Expand Down
53 changes: 28 additions & 25 deletions fltk/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class Client:

def __init__(self, id, log_rref, rank, world_size, config: BareConfig = None):
logging.info(f'Welcome to client {id}')
self.net = None
self.net: torch.nn.Module = None
self.id = id
self.log_rref = log_rref
self.rank = rank
Expand Down Expand Up @@ -89,26 +89,28 @@ def reset_model(self):
@return: None
@rtype: None
"""
# Set loss function for gradient calculation
self.loss_function = self.args.get_loss_function()()
# Create optimizer (default is SGD): TODO: Move to AdamW?
self.optimizer = torch.optim.SGD(self.net.parameters(),
lr=self.args.get_learning_rate(),
momentum=self.args.get_momentum())
self.scheduler = MinCapableStepLR(self.args.get_logger(), self.optimizer,
self.args.get_scheduler_step_size(),
self.args.get_scheduler_gamma(),
self.args.get_min_lr())
# Reset logger
self.args.init_logger(logging)
# Reset the epoch counter
self.epoch_counter = 0
self.finished_init = False
# Dataset will be re-initialized
# Dataset will be re-initialized so save memory
del self.dataset
# This will be set afterwards, but we delete possible gradient information.
del self.net
self.set_net(self.load_default_model())
self.net.requires_grad_(True)

# Set loss function for gradient calculation
self.loss_function = self.args.get_loss_function()()

self.optimizer = torch.optim.SGD(self.net.parameters(),
lr=self.args.get_learning_rate(),
momentum=self.args.get_momentum())
self.scheduler = MinCapableStepLR(self.args.get_logger(), self.optimizer,
self.args.get_scheduler_step_size(),
self.args.get_scheduler_gamma(),
self.args.get_min_lr())

def ping(self):
"""
Expand Down Expand Up @@ -243,7 +245,7 @@ def train(self, epoch, pill: PoisonPill = None):
if self.args.distributed:
self.dataset.train_sampler.set_epoch(epoch)


self.net.train()
for i, (inputs, labels) in enumerate(self.dataset.get_train_loader(), 0):
inputs, labels = inputs.to(self.device), labels.to(self.device)
# TODO: check if these parameters are correct, labels or ouputs?
Expand All @@ -252,7 +254,7 @@ def train(self, epoch, pill: PoisonPill = None):
inputs, labels = pill.poison_output(inputs, labels)

# zero the parameter gradients
self.optimizer.zero_grad(set_to_none=True)
self.optimizer.zero_grad()

# forward + backward + optimize

Expand Down Expand Up @@ -285,20 +287,21 @@ def test(self):
targets_ = []
pred_ = []
loss = 0.0
with torch.no_grad():
for (images, labels) in self.dataset.get_test_loader():
images, labels = images.to(self.device), labels.to(self.device)
self.net.eval()

for (images, labels) in self.dataset.get_test_loader():
images, labels = images.to(self.device), labels.to(self.device)

outputs = self.net(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
# TODO: Log the information regarding the poisoned accuracy
correct += (predicted == labels).sum().item()
outputs = self.net(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
# TODO: Log the information regarding the poisoned accuracy
correct += (predicted == labels).sum().item()

targets_.extend(labels.cpu().view_as(predicted).numpy())
pred_.extend(predicted.cpu().numpy())
targets_.extend(labels.cpu().view_as(predicted).numpy())
pred_.extend(predicted.cpu().numpy())

loss += self.loss_function(outputs, labels).item()
loss += self.loss_function(outputs, labels).item()

accuracy = 100 * correct / total
confusion_mat = confusion_matrix(targets_, pred_)
Expand Down
29 changes: 13 additions & 16 deletions fltk/federator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import math
import pathlib
import time
from pathlib import Path
Expand Down Expand Up @@ -184,7 +185,7 @@ def clients_ready(self):
time.sleep(2)
logging.info('All clients are ready')

def remote_run_epoch(self, epochs, cur_model: torch.nn.Module, ratio = None, store_grad=False):
def remote_run_epoch(self, epochs, ratio = None, store_grad=False):
responses = []
client_weights = []
selected_clients = self.select_clients(self.config.clients_per_round)
Expand All @@ -200,17 +201,18 @@ def remote_run_epoch(self, epochs, cur_model: torch.nn.Module, ratio = None, sto
responses.append((client, _remote_method_async(Client.run_epochs, client.ref, num_epoch=epochs, pill=pill)))
self.epoch_counter += epochs

accuracy, loss, class_precision, class_recall = self.test_data.test()
# self.tb_writer.add_scalar('training loss', loss, self.epoch_counter * self.test_data.get_client_datasize()) # does not seem to work :( )
self.tb_writer.add_scalar('accuracy', accuracy, self.epoch_counter * self.test_data.get_client_datasize())
self.tb_writer.add_scalar('accuracy per epoch', accuracy, self.epoch_counter)
try:
# Test the model before waiting for the model.
self.test_model()
except Exception as e:
print(e)

flat_current = None

# Test the model before waiting for the model.
self.test_model()


if store_grad:
flat_current = flatten_params(cur_model.state_dict())
flat_current = flatten_params(self.test_data.net.state_dict())
for res in responses:
epoch_data, weights = res[1].wait()
if store_grad:
Expand Down Expand Up @@ -238,7 +240,7 @@ def remote_run_epoch(self, epochs, cur_model: torch.nn.Module, ratio = None, sto

client_weights.append(weights)
updated_model = average_nn_parameters(client_weights)

self.test_data.net.load_state_dict(updated_model)
# test global model
logging.info("Testing on global test set")
self.test_data.update_nn_parameters(updated_model)
Expand Down Expand Up @@ -289,11 +291,7 @@ def run(self, ratios = [0.0, 0.06, 0.12, 0.18]):
poison_pill = None
save_path = self.config
for rat in ratios:
# Get model to calculate gradient updates, default is shared between all.
model = initialize_default_model(self.config, self.config.get_net())
# Re-use the functionality to update
self.distribute_new_model(model.state_dict())

self.test_data.net = initialize_default_model(self.config, self.config.get_net())
# Update the clients to point to the newer version.
self.update_clients(rat)
if self.attack:
Expand All @@ -306,7 +304,6 @@ def run(self, ratios = [0.0, 0.06, 0.12, 0.18]):
self.ping_all()
self.clients_ready()
self.update_client_data_sizes()

addition = 0
epoch_to_run = self.config.epochs
epoch_size = self.config.epochs_per_cycle
Expand All @@ -315,7 +312,7 @@ def run(self, ratios = [0.0, 0.06, 0.12, 0.18]):
print(f'Running epoch {epoch}')
# Get new model during run, update iteratively. The model is needed to calculate the
# gradient by the federator.
model = self.remote_run_epoch(epoch_size, model, rat)
self.remote_run_epoch(epoch_size, rat)
addition += 1
logging.info('Printing client data')
print(self.client_data)
Expand Down

0 comments on commit f486c26

Please sign in to comment.