Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix a bug in models/client.py #39

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 21 additions & 6 deletions models/client.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,24 @@
import random
import warnings
import importlib
import copy


class Client:

def __init__(self, client_id, group=None, train_data={'x' : [],'y' : []}, eval_data={'x' : [],'y' : []}, model=None):

def __init__(self, client_id, group=None, train_data={'x': [], 'y': []}, eval_data={'x': [], 'y': []},
model_info=None):
model_path = model_info['model_path']
seed = model_info['seed']
model_params = model_info['model_params']
mod = importlib.import_module(model_path)
ClientModel = getattr(mod, 'ClientModel')
model = ClientModel(seed, *model_params)
# model_path = 'femnist.cnn'
# mod = importlib.import_module(model_path)
# ClientModel = getattr(mod, 'ClientModel')
# model = ClientModel(123, *(0.06, 62))

self._model = model
self.id = client_id
self.group = group
Expand All @@ -30,12 +44,13 @@ def train(self, num_epochs=1, batch_size=10, minibatch=None):
comp, update = self.model.train(data, num_epochs, batch_size)
else:
frac = min(1.0, minibatch)
num_data = max(1, int(frac*len(self.train_data["x"])))
num_data = max(1, int(frac * len(self.train_data["x"])))
xs, ys = zip(*random.sample(list(zip(self.train_data["x"], self.train_data["y"])), num_data))
data = {'x': xs, 'y': ys}
data = {'x': list(xs), 'y': list(ys)}

# Minibatch trains for only 1 epoch - multiple local epochs don't make sense!
num_epochs = 1
print(id(self.model))
comp, update = self.model.train(data, num_epochs, num_data)
num_train_samples = len(data['y'])
return comp, num_train_samples, update
Expand Down Expand Up @@ -88,8 +103,8 @@ def num_samples(self):
if self.train_data is not None:
train_size = len(self.train_data['y'])

test_size = 0
if self.eval_data is not None:
test_size = 0
if self.eval_data is not None:
test_size = len(self.eval_data['y'])
return train_size + test_size

Expand Down
100 changes: 83 additions & 17 deletions models/femnist/cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,41 +14,107 @@ def __init__(self, seed, lr, num_classes):

def create_model(self):
"""Model function for CNN."""
with tf.device('/gpu:0'):
features = tf.placeholder(
tf.float32, shape=[None, IMAGE_SIZE * IMAGE_SIZE], name='features')
labels = tf.placeholder(tf.int64, shape=[None], name='labels')
input_layer = tf.reshape(features, [-1, IMAGE_SIZE, IMAGE_SIZE, 1])
conv1 = tf.layers.conv2d(
inputs=input_layer,
filters=32,
kernel_size=[5, 5],
padding="same",
activation=tf.nn.relu)
pool1 = tf.layers.max_pooling2d(inputs=conv1, pool_size=[2, 2], strides=2)
conv2 = tf.layers.conv2d(
inputs=pool1,
filters=64,
kernel_size=[5, 5],
padding="same",
activation=tf.nn.relu)
pool2 = tf.layers.max_pooling2d(inputs=conv2, pool_size=[2, 2], strides=2)
pool2_flat = tf.reshape(pool2, [-1, 7 * 7 * 64])
dense = tf.layers.dense(inputs=pool2_flat, units=2048, activation=tf.nn.relu)
logits = tf.layers.dense(inputs=dense, units=self.num_classes)
predictions = {
"classes": tf.argmax(input=logits, axis=1),
"probabilities": tf.nn.softmax(logits, name="softmax_tensor")
}
loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
# TODO: Confirm that opt initialized once is ok?
train_op = self.optimizer.minimize(
loss=loss,
global_step=tf.train.get_global_step())
eval_metric_ops = tf.count_nonzero(tf.equal(labels, predictions["classes"]))
return features, labels, train_op, eval_metric_ops, loss

