Skip to content

Commit

Permalink
Update Multi-Krum
Browse files Browse the repository at this point in the history
  • Loading branch information
JMGaljaard committed Jun 11, 2021
1 parent 14a27c2 commit 4a54061
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 88 deletions.
2 changes: 1 addition & 1 deletion fltk/federator.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def remote_run_epoch(self, epochs, ratio=None, store_grad=False):
self.epoch_counter)

client_weights.append(weights)
updated_model = self.antidote.process_gradients(client_weights)
updated_model = self.antidote.process_gradients(client_weights, epoch = self.epoch_counter)
self.test_data.net.load_state_dict(updated_model)
# test global model
logging.info("Testing on global test set")
Expand Down
89 changes: 54 additions & 35 deletions fltk/strategy/antidote.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import numpy as np
import torch

from fltk.client import Client
from fltk.nets.util.utils import flatten_params
from fltk.strategy.util.antidote import calc_krum_scores
from fltk.util.base_config import BareConfig
Expand All @@ -17,7 +18,7 @@ def __init__(self):
pass

@abstractmethod
def process_gradients(self, gradients):
def process_gradients(self, gradients, **kwargs):
pass

class DummyAntidote(Antidote):
Expand All @@ -26,7 +27,7 @@ def __init__(self, cfg: BareConfig):
Antidote.__init__(self)
pass

def process_gradients(self, gradients):
def process_gradients(self, gradients, **kwargs):
return average_nn_parameters(gradients)

class MultiKrumAntidote(Antidote):
Expand All @@ -36,78 +37,96 @@ def __init__(self, cfg: BareConfig, **kwargs):
self.f = cfg.get_antidote_f_value()
self.k = cfg.get_antidote_k_value()

def process_gradients(self, gradients):
def process_gradients(self, gradients, **kwargs):
"""
Function which returns the average of the k gradient with the lowest score.
"""
krum_scores = calc_krum_scores(gradients)
krum_scores = calc_krum_scores(gradients, self.f)

# Now take k closest entries
sorted_indices = np.argsort(krum_scores)[:self.k]
top_gradients = [gradients[top_k_index] for top_k_index in sorted_indices]
return average_nn_parameters(top_gradients)

class ClusterAntidote(Antidote):


@staticmethod
def ema(s_t_prev, value, t, rho, bias_correction = True):
s_t = rho * s_t_prev + (1 - rho) * value
s_t_hat = None
if bias_correction:
s_t_hat = s_t / (1.0 - rho**(t + 1))
return s_t_hat if bias_correction else s_t

def __init__(self, cfg: BareConfig, **kwargs):
Antidote.__init__(self)
self.f = cfg.get_antidote_f_value()
self.k = cfg.get_antidote_k_value()
self.past_gradients = np.array([])

# TODO: Not hardcode this for cifar10
self.class_targeted = np.zeros((10, cfg.epochs))
# Rho for this round poisoned
self.rho_1 = 0.5
# Rho for this class poisoned
self.rho_1 = 0.75
self.max_epoch = 130
self.num_classes = 10

def process_gradients(self, gradients):
def process_gradients(self, gradients, **kwargs):
"""
Function which returns the average of the k gradient with the lowest score.
"""
krum_scores = calc_krum_scores(gradients)
most_likely_good = np.argmax(krum_scores)
# Note gradients is a list of ordered dicts (pytorch state dicts)
new_connected_grads = [next(reversed(gradient.values())).numpy() for gradient in gradients]
self.past_gradients = np.stack([self.past_gradients] + new_connected_grads)

# TODO: Decide when to allow for performing the analysis.
# TODO: Decide how many runs you want to collect.

# TODO: Decide on how to get the number of classes
classes_ = 10
for cls in range(classes_):
epoch_indx = kwargs['epoch']
# First 10 epochs we effectively don't do much
if epoch_indx > 10:
new_connected_grads = [next(reversed(gradient.values())).numpy() for gradient in gradients]
self.past_gradients = np.stack([self.past_gradients] + new_connected_grads)
# If collected enough data, we continue to the next round
if epoch_indx > 20:
trusty_indices = self.target_malicious(gradients, epoch_indx)
return average_nn_parameters([gradients[indx] for indx in trusty_indices])
return average_nn_parameters(gradients)

def target_malicious(self, gradients, epoch_indx):
truthy_gradient = np.zeros((self.num_classes, len(gradients)), dtype=bool)
for cls in range(self.num_classes):
# Slice to get only the rows corresponding the the output node.
sub_sample = self.past_gradients[cls::classes_]
sub_sample = self.past_gradients[cls::self.num_classes]
clf = KMeans(2)
scaler = StandardScaler()
fitter = PCA(n_components=2)
scaled_param_diff = scaler.fit_transform(self.past_gradients)
scaled_param_diff = scaler.fit_transform(sub_sample)
dim_reduced_gradients = fitter.fit_transform(scaled_param_diff)
classified = clf.fit_transform(dim_reduced_gradients)

# Get the label assigned to the 'krum' vector.
estimated_cluster = classified[-(len(gradients) - most_likely_good)]

this_epoch = classified[-classes_:]
# TODO: decide how to
sum(this_epoch)

if flagged_updates:

# TODO: Do clustering

