diff --git a/fltk/client.py b/fltk/client.py index bde8cb45..f5ae6c2f 100644 --- a/fltk/client.py +++ b/fltk/client.py @@ -194,12 +194,15 @@ def train(self, epoch): for i, (inputs, labels) in enumerate(self.dataset.get_train_loader(), 0): # TODO: Implement swap based on received attack. inputs, labels = inputs.to(self.device), labels.to(self.device) + # TODO: call Pill attack with input & output # zero the parameter gradients self.optimizer.zero_grad() # forward + backward + optimize + outputs = self.net(inputs) + loss = self.loss_function(outputs, labels) loss.backward() self.optimizer.step() @@ -234,6 +237,7 @@ def test(self): outputs = self.net(images) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) + # TODO: Log the information regarding the poisoned accuracy correct += (predicted == labels).sum().item() targets_.extend(labels.cpu().view_as(predicted).numpy())