# todo fedsp
def create_fedsp_model(self):
"""Model function for FedSP-CNN."""
features = tf.placeholder(
tf.float32, shape=[None, IMAGE_SIZE * IMAGE_SIZE], name='features')
labels = tf.placeholder(tf.int64, shape=[None], name='labels')
input_layer = tf.reshape(features, [-1, IMAGE_SIZE, IMAGE_SIZE, 1])
conv1 = tf.layers.conv2d(
inputs=input_layer,
filters=32,
kernel_size=[5, 5],
padding="same",
activation=tf.nn.relu)
pool1 = tf.layers.max_pooling2d(inputs=conv1, pool_size=[2, 2], strides=2)
conv2 = tf.layers.conv2d(
inputs=pool1,
# global encoder
global_conv1 = tf.layers.conv2d(
inputs=input_layer,
filters=32,
kernel_size=[5, 5],
padding="same",
activation=tf.nn.relu, name='global_conv1')
global_pool1 = tf.layers.max_pooling2d(inputs=global_conv1, pool_size=[2, 2], strides=2)
global_conv2 = tf.layers.conv2d(
inputs=global_pool1,
filters=64,
kernel_size=[5, 5],
padding="same",
activation=tf.nn.relu, name='global_conv2')
global_pool2 = tf.layers.max_pooling2d(inputs=global_conv2, pool_size=[2, 2], strides=2)
global_pool2_flat = tf.reshape(global_pool2, [-1, 7 * 7 * 64])

# local encoder
local_conv1 = tf.layers.conv2d(
inputs=input_layer,
filters=32,
kernel_size=[5, 5],
padding="same",
activation=tf.nn.relu, name='local_conv1')
local_pool1 = tf.layers.max_pooling2d(inputs=local_conv1, pool_size=[2, 2], strides=2)
local_conv2 = tf.layers.conv2d(
inputs=local_pool1,
filters=64,
kernel_size=[5, 5],
padding="same",
activation=tf.nn.relu)
pool2 = tf.layers.max_pooling2d(inputs=conv2, pool_size=[2, 2], strides=2)
pool2_flat = tf.reshape(pool2, [-1, 7 * 7 * 64])
dense = tf.layers.dense(inputs=pool2_flat, units=2048, activation=tf.nn.relu)
activation=tf.nn.relu, name='local_conv2')
local_pool2 = tf.layers.max_pooling2d(inputs=local_conv2, pool_size=[2, 2], strides=2)
local_pool2_flat = tf.reshape(local_pool2, [-1, 7 * 7 * 64])

concat_res = tf.concat([global_pool2_flat, local_pool2_flat], 1)

dense = tf.layers.dense(inputs=concat_res, units=2048, activation=tf.nn.relu)

logits = tf.layers.dense(inputs=dense, units=self.num_classes)

predictions = {
"classes": tf.argmax(input=logits, axis=1),
"probabilities": tf.nn.softmax(logits, name="softmax_tensor")
"classes": tf.argmax(input=logits, axis=1),
"probabilities": tf.nn.softmax(logits, name="softmax_tensor")
}

loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)

# TODO: Confirm that opt initialized once is ok?
train_op = self.optimizer.minimize(
loss=loss,
global_step=tf.train.get_global_step())

eval_metric_ops = tf.count_nonzero(tf.equal(labels, predictions["classes"]))

return features, labels, train_op, eval_metric_ops, loss

def process_x(self, raw_x_batch):
@staticmethod
def process_x(raw_x_batch):
return np.array(raw_x_batch)

def process_y(self, raw_y_batch):
@staticmethod
def process_y(raw_y_batch):
return np.array(raw_y_batch)
65 changes: 42 additions & 23 deletions models/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import os
import sys
import random
import copy
import tensorflow as tf

import metrics.writer as metrics_writer
Expand All @@ -17,11 +18,11 @@
from utils.args import parse_args
from utils.model_utils import read_data

STAT_METRICS_PATH = 'metrics/stat_metrics.csv'
SYS_METRICS_PATH = 'metrics/sys_metrics.csv'
STAT_METRICS_PATH = 'metrics/metrics_stat.csv'
SYS_METRICS_PATH = 'metrics/metrics_sys.csv'

def main():

def main():
args = parse_args()

# Set the random seed if provided (affects client sampling, and batching)
Expand All @@ -33,10 +34,14 @@ def main():
if not os.path.exists(model_path):
print('Please specify a valid dataset and a valid model.')
model_path = '%s.%s' % (args.dataset, args.model)

print('############################## %s ##############################' % model_path)
mod = importlib.import_module(model_path)
ClientModel = getattr(mod, 'ClientModel')
# todo tdye
model_info = {
'model_path': model_path
}
# mod = importlib.import_module(model_path)
# ClientModel = getattr(mod, 'ClientModel')

tup = MAIN_PARAMS[args.dataset][args.t]
num_rounds = args.num_rounds if args.num_rounds != -1 else tup[0]
Expand All @@ -47,21 +52,30 @@ def main():
tf.logging.set_verbosity(tf.logging.WARN)

