Skip to content

Commit

Permalink
Create MXNet Example (#614)
Browse files Browse the repository at this point in the history
Co-authored-by: Daniel J. Beutel <[email protected]>
Co-authored-by: Taner Topal <[email protected]>
  • Loading branch information
3 people authored Mar 10, 2021
1 parent 715be72 commit 98de50c
Show file tree
Hide file tree
Showing 9 changed files with 512 additions and 0 deletions.
66 changes: 66 additions & 0 deletions examples/mxnet_from_centralized_to_federated/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# MXNet: From Centralized To Federated

This example demonstrates how an already existing centralized MXNet-based machine learning project can be federated with Flower.

This introductory example for Flower uses MXNet, but you're not required to be a MXNet expert to run the example. The example will help you to understand how Flower can be used to build federated learning use cases based on an existing MXNet project.

## Project Setup

Start by cloning the example project. We prepared a single-line command that you can copy into your shell which will checkout the example for you:

```shell
git clone --depth=1 https://github.com/adap/flower.git && mv flower/examples/mxnet_from_centralized_to_federated . && rm -rf flower && cd mxnet_from_centralized_to_federated
```

This will create a new directory called `mxnet_from_centralized_to_federated` containing the following files:

```shell
-- pyproject.toml
-- mxnet_mnist.py
-- client.py
-- server.py
-- README.md
```

Project dependencies (such as `mxnet` and `flwr`) are defined in `pyproject.toml`. We recommend [Poetry](https://python-poetry.org/docs/) to install those dependencies and manage your virtual environment ([Poetry installation](https://python-poetry.org/docs/#installation)), but feel free to use a different way of installing dependencies and managing virtual environments if you have other preferences.

```shell
poetry install
poetry shell
```

Poetry will install all your dependencies in a newly created virtual environment. To verify that everything works correctly you can run the following command:

```shell
python3 -c "import flwr"
```

If you don't see any errors you're good to go!

## Run MXNet Federated

This MXNet example is based on the [Handwritten Digit Recognition](https://mxnet.apache.org/versions/1.7.0/api/python/docs/tutorials/packages/gluon/image/mnist.html) tutorial and uses the MNIST dataset (hand-written digits with 28x28 pixels in greyscale with 10 classes). Feel free to consult the tutorial if you want to get a better understanding of MXNet. The file `mxnet_mnist.py` contains all the steps that are described in the tutorial. It loads the dataset and a sequential model, trains the model with the training set, and evaluates the trained model on the test set.

The only things we need are a simple Flower server (in `server.py`) and a Flower client (in `client.py`). The Flower client basically takes model and training code tells Flower how to call it.

Start the server in a terminal as follows:

```shell
python3 server.py
```

Now that the server is running and waiting for clients, we can start two clients that will participate in the federated learning process. To do so simply open two more terminal windows and run the following commands.

Start client 1 in the first terminal:

```shell
python3 client.py
```

Start client 2 in the second terminal:

```shell
python3 client.py
```

You are now training a MXNet-based classifier on MNIST, federated across two clients. The setup is of course simplified since both clients hold the same dataset, but you can now continue with your own explorations. How about changing from a sequential model to a CNN? How about adding more clients?
80 changes: 80 additions & 0 deletions examples/mxnet_from_centralized_to_federated/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
"""Flower client example using MXNet for MNIST classification."""

from typing import Dict, List, Tuple

import flwr as fl
import numpy as np
import mxnet as mx
from mxnet import nd

import mxnet_mnist


# Flower Client
class MNISTClient(fl.client.NumPyClient):
"""Flower client implementing MNIST classification using MXNet."""

def __init__(
self,
model: mxnet_mnist.model(),
train_data: mx.io.NDArrayIter,
val_data: mx.io.NDArrayIter,
device: mx.context,
) -> None:
self.model = model
self.train_data = train_data
self.val_data = val_data
self.device = device

def get_parameters(self) -> List[np.ndarray]:
# Return model parameters as a list of NumPy Arrays
param = []
for val in self.model.collect_params(".*weight").values():
p = val.data()
# convert parameters from NDArray to Numpy Array required by Flower Numpy Client
param.append(p.asnumpy())
return param

def set_parameters(self, parameters: List[np.ndarray]) -> None:
# Collect model parameters and set new weight values
params = zip(self.model.collect_params(".*weight").keys(), parameters)
for key, value in params:
self.model.collect_params().setattr(key, value)

def fit(
self, parameters: List[np.ndarray], config: Dict
) -> Tuple[List[np.ndarray], int, Dict]:
# Set model parameters, train model, return updated model parameters
self.set_parameters(parameters)
mxnet_mnist.train(self.model, self.train_data, epoch=1, device=self.device)
return self.get_parameters(), self.train_data.batch_size, {}

def evaluate(
self, parameters: List[np.ndarray], config: Dict
) -> Tuple[int, float, Dict]:
# Set model parameters, evaluate model on local test dataset, return result
self.set_parameters(parameters)
loss, accuracy = mxnet_mnist.test(self.model, self.val_data, device=self.device)
return float(loss), self.val_data.batch_size, {"accuracy": float(accuracy)}


def main() -> None:
"""Load data, start MNISTClient."""

# Set context to GPU or - if not available - to CPU
DEVICE = [mx.gpu() if mx.test_utils.list_gpus() else mx.cpu()]
# Load data
train_data, val_data = mxnet_mnist.load_data()
# Load model (from centralized training)
model = mxnet_mnist.model()
# Do one forward propagation to initialize parameters
init = nd.random.uniform(shape=(2, 784))
model(init)

# Start Flower client
client = MNISTClient(model, train_data, val_data, DEVICE)
fl.client.start_numpy_client("0.0.0.0:8080", client)


if __name__ == "__main__":
main()
135 changes: 135 additions & 0 deletions examples/mxnet_from_centralized_to_federated/mxnet_mnist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
"""MXNet MNIST image classification.
The code is generally adapted from:
https://mxnet.apache.org/api/python/docs/tutorials/packages/gluon/image/mnist.html
"""

from __future__ import print_function
from typing import Tuple
import mxnet as mx
from mxnet import gluon
from mxnet.gluon import nn
from mxnet import autograd as ag
import mxnet.ndarray as F
from mxnet import nd

# Fixing the random seed
mx.random.seed(42)


def load_data() -> Tuple[mx.io.NDArrayIter, mx.io.NDArrayIter]:
print("Download Dataset")
# Download MNIST data
mnist = mx.test_utils.get_mnist()
batch_size = 100
train_data = mx.io.NDArrayIter(
mnist["train_data"], mnist["train_label"], batch_size, shuffle=True
)
val_data = mx.io.NDArrayIter(mnist["test_data"], mnist["test_label"], batch_size)
return train_data, val_data


def model():
# Define simple Sequential model
net = nn.Sequential()
net.add(nn.Dense(256, activation="relu"))
net.add(nn.Dense(10))
net.collect_params().initialize()
return net


def train(
net: mx.gluon.nn, train_data: mx.io.NDArrayIter, epoch: int, device: mx.context
) -> None:
trainer = gluon.Trainer(net.collect_params(), "sgd", {"learning_rate": 0.03})
# Use Accuracy as the evaluation metric.
metric = mx.metric.Accuracy()
softmax_cross_entropy_loss = gluon.loss.SoftmaxCrossEntropyLoss()
for i in range(epoch):
# Reset the train data iterator.
train_data.reset()
# Loop over the train data iterator.
for batch in train_data:
# Splits train data into multiple slices along batch_axis
# and copy each slice into a context.
data = gluon.utils.split_and_load(
batch.data[0], ctx_list=device, batch_axis=0
)
# Splits train labels into multiple slices along batch_axis
# and copy each slice into a context.
label = gluon.utils.split_and_load(
batch.label[0], ctx_list=device, batch_axis=0
)
outputs = []
# Inside training scope
with ag.record():
for x, y in zip(data, label):
z = net(x)
# Computes softmax cross entropy loss.
loss = softmax_cross_entropy_loss(z, y)
# Backpropogate the error for one iteration.
loss.backward()
outputs.append(z)
# Updates internal evaluation
metric.update(label, outputs)
# Make one step of parameter update. Trainer needs to know the
# batch size of data to normalize the gradient by 1/batch_size.
trainer.step(batch.data[0].shape[0])
# Gets the evaluation result.
name, acc = metric.get()
# name_loss, running_loss = loss_metric.get()
# Reset evaluation result to initial state.
metric.reset()
print("training acc at epoch %d: %s=%f" % (i, name, acc))


def test(
net: mx.gluon.nn, val_data: mx.io.NDArrayIter, device: mx.context
) -> Tuple[float, float]:
# Use Accuracy as the evaluation metric.
metric = mx.metric.Accuracy()
loss_metric = mx.metric.Loss()
loss = 0.0
# Reset the validation data iterator.
val_data.reset()
# Loop over the validation data iterator.
for batch in val_data:
# Splits validation data into multiple slices along batch_axis
# and copy each slice into a context.
data = gluon.utils.split_and_load(batch.data[0], ctx_list=device, batch_axis=0)
# Splits validation label into multiple slices along batch_axis
# and copy each slice into a context.
label = gluon.utils.split_and_load(
batch.label[0], ctx_list=device, batch_axis=0
)
outputs = []
for x in data:
outputs.append(net(x))
loss_metric.update(label, outputs)
loss += loss_metric.get()[1]
# Updates internal evaluation
metric.update(label, outputs)
accuracy = metric.get()[1]
return loss, accuracy


def main():
# Set context to GPU or - if not available - to CPU
DEVICE = [mx.gpu() if mx.test_utils.list_gpus() else mx.cpu()]
# Load train and validation data
train_data, val_data = load_data()
# Define sequential model
net = model()
init = nd.random.uniform(shape=(2, 784))
net(init)
# Start model training based on training set
train(net=net, train_data=train_data, epoch=5, device=DEVICE)
# Evaluate model using loss and accuracy
loss, acc = test(net=net, val_data=val_data, device=DEVICE)
print("Loss: ", loss)
print("Accuracy: ", acc)


if __name__ == "__main__":
main()
14 changes: 14 additions & 0 deletions examples/mxnet_from_centralized_to_federated/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
[tool.poetry]
name = "mxnet_example"
version = "0.1.0"
description = "MXNet example with MNIST and CNN"
authors = ["The Flower Authors <[email protected]>"]

[tool.poetry.dependencies]
python = "^3.6.1"
flwr = "^0.14.0" # For development: { path = "../../", develop = true }
mxnet = "^1.7.0"

[build-system]
requires = ["poetry-core==1.1.4"]
build-backend = "poetry.core.masonry.api"
6 changes: 6 additions & 0 deletions examples/mxnet_from_centralized_to_federated/server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
"""Flower server example."""

import flwr as fl

if __name__ == "__main__":
fl.server.start_server("0.0.0.0:8080", config={"num_rounds": 3})
63 changes: 63 additions & 0 deletions examples/quickstart_mxnet/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Flower Example using MXNet

This example demonstrates how to run a MXNet machine learning project federated with Flower.

This introductory example for Flower uses MXNet, but you're not required to be a MXNet expert to run the example. The example will help you to understand how Flower can be used to build federated learning use cases based on an existing MXNet projects.

## Project Setup

Start by cloning the example project. We prepared a single-line command that you can copy into your shell which will checkout the example for you:

```shell
git clone --depth=1 https://github.com/adap/flower.git && mv flower/examples/quickstart_mxnet . && rm -rf flower && cd quickstart_mxnet
```

This will create a new directory called `quickstart_mxnet` containing the following files:

```shell
-- pyproject.toml
-- client.py
-- server.py
-- README.md
```

Project dependencies (such as `mxnet` and `flwr`) are defined in `pyproject.toml`. We recommend [Poetry](https://python-poetry.org/docs/) to install those dependencies and manage your virtual environment ([Poetry installation](https://python-poetry.org/docs/#installation)), but feel free to use a different way of installing dependencies and managing virtual environments if you have other preferences.

```shell
poetry install
poetry shell
```

Poetry will install all your dependencies in a newly created virtual environment. To verify that everything works correctly you can run the following command:

```shell
python3 -c "import flwr"
```

If you don't see any errors you're good to go!

## Run MXNet Federated

This MXNet example is based on the [Handwritten Digit Recognition](https://mxnet.apache.org/versions/1.7.0/api/python/docs/tutorials/packages/gluon/image/mnist.html) tutorial and uses the MNIST dataset (hand-written digits with 28x28 pixels in greyscale with 10 classes). Feel free to consult the tutorial if you want to get a better understanding of MXNet. The file `client.py` contains all the steps that are described in the tutorial. It loads the dataset and a sequential model, trains the model with the training set, and evaluates the trained model on the test set.

You are ready to start the Flower server as well as the clients. You can simply start the server in a terminal as follows:

```shell
python3 server.py
```

Now you are ready to start the Flower clients which will participate in the learning. To do so simply open two more terminal windows and run the following commands.

Start client 1 in the first terminal:

```shell
python3 client.py
```

Start client 2 in the second terminal:

```shell
python3 client.py
```

You are now training a MXNet-based classifier on MNIST, federated across two clients. The setup is of course simplified since both clients hold the same dataset, but you can now continue with your own explorations. How about changing from a sequential model to a CNN? How about adding more clients?
Loading

0 comments on commit 98de50c

Please sign in to comment.