diff --git a/doc/source/ref-changelog.md b/doc/source/ref-changelog.md index 58bb6cbf4254..030f7618f4b2 100644 --- a/doc/source/ref-changelog.md +++ b/doc/source/ref-changelog.md @@ -2,6 +2,8 @@ ## Unreleased +- **General updates to Flower Examples** ([#2381](https://github.com/adap/flower/pull/2381)) + - **Update Flower Baselines** - HFedXGBoost [#2226](https://github.com/adap/flower/pull/2226) diff --git a/examples/simulation-pytorch/README.md b/examples/simulation-pytorch/README.md index 2fe8366cbc04..11b7a3364376 100644 --- a/examples/simulation-pytorch/README.md +++ b/examples/simulation-pytorch/README.md @@ -1,6 +1,6 @@ # Flower Simulation example using PyTorch -This introductory example uses the simulation capabilities of Flower to simulate a large number of clients on either a single machine or a cluster of machines. Take a look at the [Documentation](https://flower.dev/docs/framework/how-to-run-simulations.html) for a deep dive on how Flower simulation works. +This introductory example uses the simulation capabilities of Flower to simulate a large number of clients on a single machine. Take a look at the [Documentation](https://flower.dev/docs/framework/how-to-run-simulations.html) for a deep dive into how Flower simulation works. This example uses [Flower Datasets](https://flower.dev/docs/datasets/) to download, partition and preprocess the MNIST dataset. This examples uses 100 clients by default. ## Running the example (via Jupyter Notebook) @@ -41,7 +41,7 @@ poetry shell Poetry will install all your dependencies in a newly created virtual environment. To verify that everything works correctly you can run the following command: ```shell -poetry run python3 -c "import flwr" +poetry run python -c "import flwr" ``` If you don't see any errors you're good to go! @@ -58,7 +58,7 @@ pip install -r requirements.txt ```bash # You can run the example without activating your environemnt -poetry run python3 sim.py +poetry run python sim.py # Or by first activating it poetry shell diff --git a/examples/simulation-pytorch/pyproject.toml b/examples/simulation-pytorch/pyproject.toml index 3b1cacf230f8..07918c0cd17c 100644 --- a/examples/simulation-pytorch/pyproject.toml +++ b/examples/simulation-pytorch/pyproject.toml @@ -11,5 +11,10 @@ authors = ["The Flower Authors "] [tool.poetry.dependencies] python = ">=3.8,<3.11" flwr = { extras = ["simulation"], version = ">=1.0,<2.0" } -torch = "1.13.1" -torchvision = "0.14.1" +flwr-datasets = { extras = ["vision"], version = ">=0.0.2,<1.0.0" } +torch = "2.1.1" +torchvision = "0.16.1" + +[tool.poetry.group.dev.dependencies] +ipykernel = "^6.27.0" + diff --git a/examples/simulation-pytorch/requirements.txt b/examples/simulation-pytorch/requirements.txt index 78ac83101d5b..4dbecab3e546 100644 --- a/examples/simulation-pytorch/requirements.txt +++ b/examples/simulation-pytorch/requirements.txt @@ -1,3 +1,4 @@ flwr[simulation]>=1.0, <2.0 -torch==1.13.1 -torchvision==0.14.1 \ No newline at end of file +torch==2.1.1 +torchvision==0.16.1 +flwr-datasets[vision]>=0.0.2, <1.0.0 \ No newline at end of file diff --git a/examples/simulation-pytorch/sim.ipynb b/examples/simulation-pytorch/sim.ipynb index e708aa36542d..508630cf9422 100644 --- a/examples/simulation-pytorch/sim.ipynb +++ b/examples/simulation-pytorch/sim.ipynb @@ -21,7 +21,8 @@ "outputs": [], "source": [ "# depending on your shell, you might need to add `\\` before `[` and `]`.\n", - "!pip install -q flwr[simulation]" + "!pip install -q flwr[simulation]\n", + "!pip install flwr_datasets[vision]" ] }, { @@ -29,7 +30,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "We will be using the _simulation_ mode in Flower, which allows you to run a large number of clients without the overheads of manually managing devices. This is achieved via the [Virtual Client Engine]() in Flower. With simulation, you can dynamically scale your experiments whether you run the code on your laptop, a machine with a single GPU, a server with multiple GPUs os even on a cluster with multiple servers. The `Virtual Client Engine` handles everything transparently and it allows you to specify how many resources (e.g. CPU cores, GPU VRAM) should be assigned to each virtual client." + "We will be using the _simulation_ mode in Flower, which allows you to run a large number of clients without the overheads of manually managing devices. This is achieved via the [Virtual Client Engine](https://flower.dev/docs/framework/how-to-run-simulations.html) in Flower. With simulation, you can dynamically scale your experiments whether you run the code on your laptop, a machine with a single GPU, a server with multiple GPUs os even on a cluster with multiple servers. The `Virtual Client Engine` handles everything transparently and it allows you to specify how many resources (e.g. CPU cores, GPU VRAM) should be assigned to each virtual client." ] }, { @@ -40,22 +41,7 @@ "\n", "Flower is agnostic to your choice of ML Framework. Flower works with `PyTorch`, `Tensorflow`, `NumPy`, `🤗 Transformers`, `MXNet`, `JAX`, `scikit-learn`, `fastai`, `Pandas`. Flower also supports all major platforms: `iOS`, `Android` and plain `C++`. You can find a _quickstart-_ example for each of the above in the [Flower Repository](https://github.com/adap/flower/tree/main/examples) inside the `examples/` directory.\n", "\n", - "In this tutorial we are going to use PyTorch, so let's install a recent version. In this tutorial we'll use a small model so using CPU only training will suffice (this will also prevent Colab from abruptly terminating your experiment if resource limits are exceeded)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "7192138a-8c87-4d9a-f726-af1038ad264c" - }, - "outputs": [], - "source": [ - "# Install Pytorch with CPU support. Please adjust this command for your platform or if you want to use a GPU\n", - "!pip install torch==1.13.1+cpu torchvision==0.14.1+cpu torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cpu" + "In this tutorial we are going to use PyTorch, it comes pre-installed in your Collab runtime so there is no need to installed it again. If you wouuld like to install another version, you can still do that in the same way other packages are installed via `!pip`" ] }, { @@ -63,7 +49,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "We are going to install some other dependencies you are likely familiar with. We'll use these to make plots." + "We are going to install some other dependencies you are likely familiar with. Let's install `maplotlib` to plot our results at the end." ] }, { @@ -80,187 +66,6 @@ "!pip install matplotlib" ] }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Centralised training: the old way of doing ML" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Let's begin by creating a simple (but complete) training loop as it is commonly done in centralised setups. Starting our tutorial in this way will allow us to very clearly identify which parts of a typical ML pipeline are common to both centralised and federated training.\n", - "\n", - "For this tutorial we'll design a image classification pipeline for [MNIST digits](https://en.wikipedia.org/wiki/MNIST_database) and using a simple CNN model as the network to train. The MNIST dataset is comprised of `28x28` greyscale images with digits from 0 to 9 (i.e. 10 classes in total)\n", - "\n", - "\n", - "## A dataset\n", - "\n", - "Let's begin by constructing the dataset." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# we naturally first need to import torch and torchvision\n", - "import torch\n", - "from torch.utils.data import DataLoader\n", - "from torchvision.transforms import ToTensor, Normalize, Compose\n", - "from torchvision.datasets import MNIST\n", - "\n", - "\n", - "def get_mnist(data_path: str = \"./data\"):\n", - " \"\"\"This function downloads the MNIST dataset into the `data_path`\n", - " directory if it is not there already. We construct the train/test\n", - " split by converting the images into tensors and normalizing them\"\"\"\n", - "\n", - " # transformation to convert images to tensors and apply normalization\n", - " tr = Compose([ToTensor(), Normalize((0.1307,), (0.3081,))])\n", - "\n", - " # prepare train and test set\n", - " trainset = MNIST(data_path, train=True, download=True, transform=tr)\n", - " testset = MNIST(data_path, train=False, download=True, transform=tr)\n", - "\n", - " return trainset, testset" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Let's run the code above and do some visualisations to understand better the data we are working with !" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "trainset, testset = get_mnist()" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We can have a quick overview of our datasets by just typing the object on the command line. For instance, below you can see that the `trainset` has 60k training examples and will use the transformation rule we defined above in `get_mnist()`." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "f10b649f-3cee-4e86-c7ff-94bd1fd3e082" - }, - "outputs": [], - "source": [ - "trainset" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Let's create a more insightful visualisation. First let's see the distribution over the labels by constructing a histogram. Then, let's visualise some training examples !" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 490 - }, - "outputId": "c8d0f4c0-60cd-4c58-bc91-3b061dae8046" - }, - "outputs": [], - "source": [ - "import matplotlib.pyplot as plt\n", - "\n", - "\n", - "# construct histogram\n", - "all_labels = trainset.targets\n", - "num_possible_labels = len(\n", - " set(all_labels.numpy().tolist())\n", - ") # this counts unique labels (so it should be = 10)\n", - "plt.hist(all_labels, bins=num_possible_labels)\n", - "\n", - "# plot formatting\n", - "plt.xticks(range(num_possible_labels))\n", - "plt.grid()\n", - "plt.xlabel(\"Label\")\n", - "plt.ylabel(\"Number of images\")\n", - "plt.title(\"Class labels distribution for MNIST\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import random\n", - "import numpy as np\n", - "\n", - "\n", - "def visualise_n_random_examples(trainset_, n: int, verbose: bool = True):\n", - " # take n examples at random\n", - " idx = list(range(len(trainset_.data)))\n", - " random.shuffle(idx)\n", - " idx = idx[:n]\n", - " if verbose:\n", - " print(f\"will display images with idx: {idx}\")\n", - "\n", - " # construct canvas\n", - " num_cols = 8\n", - " num_rows = int(np.ceil(len(idx) / num_cols))\n", - " fig, axs = plt.subplots(figsize=(16, num_rows * 2), nrows=num_rows, ncols=num_cols)\n", - "\n", - " # display images on canvas\n", - " for c_i, i in enumerate(idx):\n", - " axs.flat[c_i].imshow(trainset_.data[i], cmap=\"gray\")" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Let's visualise 32 images from the dataset\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 715 - }, - "outputId": "4e0988a8-388d-4acf-882b-089e4ea887bf" - }, - "outputs": [], - "source": [ - "# it is likely that the plot this function will generate looks familiar to other plots you might have generated before\n", - "# or you might have encountered in other tutorials. So far, we aren't doing anything new, Federated Learning will start soon!\n", - "visualise_n_random_examples(trainset, n=32)" - ] - }, { "attachments": {}, "cell_type": "markdown", @@ -278,8 +83,10 @@ "metadata": {}, "outputs": [], "source": [ + "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", + "from torch.utils.data import DataLoader\n", "\n", "\n", "class Net(nn.Module):\n", @@ -319,28 +126,27 @@ "metadata": {}, "outputs": [], "source": [ - "def train(net, trainloader, optimizer, epochs, device):\n", + "def train(net, trainloader, optim, epochs, device: str):\n", " \"\"\"Train the network on the training set.\"\"\"\n", " criterion = torch.nn.CrossEntropyLoss()\n", " net.train()\n", " for _ in range(epochs):\n", - " for images, labels in trainloader:\n", - " images, labels = images.to(device), labels.to(device)\n", - " optimizer.zero_grad()\n", + " for batch in trainloader:\n", + " images, labels = batch[\"image\"].to(device), batch[\"label\"].to(device)\n", + " optim.zero_grad()\n", " loss = criterion(net(images), labels)\n", " loss.backward()\n", - " optimizer.step()\n", - " return net\n", + " optim.step()\n", "\n", "\n", - "def test(net, testloader, device):\n", + "def test(net, testloader, device: str):\n", " \"\"\"Validate the network on the entire test set.\"\"\"\n", " criterion = torch.nn.CrossEntropyLoss()\n", " correct, loss = 0, 0.0\n", " net.eval()\n", " with torch.no_grad():\n", - " for images, labels in testloader:\n", - " images, labels = images.to(device), labels.to(device)\n", + " for data in testloader:\n", + " images, labels = data[\"image\"].to(device), data[\"label\"].to(device)\n", " outputs = net(images)\n", " loss += criterion(outputs, labels).item()\n", " _, predicted = torch.max(outputs.data, 1)\n", @@ -370,7 +176,9 @@ "source": [ "## One Client, One Data Partition\n", "\n", - "To start designing a Federated Learning pipeline we need to meet one of the key properties in FL: each client has its own data partition. To accomplish this with the MNIST dataset, we are going to generate N random partitions, where N is the total number of clients in our FL system." + "To start designing a Federated Learning pipeline we need to meet one of the key properties in FL: each client has its own data partition. To accomplish this with the MNIST dataset, we are going to generate N random partitions, where N is the total number of clients in our FL system.\n", + "\n", + "We can use [Flower Datasets](https://flower.dev/docs/datasets/) to effortlessly obtain an off-the-shelf partitioned dataset or partition one that isn't pre-partitioned. Let's choose MNIST." ] }, { @@ -379,94 +187,42 @@ "metadata": {}, "outputs": [], "source": [ - "from torch.utils.data import random_split\n", - "\n", - "\n", - "def prepare_dataset(num_partitions: int, batch_size: int, val_ratio: float = 0.1):\n", - " \"\"\"This function partitions the training set into N disjoint\n", - " subsets, each will become the local dataset of a client. This\n", - " function also subsequently partitions each training set partition\n", - " into train and validation. The test set is left intact and will\n", - " be used by the central server to asses the performance of the\n", - " global model.\"\"\"\n", - "\n", - " # get the MNIST datatset\n", - " trainset, testset = get_mnist()\n", - "\n", - " # split trainset into `num_partitions` trainsets\n", - " num_images = len(trainset) // num_partitions\n", - "\n", - " partition_len = [num_images] * num_partitions\n", + "from datasets import Dataset\n", + "from flwr_datasets import FederatedDataset\n", + "from datasets.utils.logging import disable_progress_bar\n", "\n", - " trainsets = random_split(\n", - " trainset, partition_len, torch.Generator().manual_seed(2023)\n", - " )\n", - "\n", - " # create dataloaders with train+val support\n", - " trainloaders = []\n", - " valloaders = []\n", - " for trainset_ in trainsets:\n", - " num_total = len(trainset_)\n", - " num_val = int(val_ratio * num_total)\n", - " num_train = num_total - num_val\n", - "\n", - " for_train, for_val = random_split(\n", - " trainset_, [num_train, num_val], torch.Generator().manual_seed(2023)\n", - " )\n", - "\n", - " trainloaders.append(\n", - " DataLoader(for_train, batch_size=batch_size, shuffle=True, num_workers=2)\n", - " )\n", - " valloaders.append(\n", - " DataLoader(for_val, batch_size=batch_size, shuffle=False, num_workers=2)\n", - " )\n", - "\n", - " # create dataloader for the test set\n", - " testloader = DataLoader(testset, batch_size=128)\n", + "# Let's set a simulation involving a total of 100 clients\n", + "NUM_CLIENTS = 100\n", "\n", - " return trainloaders, valloaders, testloader" + "# Download MNIST dataset and partition the \"train\" partition (so one can be assigned to each client)\n", + "mnist_fds = FederatedDataset(dataset=\"mnist\", partitioners={\"train\": NUM_CLIENTS})\n", + "# Let's keep the test set as is, and use it to evaluate the global model on the server\n", + "centralized_testset = mnist_fds.load_full(\"test\")" ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ - "Let's create 100 partitions and extract some statistics from one partition\n" + "Let's create a function that returns a set of transforms to apply to our images" ] }, { "cell_type": "code", "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 508 - }, - "outputId": "0f53ca81-cb55-46ef-c8e0-4e19a4f060b2" - }, + "metadata": {}, "outputs": [], "source": [ - "NUM_CLIENTS = 100\n", - "\n", - "trainloaders, valloaders, testloader = prepare_dataset(\n", - " num_partitions=NUM_CLIENTS, batch_size=32\n", - ")\n", + "from torchvision.transforms import ToTensor, Normalize, Compose\n", "\n", - "# first partition\n", - "train_partition = trainloaders[0].dataset\n", "\n", - "# count data points\n", - "partition_indices = train_partition.indices\n", - "print(f\"number of images: {len(partition_indices)}\")\n", + "def apply_transforms(batch):\n", + " \"\"\"Get transformation for MNIST dataset\"\"\"\n", "\n", - "# visualise histogram\n", - "plt.hist(train_partition.dataset.dataset.targets[partition_indices], bins=10)\n", - "plt.grid()\n", - "plt.xticks(range(10))\n", - "plt.xlabel(\"Label\")\n", - "plt.ylabel(\"Number of images\")\n", - "plt.title(\"Class labels distribution for MNIST\")" + " # transformation to convert images to tensors and apply normalization\n", + " transforms = Compose([ToTensor(), Normalize((0.1307,), (0.3081,))])\n", + " batch[\"image\"] = [transforms(img) for img in batch[\"image\"]]\n", + " return batch" ] }, { @@ -474,9 +230,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "As you can see, the histogram of this partition is a bit different from the one we obtained at the beginning where we took the entire dataset into consideration. Because our data partitions are artificially constructed by sampling the MNIST dataset in an IID fashion, our Federated Learning example will not face sever _data heterogeneity_ issues (which is a fairly [active research topic](https://arxiv.org/abs/1912.04977)).\n", - "\n", - "Let's next define how our FL clients will behave\n", + "Let's next define how our FL clients will behave.\n", "\n", "## Defining a Flower Client\n", "\n", @@ -521,16 +275,15 @@ "from collections import OrderedDict\n", "from typing import Dict, List, Tuple\n", "\n", - "import torch\n", "from flwr.common import NDArrays, Scalar\n", "\n", "\n", "class FlowerClient(fl.client.NumPyClient):\n", - " def __init__(self, trainloader, vallodaer) -> None:\n", + " def __init__(self, trainloader, valloader) -> None:\n", " super().__init__()\n", "\n", " self.trainloader = trainloader\n", - " self.valloader = vallodaer\n", + " self.valloader = valloader\n", " self.model = Net(num_classes=10)\n", " # Determine device\n", " self.device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", @@ -613,7 +366,7 @@ "metadata": {}, "outputs": [], "source": [ - "def get_evaluate_fn(testloader):\n", + "def get_evaluate_fn(centralized_testset: Dataset):\n", " \"\"\"This is a function that returns a function. The returned\n", " function (i.e. `evaluate_fn`) will be executed by the strategy\n", " at the end of each round to evaluate the stat of the global\n", @@ -636,20 +389,15 @@ " state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})\n", " model.load_state_dict(state_dict, strict=True)\n", "\n", + " # Apply transform to dataset\n", + " testset = centralized_testset.with_transform(apply_transforms)\n", + "\n", + " testloader = DataLoader(testset, batch_size=50)\n", " # call test\n", " loss, accuracy = test(model, testloader, device)\n", " return loss, {\"accuracy\": accuracy}\n", "\n", - " return evaluate_fn\n", - "\n", - "\n", - "# now we can define the strategy\n", - "# strategy = fl.server.strategy.FedAvg(\n", - "# fraction_fit=0.1,\n", - "# fraction_evaluate=0.1,\n", - "# min_available_clients=100,\n", - "# evaluate_fn=get_evaluate_fn(testloader), # Even this is not required\n", - "# )" + " return evaluate_fn" ] }, { @@ -707,14 +455,9 @@ "strategy = fl.server.strategy.FedAvg(\n", " fraction_fit=0.1, # Sample 10% of available clients for training\n", " fraction_evaluate=0.05, # Sample 5% of available clients for evaluation\n", - " min_fit_clients=10, # Never sample less than 10 clients for training\n", - " min_evaluate_clients=5, # Never sample less than 5 clients for evaluation\n", - " min_available_clients=int(\n", - " NUM_CLIENTS * 0.75\n", - " ), # Wait until at least 75 clients are available\n", " on_fit_config_fn=fit_config,\n", " evaluate_metrics_aggregation_fn=weighted_average, # aggregates federated metrics\n", - " evaluate_fn=get_evaluate_fn(testloader), # global evaluation function\n", + " evaluate_fn=get_evaluate_fn(centralized_testset), # global evaluation function\n", ")" ] }, @@ -737,18 +480,41 @@ "metadata": {}, "outputs": [], "source": [ - "def generate_client_fn(trainloaders, valloaders):\n", - " def client_fn(cid: str):\n", - " \"\"\"Returns a FlowerClient containing the cid-th data partition\"\"\"\n", + "from torch.utils.data import DataLoader\n", + "\n", + "\n", + "def get_client_fn(dataset: FederatedDataset):\n", + " \"\"\"Return a function to construct a client.\n", + "\n", + " The VirtualClientEngine will execute this function whenever a client is sampled by\n", + " the strategy to participate.\n", + " \"\"\"\n", "\n", - " return FlowerClient(\n", - " trainloader=trainloaders[int(cid)], vallodaer=valloaders[int(cid)]\n", + " def client_fn(cid: str) -> fl.client.Client:\n", + " \"\"\"Construct a FlowerClient with its own dataset partition.\"\"\"\n", + "\n", + " # Let's get the partition corresponding to the i-th client\n", + " client_dataset = dataset.load_partition(int(cid), \"train\")\n", + "\n", + " # Now let's split it into train (90%) and validation (10%)\n", + " client_dataset_splits = client_dataset.train_test_split(test_size=0.1)\n", + "\n", + " trainset = client_dataset_splits[\"train\"]\n", + " valset = client_dataset_splits[\"test\"]\n", + "\n", + " # Now we apply the transform to each batch.\n", + " trainloader = DataLoader(\n", + " trainset.with_transform(apply_transforms), batch_size=32, shuffle=True\n", " )\n", + " valloader = DataLoader(valset.with_transform(apply_transforms), batch_size=32)\n", + "\n", + " # Create and return client\n", + " return FlowerClient(trainloader, valloader)\n", "\n", " return client_fn\n", "\n", "\n", - "client_fn_callback = generate_client_fn(trainloaders, valloaders)" + "client_fn_callback = get_client_fn(mnist_fds)" ] }, { @@ -774,6 +540,8 @@ "# client needs exclusive access to these many resources in order to run\n", "client_resources = {\"num_cpus\": 1, \"num_gpus\": 0.0}\n", "\n", + "# Let's disable tqdm progress bar in the main thread (used by the server)\n", + "disable_progress_bar()\n", "\n", "history = fl.simulation.start_simulation(\n", " client_fn=client_fn_callback, # a callback to construct a client\n", @@ -781,6 +549,9 @@ " config=fl.server.ServerConfig(num_rounds=10), # let's run for 10 rounds\n", " strategy=strategy, # the strategy that will orchestrate the whole FL pipeline\n", " client_resources=client_resources,\n", + " actor_kwargs={\n", + " \"on_actor_init_fn\": disable_progress_bar # disable tqdm on each actor/process spawning virtual clients\n", + " },\n", ")" ] }, @@ -806,6 +577,8 @@ }, "outputs": [], "source": [ + "import matplotlib.pyplot as plt\n", + "\n", "print(f\"{history.metrics_centralized = }\")\n", "\n", "global_accuracy_centralised = history.metrics_centralized[\"accuracy\"]\n", @@ -817,6 +590,27 @@ "plt.xlabel(\"Round\")\n", "plt.title(\"MNIST - IID - 100 clients with 10 clients per round\")" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Congratulations! With that, you built a Flower client, customized it's instantiation through the `client_fn`, customized the server-side execution through a `FedAvg` strategy configured for this workload, and started a simulation with 100 clients (each holding their own individual partition of the MNIST dataset).\n", + "\n", + "Next, you can continue to explore more advanced Flower topics:\n", + "\n", + "- Deploy server and clients on different machines using `start_server` and `start_client`\n", + "- Customize the server-side execution through custom strategies\n", + "- Customize the client-side execution through `config` dictionaries\n", + "\n", + "Get all resources you need!\n", + "\n", + "* **[DOCS]** Our complete documenation: https://flower.dev/docs/\n", + "* **[Examples]** All Flower examples: https://flower.dev/docs/examples/\n", + "* **[VIDEO]** Our Youtube channel: https://www.youtube.com/@flowerlabs\n", + "\n", + "Don't forget to join our Slack channel: https://flower.dev/join-slack/\n" + ] } ], "metadata": { @@ -825,10 +619,11 @@ "toc_visible": true }, "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 3 (ipykernel)", + "language": "python", "name": "python3" } }, "nbformat": 4, - "nbformat_minor": 0 + "nbformat_minor": 4 } diff --git a/examples/simulation-pytorch/sim.py b/examples/simulation-pytorch/sim.py index 5adfca744591..68d9426e83ab 100644 --- a/examples/simulation-pytorch/sim.py +++ b/examples/simulation-pytorch/sim.py @@ -3,15 +3,17 @@ from typing import Dict, Tuple, List import torch -import torchvision -from torch.utils.data import DataLoader, random_split +from torch.utils.data import DataLoader import flwr as fl from flwr.common import Metrics from flwr.common.typing import Scalar -from utils import Net, train, test, get_mnist +from datasets import Dataset +from datasets.utils.logging import disable_progress_bar +from flwr_datasets import FederatedDataset +from utils import Net, train, test, apply_transforms parser = argparse.ArgumentParser(description="Flower Simulation with PyTorch") @@ -78,18 +80,28 @@ def evaluate(self, parameters, config): return float(loss), len(valloader.dataset), {"accuracy": float(accuracy)} -def get_client_fn(train_partitions, val_partitions): +def get_client_fn(dataset: FederatedDataset): """Return a function to construct a client. - The VirtualClientEngine will exectue this function whenever a client is sampled by + The VirtualClientEngine will execute this function whenever a client is sampled by the strategy to participate. """ def client_fn(cid: str) -> fl.client.Client: """Construct a FlowerClient with its own dataset partition.""" - # Extract partition for client with id = cid - trainset, valset = train_partitions[int(cid)], val_partitions[int(cid)] + # Let's get the partition corresponding to the i-th client + client_dataset = dataset.load_partition(int(cid), "train") + + # Now let's split it into train (90%) and validation (10%) + client_dataset_splits = client_dataset.train_test_split(test_size=0.1) + + trainset = client_dataset_splits["train"] + valset = client_dataset_splits["test"] + + # Now we apply the transform to each batch. + trainset = trainset.with_transform(apply_transforms) + valset = valset.with_transform(apply_transforms) # Create and return client return FlowerClient(trainset, valset) @@ -113,40 +125,6 @@ def set_params(model: torch.nn.ModuleList, params: List[fl.common.NDArrays]): model.load_state_dict(state_dict, strict=True) -def prepare_dataset(): - """Download and partitions the MNIST dataset.""" - - # Get the MNIST dataset - trainset, testset = get_mnist() - - # Split trainset into `num_partitions` trainsets - num_images = len(trainset) // NUM_CLIENTS - partition_len = [num_images] * NUM_CLIENTS - - trainsets = random_split( - trainset, partition_len, torch.Generator().manual_seed(2023) - ) - - val_ratio = 0.1 - - # Create dataloaders with train+val support - train_partitions = [] - val_partitions = [] - for trainset_ in trainsets: - num_total = len(trainset_) - num_val = int(val_ratio * num_total) - num_train = num_total - num_val - - for_train, for_val = random_split( - trainset_, [num_train, num_val], torch.Generator().manual_seed(2023) - ) - - train_partitions.append(for_train) - val_partitions.append(for_val) - - return train_partitions, val_partitions, testset - - def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: """Aggregation function for (federated) evaluation metrics, i.e. those returned by the client's evaluate() method.""" @@ -159,7 +137,7 @@ def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: def get_evaluate_fn( - testset: torchvision.datasets.CIFAR10, + centralized_testset: Dataset, ): """Return an evaluation function for centralized evaluation.""" @@ -175,6 +153,12 @@ def evaluate( set_params(model, parameters) model.to(device) + # Apply transform to dataset + testset = centralized_testset.with_transform(apply_transforms) + + # Disable tqdm for dataset preprocessing + disable_progress_bar() + testloader = DataLoader(testset, batch_size=50) loss, accuracy = test(model, testloader, device=device) @@ -187,8 +171,9 @@ def main(): # Parse input arguments args = parser.parse_args() - # Download CIFAR-10 dataset and partition it - trainsets, valsets, testset = prepare_dataset() + # Download MNIST dataset and partition it + mnist_fds = FederatedDataset(dataset="mnist", partitioners={"train": NUM_CLIENTS}) + centralized_testset = mnist_fds.load_full("test") # Configure the strategy strategy = fl.server.strategy.FedAvg( @@ -201,7 +186,7 @@ def main(): ), # Wait until at least 75 clients are available on_fit_config_fn=fit_config, evaluate_metrics_aggregation_fn=weighted_average, # Aggregate federated metrics - evaluate_fn=get_evaluate_fn(testset), # Global evaluation function + evaluate_fn=get_evaluate_fn(centralized_testset), # Global evaluation function ) # Resources to be assigned to each virtual client @@ -212,11 +197,14 @@ def main(): # Start simulation fl.simulation.start_simulation( - client_fn=get_client_fn(trainsets, valsets), + client_fn=get_client_fn(mnist_fds), num_clients=NUM_CLIENTS, client_resources=client_resources, config=fl.server.ServerConfig(num_rounds=args.num_rounds), strategy=strategy, + actor_kwargs={ + "on_actor_init_fn": disable_progress_bar # disable tqdm on each actor/process spawning virtual clients + }, ) diff --git a/examples/simulation-pytorch/utils.py b/examples/simulation-pytorch/utils.py index fff6bb490930..01f63cc94ba3 100644 --- a/examples/simulation-pytorch/utils.py +++ b/examples/simulation-pytorch/utils.py @@ -3,7 +3,13 @@ import torch.nn.functional as F from torchvision.transforms import ToTensor, Normalize, Compose -from torchvision.datasets import MNIST + + +# transformation to convert images to tensors and apply normalization +def apply_transforms(batch): + transforms = Compose([ToTensor(), Normalize((0.1307,), (0.3081,))]) + batch["image"] = [transforms(img) for img in batch["image"]] + return batch # Model (simple CNN adapted from 'PyTorch: A 60 Minute Blitz') @@ -33,8 +39,8 @@ def train(net, trainloader, optim, epochs, device: str): criterion = torch.nn.CrossEntropyLoss() net.train() for _ in range(epochs): - for images, labels in trainloader: - images, labels = images.to(device), labels.to(device) + for batch in trainloader: + images, labels = batch["image"].to(device), batch["label"].to(device) optim.zero_grad() loss = criterion(net(images), labels) loss.backward() @@ -49,23 +55,10 @@ def test(net, testloader, device: str): net.eval() with torch.no_grad(): for data in testloader: - images, labels = data[0].to(device), data[1].to(device) + images, labels = data["image"].to(device), data["label"].to(device) outputs = net(images) loss += criterion(outputs, labels).item() _, predicted = torch.max(outputs.data, 1) correct += (predicted == labels).sum().item() accuracy = correct / len(testloader.dataset) return loss, accuracy - - -def get_mnist(data_path: str = "./data"): - """Download MNIST and apply transform.""" - - # transformation to convert images to tensors and apply normalization - tr = Compose([ToTensor(), Normalize((0.1307,), (0.3081,))]) - - # prepare train and test set - trainset = MNIST(data_path, train=True, download=True, transform=tr) - testset = MNIST(data_path, train=False, download=True, transform=tr) - - return trainset, testset diff --git a/examples/simulation-tensorflow/README.md b/examples/simulation-tensorflow/README.md index 61a6749a6bdf..f0d94f343d37 100644 --- a/examples/simulation-tensorflow/README.md +++ b/examples/simulation-tensorflow/README.md @@ -1,6 +1,6 @@ # Flower Simulation example using TensorFlow/Keras -This introductory example uses the simulation capabilities of Flower to simulate a large number of clients on either a single machine of a cluster of machines. Take a look at the [Documentation](https://flower.dev/docs/framework/how-to-run-simulations.html) for a deep dive on how Flower simulation works. +This introductory example uses the simulation capabilities of Flower to simulate a large number of clients on a single machine. Take a look at the [Documentation](https://flower.dev/docs/framework/how-to-run-simulations.html) for a deep dive into how Flower simulation works. This example uses [Flower Datasets](https://flower.dev/docs/datasets/) to download, partition and preprocess the MNIST dataset. This examples uses 100 clients by default. ## Running the example (via Jupyter Notebook) @@ -40,7 +40,7 @@ poetry shell Poetry will install all your dependencies in a newly created virtual environment. To verify that everything works correctly you can run the following command: ```shell -poetry run python3 -c "import flwr" +poetry run python -c "import flwr" ``` If you don't see any errors you're good to go! @@ -57,7 +57,7 @@ pip install -r requirements.txt ```bash # You can run the example without activating your environemnt -poetry run python3 sim.py +poetry run python sim.py # Or by first activating it poetry shell diff --git a/examples/simulation-tensorflow/pyproject.toml b/examples/simulation-tensorflow/pyproject.toml index 4016c3da0da0..f2e7bd3006c0 100644 --- a/examples/simulation-tensorflow/pyproject.toml +++ b/examples/simulation-tensorflow/pyproject.toml @@ -11,5 +11,6 @@ authors = ["The Flower Authors "] [tool.poetry.dependencies] python = ">=3.8,<3.11" flwr = { extras = ["simulation"], version = ">=1.0,<2.0" } +flwr-datasets = { extras = ["vision"], version = ">=0.0.2,<1.0.0" } tensorflow = {version = "^2.9.1, !=2.11.1", markers="platform_machine == 'x86_64'"} tensorflow-macos = {version = "^2.9.1, !=2.11.1", markers="sys_platform == 'darwin' and platform_machine == 'arm64'"} diff --git a/examples/simulation-tensorflow/requirements.txt b/examples/simulation-tensorflow/requirements.txt index 76e77f5ff9b8..bb69a87be1b4 100644 --- a/examples/simulation-tensorflow/requirements.txt +++ b/examples/simulation-tensorflow/requirements.txt @@ -1,3 +1,4 @@ flwr[simulation]>=1.0, <2.0 +flwr-datasets[vision]>=0.0.2, <1.0.0 tensorflow-macos>=2.9.1, != 2.11.1 ; sys_platform == "darwin" and platform_machine == "arm64" tensorflow-cpu>=2.9.1, != 2.11.1 ; platform_machine == "x86_64" diff --git a/examples/simulation-tensorflow/sim.ipynb b/examples/simulation-tensorflow/sim.ipynb index 559dcf3170a3..575b437018f3 100644 --- a/examples/simulation-tensorflow/sim.ipynb +++ b/examples/simulation-tensorflow/sim.ipynb @@ -17,7 +17,8 @@ "metadata": {}, "outputs": [], "source": [ - "!pip install -q flwr[\"simulation\"] tensorflow" + "!pip install -q flwr[\"simulation\"] tensorflow\n", + "!pip install -q flwr_datasets[\"vision\"]" ] }, { @@ -49,7 +50,6 @@ "metadata": {}, "outputs": [], "source": [ - "import math\n", "from typing import Dict, List, Tuple\n", "\n", "import tensorflow as tf\n", @@ -58,6 +58,9 @@ "from flwr.common import Metrics\n", "from flwr.simulation.ray_transport.utils import enable_tf_gpu_growth\n", "\n", + "from datasets import Dataset\n", + "from flwr_datasets import FederatedDataset\n", + "\n", "VERBOSE = 0\n", "NUM_CLIENTS = 100" ] @@ -113,28 +116,24 @@ "outputs": [], "source": [ "class FlowerClient(fl.client.NumPyClient):\n", - " def __init__(self, x_train, y_train, x_val, y_val) -> None:\n", + " def __init__(self, trainset, valset) -> None:\n", " # Create model\n", " self.model = get_model()\n", - " self.x_train, self.y_train = x_train, y_train\n", - " self.x_val, self.y_val = x_val, y_val\n", + " self.trainset = trainset\n", + " self.valset = valset\n", "\n", " def get_parameters(self, config):\n", " return self.model.get_weights()\n", "\n", " def fit(self, parameters, config):\n", " self.model.set_weights(parameters)\n", - " self.model.fit(\n", - " self.x_train, self.y_train, epochs=1, batch_size=32, verbose=VERBOSE\n", - " )\n", - " return self.model.get_weights(), len(self.x_train), {}\n", + " self.model.fit(self.trainset, epochs=1, verbose=VERBOSE)\n", + " return self.model.get_weights(), len(self.trainset), {}\n", "\n", " def evaluate(self, parameters, config):\n", " self.model.set_weights(parameters)\n", - " loss, acc = self.model.evaluate(\n", - " self.x_val, self.y_val, batch_size=64, verbose=VERBOSE\n", - " )\n", - " return loss, len(self.x_val), {\"accuracy\": acc}" + " loss, acc = self.model.evaluate(self.valset, verbose=VERBOSE)\n", + " return loss, len(self.valset), {\"accuracy\": acc}" ] }, { @@ -155,8 +154,6 @@ "We now define four auxiliary functions for this example (note the last two are entirely optional):\n", "* `get_client_fn()`: Is a function that returns another function. The returned `client_fn` will be executed by Flower's VirtualClientEngine each time a new _virtual_ client (i.e. a client that is simulated in a Python process) needs to be spawn. When are virtual clients spawned? Each time the strategy samples them to do either `fit()` (i.e. train the global model on the local data of a particular client) or `evaluate()` (i.e. evaluate the global model on the validation set of a given client).\n", "\n", - "* `partition_mnist()`: A utility function that downloads the MNIST dataset and partitions it into `NUM_CLIENT` disjoint sets. The resulting list of dataset partitions will be passed to `get_client_fn()` so a client can be constructed by passing it its corresponding dataset partition. There are multiple ways of partitioning a dataset, but in this example we keep things simple. For larger dataset, you might want to pre-partition your dataset before running your Flower experiment and, potentially, store these partition into your files system or a database. In this way, your `FlowerClient` objects can retrieve their data directly when doing either `fit()` or `evaluate()`.\n", - "\n", "* `weighted_average()`: This is an optional function to pass to the strategy. It will be executed after an evaluation round (i.e. when client run `evaluate()`) and will aggregate the metrics clients return. In this example, we use this function to compute the weighted average accuracy of clients doing `evaluate()`.\n", "\n", "* `get_evaluate_fn()`: This is again a function that returns another function. The returned function will be executed by the strategy at the end of a `fit()` round and after a new global model has been obtained after aggregation. This is an optional argument for Flower strategies. In this example, we use the whole MNIST test set to perform this server-side evaluation." @@ -168,42 +165,35 @@ "metadata": {}, "outputs": [], "source": [ - "def get_client_fn(dataset_partitions):\n", - " \"\"\"Return a function to be executed by the VirtualClientEngine in order to construct\n", - " a client.\"\"\"\n", + "def get_client_fn(dataset: FederatedDataset):\n", + " \"\"\"Return a function to construct a client.\n", + "\n", + " The VirtualClientEngine will execute this function whenever a client is sampled by\n", + " the strategy to participate.\n", + " \"\"\"\n", "\n", " def client_fn(cid: str) -> fl.client.Client:\n", " \"\"\"Construct a FlowerClient with its own dataset partition.\"\"\"\n", "\n", " # Extract partition for client with id = cid\n", - " x_train, y_train = dataset_partitions[int(cid)]\n", - " # Use 10% of the client's training data for validation\n", - " split_idx = math.floor(len(x_train) * 0.9)\n", - " x_train_cid, y_train_cid = (\n", - " x_train[:split_idx],\n", - " y_train[:split_idx],\n", + " client_dataset = dataset.load_partition(int(cid), \"train\")\n", + "\n", + " # Now let's split it into train (90%) and validation (10%)\n", + " client_dataset_splits = client_dataset.train_test_split(test_size=0.1)\n", + "\n", + " trainset = client_dataset_splits[\"train\"].to_tf_dataset(\n", + " columns=\"image\", label_cols=\"label\", batch_size=32\n", + " )\n", + " valset = client_dataset_splits[\"test\"].to_tf_dataset(\n", + " columns=\"image\", label_cols=\"label\", batch_size=64\n", " )\n", - " x_val_cid, y_val_cid = x_train[split_idx:], y_train[split_idx:]\n", "\n", " # Create and return client\n", - " return FlowerClient(x_train_cid, y_train_cid, x_val_cid, y_val_cid)\n", + " return FlowerClient(trainset, valset)\n", "\n", " return client_fn\n", "\n", "\n", - "def partition_mnist():\n", - " \"\"\"Download and partitions the MNIST dataset.\"\"\"\n", - " (x_train, y_train), testset = tf.keras.datasets.mnist.load_data()\n", - " partitions = []\n", - " # We keep all partitions equal-sized in this example\n", - " partition_size = math.floor(len(x_train) / NUM_CLIENTS)\n", - " for cid in range(NUM_CLIENTS):\n", - " # Split dataset into non-overlapping NUM_CLIENT partitions\n", - " idx_from, idx_to = int(cid) * partition_size, (int(cid) + 1) * partition_size\n", - " partitions.append((x_train[idx_from:idx_to] / 255.0, y_train[idx_from:idx_to]))\n", - " return partitions, testset\n", - "\n", - "\n", "def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics:\n", " \"\"\"Aggregation function for (federated) evaluation metrics, i.e. those returned by\n", " the client's evaluate() method.\"\"\"\n", @@ -215,9 +205,8 @@ " return {\"accuracy\": sum(accuracies) / sum(examples)}\n", "\n", "\n", - "def get_evaluate_fn(testset):\n", - " \"\"\"Return an evaluation function for server-side (i.e. centralized) evaluation.\"\"\"\n", - " x_test, y_test = testset\n", + "def get_evaluate_fn(testset: Dataset):\n", + " \"\"\"Return an evaluation function for server-side (i.e. centralised) evaluation.\"\"\"\n", "\n", " # The `evaluate` function will be called after every round by the strategy\n", " def evaluate(\n", @@ -227,7 +216,7 @@ " ):\n", " model = get_model() # Construct the model\n", " model.set_weights(parameters) # Update model with the latest parameters\n", - " loss, accuracy = model.evaluate(x_test, y_test, verbose=VERBOSE)\n", + " loss, accuracy = model.evaluate(testset, verbose=VERBOSE)\n", " return loss, {\"accuracy\": accuracy}\n", "\n", " return evaluate" @@ -241,7 +230,9 @@ "\n", "The function `start_simulation` accepts a number of arguments, amongst them the `client_fn` used to create `FlowerClient` instances, the number of clients to simulate `num_clients`, the number of rounds `num_rounds`, and the strategy. The strategy encapsulates the federated learning approach/algorithm, for example, *Federated Averaging* (FedAvg).\n", "\n", - "Flower comes with a number of built-in strategies, but we can also use our own strategy implementations to customize nearly all aspects of the federated learning approach. For this example, we use the built-in `FedAvg` implementation and customize it using a few basic parameters. The last step is the actual call to `start_simulation` which - you guessed it - actually starts the simulation." + "Flower comes with a number of built-in strategies, but we can also use our own strategy implementations to customize nearly all aspects of the federated learning approach. For this example, we use the built-in `FedAvg` implementation and customize it using a few basic parameters. The last step is the actual call to `start_simulation` which - you guessed it - actually starts the simulation.\n", + "\n", + "We can use [Flower Datasets](https://flower.dev/docs/datasets/) to effortlessly obtain an off-the-shelf partitioned dataset or partition one that isn't pre-partitioned. Let's choose MNIST." ] }, { @@ -253,8 +244,13 @@ "# Enable GPU growth in your main process\n", "enable_tf_gpu_growth()\n", "\n", - "# Create dataset partitions (needed if your dataset is not pre-partitioned)\n", - "partitions, testset = partition_mnist()\n", + "# Download MNIST dataset and partition it\n", + "mnist_fds = FederatedDataset(dataset=\"mnist\", partitioners={\"train\": NUM_CLIENTS})\n", + "# Get the whole test set for centralised evaluation\n", + "centralized_testset = mnist_fds.load_full(\"test\").to_tf_dataset(\n", + " columns=\"image\", label_cols=\"label\", batch_size=64\n", + ")\n", + "\n", "\n", "# Create FedAvg strategy\n", "strategy = fl.server.strategy.FedAvg(\n", @@ -266,7 +262,7 @@ " NUM_CLIENTS * 0.75\n", " ), # Wait until at least 75 clients are available\n", " evaluate_metrics_aggregation_fn=weighted_average, # aggregates federated metrics\n", - " evaluate_fn=get_evaluate_fn(testset), # global evaluation function\n", + " evaluate_fn=get_evaluate_fn(centralized_testset), # global evaluation function\n", ")\n", "\n", "# With a dictionary, you tell Flower's VirtualClientEngine that each\n", @@ -275,7 +271,7 @@ "\n", "# Start simulation\n", "history = fl.simulation.start_simulation(\n", - " client_fn=get_client_fn(partitions),\n", + " client_fn=get_client_fn(mnist_fds),\n", " num_clients=NUM_CLIENTS,\n", " config=fl.server.ServerConfig(num_rounds=10),\n", " strategy=strategy,\n", @@ -290,7 +286,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "You can then use the resturned History object to either save the results to disk or do some visualisation (or both of course, or neither if you like chaos). Below you can see how you can plot the centralised accuracy obtainined at the end of each round (including at the very beginning of the experiment) for the global model. This is want the function evaluate_fn() that we passed to the strategy reports." + "You can then use the resturned History object to either save the results to disk or do some visualisation (or both of course, or neither if you like chaos). Below you can see how you can plot the centralised accuracy obtainined at the end of each round (including at the very beginning of the experiment) for the global model. This is want the function `evaluate_fn()` that we passed to the strategy reports." ] }, { @@ -323,7 +319,15 @@ "\n", "- Deploy server and clients on different machines using `start_server` and `start_client`\n", "- Customize the server-side execution through custom strategies\n", - "- Customize the client-side execution through `config` dictionaries" + "- Customize the client-side execution through `config` dictionaries\n", + "\n", + "Get all resources you need!\n", + "\n", + "* **[DOCS]** Our complete documenation: https://flower.dev/docs/\n", + "* **[Examples]** All Flower examples: https://flower.dev/docs/examples/\n", + "* **[VIDEO]** Our Youtube channel: https://www.youtube.com/@flowerlabs\n", + "\n", + "Don't forget to join our Slack channel: https://flower.dev/join-slack/" ] } ], @@ -333,11 +337,11 @@ "provenance": [] }, "kernelspec": { - "display_name": "Python 3.8.12 ('.venv': poetry)", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" } }, "nbformat": 4, - "nbformat_minor": 0 + "nbformat_minor": 4 } diff --git a/examples/simulation-tensorflow/sim.py b/examples/simulation-tensorflow/sim.py index 15f7097dc439..490e25fe8c8d 100644 --- a/examples/simulation-tensorflow/sim.py +++ b/examples/simulation-tensorflow/sim.py @@ -9,6 +9,8 @@ from flwr.common import Metrics from flwr.simulation.ray_transport.utils import enable_tf_gpu_growth +from datasets import Dataset +from flwr_datasets import FederatedDataset # Make TensorFlow logs less verbose os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" @@ -34,28 +36,24 @@ class FlowerClient(fl.client.NumPyClient): - def __init__(self, x_train, y_train, x_val, y_val) -> None: + def __init__(self, trainset, valset) -> None: # Create model self.model = get_model() - self.x_train, self.y_train = x_train, y_train - self.x_val, self.y_val = x_val, y_val + self.trainset = trainset + self.valset = valset def get_parameters(self, config): return self.model.get_weights() def fit(self, parameters, config): self.model.set_weights(parameters) - self.model.fit( - self.x_train, self.y_train, epochs=1, batch_size=32, verbose=VERBOSE - ) - return self.model.get_weights(), len(self.x_train), {} + self.model.fit(self.trainset, epochs=1, verbose=VERBOSE) + return self.model.get_weights(), len(self.trainset), {} def evaluate(self, parameters, config): self.model.set_weights(parameters) - loss, acc = self.model.evaluate( - self.x_val, self.y_val, batch_size=64, verbose=VERBOSE - ) - return loss, len(self.x_val), {"accuracy": acc} + loss, acc = self.model.evaluate(self.valset, verbose=VERBOSE) + return loss, len(self.valset), {"accuracy": acc} def get_model(): @@ -72,10 +70,10 @@ def get_model(): return model -def get_client_fn(dataset_partitions): - """Return a function to construc a client. +def get_client_fn(dataset: FederatedDataset): + """Return a function to construct a client. - The VirtualClientEngine will exectue this function whenever a client is sampled by + The VirtualClientEngine will execute this function whenever a client is sampled by the strategy to participate. """ @@ -83,34 +81,24 @@ def client_fn(cid: str) -> fl.client.Client: """Construct a FlowerClient with its own dataset partition.""" # Extract partition for client with id = cid - x_train, y_train = dataset_partitions[int(cid)] - # Use 10% of the client's training data for validation - split_idx = math.floor(len(x_train) * 0.9) - x_train_cid, y_train_cid = ( - x_train[:split_idx], - y_train[:split_idx], + client_dataset = dataset.load_partition(int(cid), "train") + + # Now let's split it into train (90%) and validation (10%) + client_dataset_splits = client_dataset.train_test_split(test_size=0.1) + + trainset = client_dataset_splits["train"].to_tf_dataset( + columns="image", label_cols="label", batch_size=32 + ) + valset = client_dataset_splits["test"].to_tf_dataset( + columns="image", label_cols="label", batch_size=64 ) - x_val_cid, y_val_cid = x_train[split_idx:], y_train[split_idx:] # Create and return client - return FlowerClient(x_train_cid, y_train_cid, x_val_cid, y_val_cid) + return FlowerClient(trainset, valset) return client_fn -def partition_mnist(): - """Download and partitions the MNIST dataset.""" - (x_train, y_train), testset = tf.keras.datasets.mnist.load_data() - partitions = [] - # We keep all partitions equal-sized in this example - partition_size = math.floor(len(x_train) / NUM_CLIENTS) - for cid in range(NUM_CLIENTS): - # Split dataset into non-overlapping NUM_CLIENT partitions - idx_from, idx_to = int(cid) * partition_size, (int(cid) + 1) * partition_size - partitions.append((x_train[idx_from:idx_to] / 255.0, y_train[idx_from:idx_to])) - return partitions, testset - - def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: """Aggregation function for (federated) evaluation metrics. @@ -124,9 +112,8 @@ def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: return {"accuracy": sum(accuracies) / sum(examples)} -def get_evaluate_fn(testset): +def get_evaluate_fn(testset: Dataset): """Return an evaluation function for server-side (i.e. centralised) evaluation.""" - x_test, y_test = testset # The `evaluate` function will be called after every round by the strategy def evaluate( @@ -136,7 +123,7 @@ def evaluate( ): model = get_model() # Construct the model model.set_weights(parameters) # Update model with the latest parameters - loss, accuracy = model.evaluate(x_test, y_test, verbose=VERBOSE) + loss, accuracy = model.evaluate(testset, verbose=VERBOSE) return loss, {"accuracy": accuracy} return evaluate @@ -146,8 +133,12 @@ def main() -> None: # Parse input arguments args = parser.parse_args() - # Create dataset partitions (needed if your dataset is not pre-partitioned) - partitions, testset = partition_mnist() + # Download MNIST dataset and partition it + mnist_fds = FederatedDataset(dataset="mnist", partitioners={"train": NUM_CLIENTS}) + # Get the whole test set for centralised evaluation + centralized_testset = mnist_fds.load_full("test").to_tf_dataset( + columns="image", label_cols="label", batch_size=64 + ) # Create FedAvg strategy strategy = fl.server.strategy.FedAvg( @@ -159,7 +150,7 @@ def main() -> None: NUM_CLIENTS * 0.75 ), # Wait until at least 75 clients are available evaluate_metrics_aggregation_fn=weighted_average, # aggregates federated metrics - evaluate_fn=get_evaluate_fn(testset), # global evaluation function + evaluate_fn=get_evaluate_fn(centralized_testset), # global evaluation function ) # With a dictionary, you tell Flower's VirtualClientEngine that each @@ -171,7 +162,7 @@ def main() -> None: # Start simulation fl.simulation.start_simulation( - client_fn=get_client_fn(partitions), + client_fn=get_client_fn(mnist_fds), num_clients=NUM_CLIENTS, config=fl.server.ServerConfig(num_rounds=args.num_rounds), strategy=strategy,