# Create 2 models
# model_params = (0.0003, 62)
# 默认学习率
model_params = MODEL_PARAMS[model_path]
# 重置学习率
# 重置后的模型参数
if args.lr != -1:
model_params_list = list(model_params)
model_params_list[0] = args.lr
model_params = tuple(model_params_list)

# Create client model, and share params with server model
# 重置全局默认图
tf.reset_default_graph()
client_model = ClientModel(args.seed, *model_params)

# model_params (0.06, 62)
# client_model = ClientModel(args.seed, *model_params)
model_info.update({
'seed': args.seed,
'model_params': model_params
})
# Create server
server = Server(client_model)
server = Server(model_info)

# Create clients
clients = setup_clients(args.dataset, client_model, args.use_val_set)
clients = setup_clients(args.dataset, model_info, args.use_val_set)
client_ids, client_groups, client_num_samples = server.get_clients_info(clients)
print('Clients in Total: %d' % len(clients))

Expand All @@ -80,16 +94,17 @@ def main():
c_ids, c_groups, c_num_samples = server.get_clients_info(server.selected_clients)

# Simulate server model training on selected clients' data
sys_metrics = server.train_model(num_epochs=args.num_epochs, batch_size=args.batch_size, minibatch=args.minibatch)
sys_metrics = server.train_model(num_epochs=args.num_epochs, batch_size=args.batch_size,
minibatch=args.minibatch)
sys_writer_fn(i + 1, c_ids, sys_metrics, c_groups, c_num_samples)

# Update server model
server.update_model()

# Test model
if (i + 1) % eval_every == 0 or (i + 1) == num_rounds:
print_stats(i + 1, server, clients, client_num_samples, args, stat_writer_fn, args.use_val_set)

# Save server model
ckpt_path = os.path.join('checkpoints', args.dataset)
if not os.path.exists(ckpt_path):
Expand All @@ -100,19 +115,24 @@ def main():
# Close models
server.close_model()


def online(clients):
"""We assume all users are always online."""
return clients


def create_clients(users, groups, train_data, test_data, model):
def create_clients(users, groups, train_data, test_data, model_info):
if len(groups) == 0:
groups = [[] for _ in users]
clients = [Client(u, g, train_data[u], test_data[u], model) for u, g in zip(users, groups)]
clients = [Client(u, g, train_data[u], test_data[u], model_info) for u, g in zip(users, groups)]
# clients = []
# for u, g in zip(users, groups):
# model = copy.deepcopy(model)
# clients.append(Client(u, g, train_data[u], test_data[u], model))
return clients


def setup_clients(dataset, model=None, use_val_set=False):
def setup_clients(dataset, model_info=None, use_val_set=False):
"""Instantiates clients based on given train and test data directories.

Return:
Expand All @@ -124,32 +144,31 @@ def setup_clients(dataset, model=None, use_val_set=False):

users, groups, train_data, test_data = read_data(train_data_dir, test_data_dir)

clients = create_clients(users, groups, train_data, test_data, model)
clients = create_clients(users, groups, train_data, test_data, model_info)

return clients


def get_stat_writer_function(ids, groups, num_samples, args):

def writer_fn(num_round, metrics, partition):
metrics_writer.print_metrics(
num_round, ids, metrics, groups, num_samples, partition, args.metrics_dir, '{}_{}'.format(args.metrics_name, 'stat'))
num_round, ids, metrics, groups, num_samples, partition, args.metrics_dir,
'{}_{}'.format(args.metrics_name, 'stat-fedsp'))

return writer_fn


def get_sys_writer_function(args):

def writer_fn(num_round, ids, metrics, groups, num_samples):
metrics_writer.print_metrics(
num_round, ids, metrics, groups, num_samples, 'train', args.metrics_dir, '{}_{}'.format(args.metrics_name, 'sys'))
num_round, ids, metrics, groups, num_samples, 'train', args.metrics_dir,
'{}_{}'.format(args.metrics_name, 'sys-fedsp'))

return writer_fn


def print_stats(
num_round, server, clients, num_samples, args, writer, use_val_set):

num_round, server, clients, num_samples, args, writer, use_val_set):
train_stat_metrics = server.test_model(clients, set_to_use='train')
print_metrics(train_stat_metrics, num_samples, prefix='train_')
writer(num_round, train_stat_metrics, 'train')
Expand Down
Loading