# T
# If total is roughly 50/50 then unlikely to be poisoned. Else likely to be poisoned
cluster_split = np.average(classified)
if 0.4 * epoch_indx * len(gradients) < cluster_split < 0.6 * len(gradients):
# Roughly 50/50 divided, so we assume valid updates.
# As such, we don't need to perform KRUM, as the distribution over the two clusters
# is arbitrary. Hence, we cannot distill much information from the assignment to one of the
# two clusters.
truthy_gradient[cls] = True
else:
krum_scores = calc_krum_scores(gradients)
most_likely_good = np.argmax(krum_scores)
# Get the label assigned to the 'krum' vector, either 1/0
estimated_cluster = classified[-(len(gradients) - most_likely_good)]
# Boolean array to indicate which belong to the same cluster.
truthy_gradient[cls] = classified[-len(gradients):] == estimated_cluster
# Only select the gradients that we suspect that are unaffected
# Take row-wise and, as such only a column that has only 'TRUE', will be selected using
# the argwhere, because True evaluates to True.
return np.argwhere(truthy_gradient)


def create_antidote(cfg: BareConfig, **kwargs) -> Antidote:
assert cfg is not None
if cfg.antidote is None:
return DummyAntidote(cfg)
antidote_mapper = {'dummy': DummyAntidote, 'multikrum': MultiKrumAntidote}
medicine_cabinet = {'dummy': DummyAntidote, 'multikrum': MultiKrumAntidote, 'cluster': ClusterAntidote}

antidote_class = antidote_mapper.get(cfg.get_antidote_type(), None)
antidote_class = medicine_cabinet.get(cfg.get_antidote_type(), None)

if not antidote_class is None:
antidote = antidote_class(cfg=cfg, **kwargs)
Expand Down
59 changes: 7 additions & 52 deletions notebooks/gradient-PCA.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
"from sklearn.preprocessing import StandardScaler\n",
"\n",
"from fltk.client import Client\n",
"from fltk.nets import Cifar10CNN\n",
"from fltk.util.base_config import BareConfig"
]
},
Expand Down Expand Up @@ -185,58 +186,11 @@
},
"outputs": [],
"source": [
"sharex, shareyimport\n",
"torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"\n",
"\n",
"class Cifar10CNN(nn.Module):\n",
"\n",
" def __init__(self):\n",
" super(Cifar10CNN, self).__init__()\n",
"\n",
" self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)\n",
" self.bn1 = nn.BatchNorm2d(32)\n",
" self.conv2 = nn.Conv2d(32, 32, kernel_size=3, padding=1)\n",
" self.bn2 = nn.BatchNorm2d(32)\n",
" self.pool1 = nn.MaxPool2d(kernel_size=2)\n",
"\n",
" self.conv3 = nn.Conv2d(32, 64, kernel_size=3, padding=1)\n",
" self.bn3 = nn.BatchNorm2d(64)\n",
" self.conv4 = nn.Conv2d(64, 64, kernel_size=3, padding=1)\n",
" self.bn4 = nn.BatchNorm2d(64)\n",
" self.pool2 = nn.MaxPool2d(kernel_size=2)\n",
"\n",
" self.conv5 = nn.Conv2d(64, 128, kernel_size=3, padding=1)\n",
" self.bn5 = nn.BatchNorm2d(128)\n",
" self.conv6 = nn.Conv2d(128, 128, kernel_size=3, padding=1)\n",
" self.bn6 = nn.BatchNorm2d(128)\n",
" self.pool3 = nn.MaxPool2d(kernel_size=2)\n",
"\n",
" self.fc1 = nn.Linear(128 * 4 * 4, 128)\n",
" self.fc2 = nn.Linear(128, 10)\n",
"\n",
" def forward(self, x):\n",
" x = self.bn1(F.relu(self.conv1(x)))\n",
" x = self.bn2(F.relu(self.conv2(x)))\n",
" x = self.pool1(x)\n",
"\n",
" x = self.bn3(F.relu(self.conv3(x)))\n",
" x = self.bn4(F.relu(self.conv4(x)))\n",
" x = self.pool2(x)\n",
"\n",
" x = self.bn5(F.relu(self.conv5(x)))\n",
" x = self.bn6(F.relu(self.conv6(x)))\n",
" x = self.pool3(x)\n",
"\n",
" x = x.view(-1, 128 * 4 * 4)\n",
"\n",
" x = self.fc1(x)\n",
" x = F.softmax(self.fc2(x))\n",
"\n",
" return x\n",
"\n",
"\n",
"def flatten_params(parameters):\n",
" \"\"\"\n",
Expand Down Expand Up @@ -643,11 +597,7 @@
"metadata": {},
"outputs": [],
"source": [
"# Models the empty language\n",
"regex_0 = re.compile(\"^$\")\n",
"regex_1 = re.compile(\"client[9]\")\n",
"regex_2 = re.compile(\"client[92]\")\n",
"regex_3 = re.compile(\"client[926]\")\n",
"\n",
"\n",
"\n",
"def plot_dataset(directories, poisoned, ratio):\n",
Expand Down Expand Up @@ -679,6 +629,11 @@
" ax[2].title.set_text('class 6')\n",
" plt.savefig(f'{ratio}.pdf')\n",
"\n",
"# Models the empty language\n",
"regex_0 = re.compile(\"^$\")\n",
"regex_1 = re.compile(\"client[9]\")\n",
"regex_2 = re.compile(\"client[92]\")\n",
"regex_3 = re.compile(\"client[926]\")\n",
"\n",
"first_ten_rounds = re.compile(\"\\/([0-9]|10)\\/\")\n",
"for regex, name in [(regex_1, \"0.1\")]:\n",
Expand Down

0 comments on commit 4a54061

Please sign in to comment.