diff --git a/openfl-tutorials/experimental/Workflow_Interface_104_Histology_with_fedcurv.ipynb b/openfl-tutorials/experimental/Workflow_Interface_104_Histology_with_fedcurv.ipynb new file mode 100644 index 0000000000..69b77fc50c --- /dev/null +++ b/openfl-tutorials/experimental/Workflow_Interface_104_Histology_with_fedcurv.ipynb @@ -0,0 +1,549 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "14821d97", + "metadata": {}, + "source": [ + "# Workflow Interface 104: Histology with Fedcurv implementation\n", + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/intel/openfl/blob/develop/openfl-tutorials/experimental/Workflow_Interface_104_Histology_with_fedcurv.ipynb)" + ] + }, + { + "cell_type": "markdown", + "id": "a7989e72", + "metadata": {}, + "source": [ + "In this OpenFL workflow interface tutorial, we'll learn how to implement FedCurv aggregation algorithm using Histology dataset." + ] + }, + { + "cell_type": "markdown", + "id": "fc8e35da", + "metadata": {}, + "source": [ + "# Getting Started" + ] + }, + { + "cell_type": "markdown", + "id": "4dbb89b6", + "metadata": {}, + "source": [ + "First we start by installing the necessary dependencies for the workflow interface" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f7f98600", + "metadata": {}, + "outputs": [], + "source": [ + "# !pip install git+https://github.com/intel/openfl.git\n", + "# !pip install -r https://raw.githubusercontent.com/intel/openfl/develop/openfl-tutorials/experimental/requirements_workflow_interface.txt\n", + "\n", + "# Uncomment this if running in Google Colab\n", + "#import os\n", + "#os.environ[\"USERNAME\"] = \"colab\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7e85e030", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "from copy import deepcopy\n", + "import torch\n", + "import torchvision\n", + "import numpy as np\n", + "\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "import torch.optim as optim\n", + "\n", + "import warnings\n", + "warnings.filterwarnings(\"ignore\")\n", + "\n", + "from pathlib import Path\n", + "from urllib.request import urlretrieve\n", + "from zipfile import ZipFile\n", + "from PIL import Image\n", + "\n", + "from openfl.utilities import tqdm_report_hook\n", + "from openfl.utilities import validate_file_hash\n", + "\n", + "batch_size_train = 64\n", + "batch_size_test = 64\n", + "learning_rate = 0.01\n", + "log_interval = 10\n", + "\n", + "np.random.seed(0)\n", + "torch.manual_seed(0)\n", + "\n", + "# Download data\n", + "\n", + "URL = ('https://zenodo.org/record/53169/files/Kather_'\n", + " 'texture_2016_image_tiles_5000.zip?download=1')\n", + "FILENAME = 'Kather_texture_2016_image_tiles_5000.zip'\n", + "ZIP_SHA384 = ('7d86abe1d04e68b77c055820c2a4c582a1d25d2983e38ab724e'\n", + " 'ac75affce8b7cb2cbf5ba68848dcfd9d84005d87d6790')\n", + "data_folder = Path('.') / 'data'\n", + "\n", + "\n", + "os.makedirs(data_folder, exist_ok=True)\n", + "filepath = data_folder / FILENAME\n", + "if not filepath.exists():\n", + " reporthook = tqdm_report_hook()\n", + " urlretrieve(URL, filepath, reporthook)\n", + " validate_file_hash(filepath, ZIP_SHA384)\n", + " with ZipFile(filepath, 'r') as f:\n", + " f.extractall(data_folder)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ce39cf3a-cfd3-4c58-9f78-d1389ee9d8d4", + "metadata": {}, + "outputs": [], + "source": [ + "TRAIN_SPLIT_RATIO = 0.8\n", + "root = Path(data_folder) / 'Kather_texture_2016_image_tiles_5000'\n", + "classes = [d.name for d in root.iterdir() if d.is_dir()]\n", + "class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}\n", + "samples = []\n", + "root = root.expanduser()\n", + "for target_class in sorted(class_to_idx.keys()):\n", + " class_index = class_to_idx[target_class]\n", + " target_dir = os.path.join(root, target_class)\n", + " for class_root, _, fnames in sorted(os.walk(target_dir, followlinks=True)):\n", + " for fname in sorted(fnames):\n", + " path = os.path.join(class_root, fname)\n", + " item = path, class_index\n", + " samples.append(item)\n", + "idx_range = list(range(len(samples)))\n", + "idx_sep = int(len(idx_range) * TRAIN_SPLIT_RATIO)\n", + "\n", + "train_samples = samples[:idx_sep]\n", + "test_samples = samples[idx_sep:]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "57a6eee1-87ce-4f06-a87f-f28b30f82364", + "metadata": {}, + "outputs": [], + "source": [ + "\"\"\"\n", + "MobileNetV2 model\n", + "\"\"\"\n", + "\n", + "class Net(nn.Module):\n", + " def __init__(self):\n", + " super(Net, self).__init__()\n", + " conv_kwargs = {'kernel_size': 3, 'stride': 1, 'padding': 1}\n", + " self.conv1 = nn.Conv2d(3, 16, **conv_kwargs)\n", + " self.conv2 = nn.Conv2d(16, 32, **conv_kwargs)\n", + " self.conv3 = nn.Conv2d(32, 64, **conv_kwargs)\n", + " self.conv4 = nn.Conv2d(64, 128, **conv_kwargs)\n", + " self.conv5 = nn.Conv2d(128 + 32, 256, **conv_kwargs)\n", + " self.conv6 = nn.Conv2d(256, 512, **conv_kwargs)\n", + " self.conv7 = nn.Conv2d(512 + 128 + 32, 256, **conv_kwargs)\n", + " self.conv8 = nn.Conv2d(256, 512, **conv_kwargs)\n", + " self.fc1 = nn.Linear(1184 * 9 * 9, 128)\n", + " self.fc2 = nn.Linear(128, 8)\n", + "\n", + "\n", + " def forward(self, x):\n", + " torch.manual_seed(0)\n", + " x = F.relu(self.conv1(x))\n", + " x = F.relu(self.conv2(x))\n", + " maxpool = F.max_pool2d(x, 2, 2)\n", + "\n", + " x = F.relu(self.conv3(maxpool))\n", + " x = F.relu(self.conv4(x))\n", + " concat = torch.cat([maxpool, x], dim=1)\n", + " maxpool = F.max_pool2d(concat, 2, 2)\n", + "\n", + " x = F.relu(self.conv5(maxpool))\n", + " x = F.relu(self.conv6(x))\n", + " concat = torch.cat([maxpool, x], dim=1)\n", + " maxpool = F.max_pool2d(concat, 2, 2)\n", + "\n", + " x = F.relu(self.conv7(maxpool))\n", + " x = F.relu(self.conv8(x))\n", + " concat = torch.cat([maxpool, x], dim=1)\n", + " maxpool = F.max_pool2d(concat, 2, 2)\n", + "\n", + " x = maxpool.flatten(start_dim=1)\n", + " x = F.dropout(self.fc1(x), p=0.5)\n", + " x = self.fc2(x)\n", + " return x\n", + "\n", + " \n", + "def inference(network, test_loader, device):\n", + " network = network.to(device)\n", + " network.eval()\n", + " \n", + " test_score = 0\n", + " # total_samples = 0\n", + " test_loss = 0\n", + "\n", + " with torch.no_grad():\n", + " for data, target in test_loader:\n", + " data, target = torch.tensor(data).to(device), \\\n", + " torch.tensor(target).to(device, dtype=torch.int64)\n", + " output = network(data)\n", + " test_loss += F.cross_entropy(output, target)\n", + " pred = output.argmax(dim=1)\n", + " test_score += pred.eq(target).sum().cpu().numpy()\n", + " test_loss /= len(test_loader.dataset)\n", + " print('\\nTest set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\\n'.format(\n", + " test_loss, test_score, len(test_loader.dataset),\n", + " 100. * test_score / len(test_loader.dataset)))\n", + " accuracy = float(test_score / len(test_loader.dataset))\n", + " return accuracy" + ] + }, + { + "cell_type": "markdown", + "id": "cd268911", + "metadata": {}, + "source": [ + "Next we import the `FLSpec`, `LocalRuntime`, and placement decorators.\n", + "\n", + "- `FLSpec` – Defines the flow specification. User defined flows are subclasses of this.\n", + "- `Runtime` – Defines where the flow runs, infrastructure for task transitions (how information gets sent). The `LocalRuntime` runs the flow on a single node.\n", + "- `aggregator/collaborator` - placement decorators that define where the task will be assigned\n", + "\n", + "In addition to these, we also import `FedCurv` module along with `FedcurvWeightedAvg` aggregation algorithm." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "precise-studio", + "metadata": {}, + "outputs": [], + "source": [ + "from copy import deepcopy\n", + "\n", + "from openfl.experimental.interface import FLSpec, Aggregator, Collaborator\n", + "from openfl.experimental.runtime import LocalRuntime\n", + "from openfl.experimental.placement import aggregator, collaborator\n", + "\n", + "from openfl.experimental.interface.aggregation_functions.fedcurv_weighted_average import fedcurv_weighted_average\n", + "from openfl.experimental.utilities.fedcurv import FedCurv" + ] + }, + { + "attachments": { + "image.png": { + "image/png": "" + } + }, + "cell_type": "markdown", + "id": "8e406db6", + "metadata": {}, + "source": [ + "Now we come to the flow definition. The OpenFL Workflow Interface adopts the conventions set by Metaflow, that every workflow begins with `start` and concludes with the `end` task. The aggregator begins with an optionally passed in model and optimizer. The aggregator begins the flow with the `start` task, where the list of collaborators is extracted from the runtime (`self.collaborators = self.runtime.collaborators`) and is then used as the list of participants to run the task listed in `self.next`, `aggregated_model_validation`. The model, optimizer, and anything that is not explicitly excluded from the next function will be passed from the `start` function on the aggregator to the `aggregated_model_validation` task on the collaborator. Where the tasks run is determined by the placement decorator that precedes each task definition (`@aggregator` or `@collaborator`). Once each of the collaborators (defined in the runtime) complete the `aggregated_model_validation` task, they pass their current state onto the `train` task, from `train` to `local_model_validation`, and then finally to `join` at the aggregator. It is in `join` that an average is taken of the model weights, and the next round can begin.\n", + "\n", + "![image.png](attachment:image.png)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "difficult-madrid", + "metadata": {}, + "outputs": [], + "source": [ + "class FederatedFlow(FLSpec):\n", + "\n", + " def __init__(self, model = None, optimizer = None, total_rounds = 10, top_model_accuracy=0, **kwargs):\n", + " super().__init__(**kwargs)\n", + " self.model = model\n", + " self.optimizer = optimizer\n", + " self.total_rounds = total_rounds\n", + " self.top_model_accuracy = top_model_accuracy\n", + " self.round = 0\n", + " self.agg_method = FedCurv(self.model, importance=1e7)\n", + " self.device = 'cpu'\n", + " if torch.cuda.is_available():\n", + " self.device = 'cuda:0'\n", + "\n", + " @aggregator\n", + " def start(self):\n", + " print(f'Performing initialization for model')\n", + " print(20*\"#\")\n", + " print(f\"Round {self.round}\")\n", + " print(20*\"#\")\n", + " self.collaborators = self.runtime.collaborators\n", + " self.private = 10\n", + " self.next(self.aggregated_model_validation,foreach='collaborators',exclude=['private'])\n", + "\n", + " @collaborator\n", + " def aggregated_model_validation(self):\n", + " print(f'Performing aggregated model validation for collaborator {self.input}')\n", + " self.agg_validation_score = inference(self.model,self.test_loader, self.device)\n", + " print(f'{self.input} value of {self.agg_validation_score}')\n", + " self.next(self.train)\n", + "\n", + " @collaborator\n", + " def train(self):\n", + " self.optimizer = optim.Adam(self.model.parameters(), lr=1e-4)\n", + " self.agg_method.on_train_begin(self.model)\n", + " self.model.train()\n", + " train_losses = []\n", + " for batch_idx, (data, target) in enumerate(self.train_loader):\n", + " data, target = torch.tensor(data).to(self.device), torch.tensor(\n", + " target).to(self.device) \n", + " self.optimizer.zero_grad()\n", + " output = self.model(data)\n", + " loss = F.cross_entropy(output, target) + self.agg_method.get_penalty(self.model, self.device)\n", + " loss.backward()\n", + " self.optimizer.step()\n", + " if batch_idx % log_interval == 0:\n", + " print('Train Epoch: 1 [{}/{} ({:.0f}%)]\\tLoss: {:.6f}'.format(\n", + " batch_idx * len(data), len(self.train_loader.dataset),\n", + " 100. * batch_idx / len(self.train_loader), loss.item()))\n", + " train_losses.append(loss.item())\n", + " torch.save(self.model.state_dict(), 'model.pth')\n", + " torch.save(self.optimizer.state_dict(), 'optimizer.pth')\n", + " self.loss = np.mean(train_losses)\n", + " print(\"Train loss\", self.loss)\n", + " self.agg_method.on_train_end(self.model, self.train_loader, self.device, 'cross_entropy')\n", + " self.training_completed = True\n", + " self.next(self.local_model_validation)\n", + "\n", + " @collaborator\n", + " def local_model_validation(self):\n", + " self.local_validation_score = inference(self.model,self.test_loader, self.device)\n", + " print(f'Performing local model validation for collaborator {self.input}: {self.local_validation_score}')\n", + " self.next(self.join, exclude=['training_completed'])\n", + "\n", + " @aggregator\n", + " def join(self,inputs):\n", + " self.average_loss = sum(input.loss for input in inputs)/len(inputs)\n", + " self.aggregated_model_accuracy = sum(input.agg_validation_score for input in inputs)/len(inputs)\n", + " self.local_model_accuracy = sum(input.local_validation_score for input in inputs)/len(inputs)\n", + " print(f'Average aggregated model validation values = {self.aggregated_model_accuracy}')\n", + " print(f'Average training loss = {self.average_loss}')\n", + " print(f'Average local model validation values = {self.local_model_accuracy}')\n", + " fedcurv_model_dict = fedcurv_weighted_average([input.model.state_dict() for input in inputs], [collaborators_weights_dict[col] for col in collaborators])\n", + " self.model.load_state_dict(fedcurv_model_dict)\n", + " self.next(self.check_round_completion)\n", + " \n", + " @aggregator\n", + " def check_round_completion(self):\n", + " if self.round != self.total_rounds:\n", + " if self.aggregated_model_accuracy > self.top_model_accuracy:\n", + " print(f'Accuracy improved to {self.aggregated_model_accuracy} for round {self.round}')\n", + " self.top_model_accuracy = self.aggregated_model_accuracy\n", + " \n", + " self.round += 1\n", + " print(20*\"#\")\n", + " print(f\"Round {self.round}\")\n", + " print(20*\"#\")\n", + " self.next(self.aggregated_model_validation, foreach='collaborators', exclude=['private'])\n", + " else:\n", + " self.next(self.end)\n", + "\n", + " @aggregator\n", + " def end(self):\n", + " print(f'This is the end of the flow')" + ] + }, + { + "cell_type": "markdown", + "id": "2aabf61e", + "metadata": {}, + "source": [ + "You'll notice in the `FederatedFlow` definition above that there were certain attributes that the flow was not initialized with, namely the `train_loader` and `test_loader` for each of the collaborators. These are **private_attributes** that are exposed only throught he runtime. Each participant has it's own set of private attributes: a dictionary where the key is the attribute name, and the value is the object that will be made accessible through that participant's task. \n", + "\n", + "Below, we segment shards of the MNIST dataset for **four collaborators**: Portland, Seattle, Chandler, and Portland. Each has their own slice of the dataset that's accessible via the `train_loader` or `test_loader` attribute. Note that the private attributes are flexible, and you can choose to pass in a completely different type of object to any of the collaborators or aggregator (with an arbitrary name). These private attributes will always be filtered out of the current state when transfering from collaborator to aggregator, or vice versa. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1c5a22ee-f422-423f-933e-3e961a66b9cd", + "metadata": {}, + "outputs": [], + "source": [ + "import torchvision\n", + "from torchvision import transforms as T\n", + "from torch.utils.data import Dataset, DataLoader\n", + "\n", + "normalize = T.Normalize(mean=[0.485, 0.456, 0.406],\n", + " std=[0.229, 0.224, 0.225])\n", + "\n", + "augmentation = T.RandomApply(\n", + " [T.RandomHorizontalFlip(),\n", + " T.RandomRotation(10),\n", + " T.RandomResizedCrop(64)], \n", + " p=.8\n", + ")\n", + "\n", + "training_transform = T.ToTensor()\n", + "test_transform = T.ToTensor()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "28367afa-10c0-49af-9c1a-13171f4be159", + "metadata": {}, + "outputs": [], + "source": [ + "class TransformedDataset(Dataset):\n", + " \"\"\"Image Person ReID Dataset.\"\"\"\n", + "\n", + " def __init__(self, dataset, transform=None, target_transform=None):\n", + " \"\"\"Initialize Dataset.\"\"\"\n", + " self.dataset = dataset\n", + " self.transform = transform\n", + " self.target_transform = target_transform\n", + "\n", + " def __len__(self):\n", + " \"\"\"Length of dataset.\"\"\"\n", + " return len(self.dataset)\n", + "\n", + " def __getitem__(self, index):\n", + " path, label = self.dataset[index]\n", + " with open(path, 'rb') as f:\n", + " img = Image.open(f)\n", + " img = img.convert('RGB')\n", + " label = self.target_transform(label) if self.target_transform else label\n", + " img = self.transform(img) if self.transform else img\n", + " return img, label" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "forward-world", + "metadata": {}, + "outputs": [], + "source": [ + "# Setup participants\n", + "aggregator = Aggregator()\n", + "aggregator.private_attributes = {}\n", + "\n", + "# Setup collaborators with private attributes\n", + "collaborator_names = ['Portland', 'Seattle', 'Chandler','Bangalore']\n", + "collaborators = [Collaborator(name=name) for name in collaborator_names]\n", + "# Keep a list of collaborator weights. The weights are decided by the number of samples for each collaborator\n", + "collaborators_weights_dict = {}\n", + "\n", + "\n", + "for idx, collaborator in enumerate(collaborators):\n", + " local_train = deepcopy(train_samples)\n", + " local_test = deepcopy(test_samples)\n", + " local_train = local_train[idx::len(collaborators)]\n", + " local_test = local_test[idx::len(collaborators)]\n", + " local_train = TransformedDataset(\n", + " local_train,\n", + " transform=training_transform\n", + " )\n", + " local_test = TransformedDataset(\n", + " local_test,\n", + " transform=test_transform\n", + " )\n", + " collaborator.private_attributes = {\n", + " 'train_loader': DataLoader(local_train,batch_size=batch_size_train, shuffle=True),\n", + " 'test_loader': DataLoader(local_test,batch_size=batch_size_train, shuffle=True)\n", + " }\n", + " collaborators_weights_dict[collaborator] = len(local_train)\n", + "\n", + "for col in collaborators_weights_dict:\n", + " collaborators_weights_dict[col] /= len(train_samples)\n", + "\n", + "if len(collaborators_weights_dict) != 0:\n", + " assert np.abs(1.0 - sum(collaborators_weights_dict.values())) < 0.01, (\n", + " f'Collaborator weights do not sum to 1.0: {collaborators_weights_dict}'\n", + " )\n", + "\n", + "local_runtime = LocalRuntime(aggregator=aggregator, collaborators=collaborators)\n", + "print(f'Local runtime collaborators = {local_runtime.collaborators}')" + ] + }, + { + "cell_type": "markdown", + "id": "278ad46b", + "metadata": {}, + "source": [ + "Now that we have our flow and runtime defined, let's run the experiment! " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "16937a65", + "metadata": {}, + "outputs": [], + "source": [ + "model = Net()\n", + "best_model = Net()\n", + "optimizer = optim.Adam(model.parameters(), lr=1e-4)\n", + "\n", + "top_model_accuracy = 0\n", + "total_rounds = 2\n", + "\n", + "flflow = FederatedFlow(model=model,\n", + " optimizer=optimizer,\n", + " total_rounds=total_rounds,\n", + " top_model_accuracy=top_model_accuracy)\n", + "\n", + "flflow.runtime = local_runtime\n", + "flflow.run()" + ] + }, + { + "cell_type": "markdown", + "id": "c32e0844", + "metadata": {}, + "source": [ + "Now that the flow has completed, let's get the final model and accuracy:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "863761fe", + "metadata": {}, + "outputs": [], + "source": [ + "print(f'Sample of the final model weights: {flflow.model.state_dict()[\"conv1.weight\"][0]}')\n", + "\n", + "print(f'\\nFinal aggregated model accuracy for {flflow.total_rounds} rounds of training: {flflow.aggregated_model_accuracy}')" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.13" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/openfl-tutorials/experimental/Workflow_Interface_104_MNIST_with_fedcurv.ipynb b/openfl-tutorials/experimental/Workflow_Interface_104_MNIST_with_fedcurv.ipynb new file mode 100644 index 0000000000..8478da732d --- /dev/null +++ b/openfl-tutorials/experimental/Workflow_Interface_104_MNIST_with_fedcurv.ipynb @@ -0,0 +1,411 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "14821d97", + "metadata": {}, + "source": [ + "# Workflow Interface 104: MNIST with Fedcurv implementation\n", + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/intel/openfl/blob/develop/openfl-tutorials/experimental/Workflow_Interface_104_MNIST_with_fedcurv.ipynb)" + ] + }, + { + "cell_type": "markdown", + "id": "a7989e72", + "metadata": {}, + "source": [ + "In this OpenFL workflow interface tutorial, we'll learn how to implement FedCurv aggregation algorithm using MNIST dataset." + ] + }, + { + "cell_type": "markdown", + "id": "fc8e35da", + "metadata": {}, + "source": [ + "# Getting Started" + ] + }, + { + "cell_type": "markdown", + "id": "4dbb89b6", + "metadata": {}, + "source": [ + "First we start by installing the necessary dependencies for the workflow interface" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f7f98600", + "metadata": {}, + "outputs": [], + "source": [ + "# !pip install git+https://github.com/intel/openfl.git\n", + "# !pip install -r https://raw.githubusercontent.com/intel/openfl/develop/openfl-tutorials/experimental/requirements_workflow_interface.txt\n", + "\n", + "# Uncomment this if running in Google Colab\n", + "#import os\n", + "#os.environ[\"USERNAME\"] = \"colab\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7e85e030", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "from copy import deepcopy\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "import torch.optim as optim\n", + "import torch\n", + "import torchvision\n", + "import numpy as np\n", + "\n", + "import warnings\n", + "warnings.filterwarnings(\"ignore\")\n", + "\n", + "batch_size_train = 64\n", + "batch_size_test = 1000\n", + "learning_rate = 0.01\n", + "momentum = 0.5\n", + "log_interval = 10\n", + "\n", + "mnist_train = torchvision.datasets.MNIST('files/', train=True, download=True,\n", + " transform=torchvision.transforms.Compose([\n", + " torchvision.transforms.ToTensor(),\n", + " torchvision.transforms.Normalize(\n", + " (0.1307,), (0.3081,))\n", + " ]))\n", + "\n", + "mnist_test = torchvision.datasets.MNIST('files/', train=False, download=True,\n", + " transform=torchvision.transforms.Compose([\n", + " torchvision.transforms.ToTensor(),\n", + " torchvision.transforms.Normalize(\n", + " (0.1307,), (0.3081,))\n", + " ]))\n", + "\n", + "class Net(nn.Module):\n", + " def __init__(self):\n", + " super(Net, self).__init__()\n", + " self.conv1 = nn.Conv2d(1, 10, kernel_size=5)\n", + " self.conv2 = nn.Conv2d(10, 20, kernel_size=5)\n", + " self.conv2_drop = nn.Dropout2d()\n", + " self.fc1 = nn.Linear(320, 50)\n", + " self.fc2 = nn.Linear(50, 10)\n", + "\n", + " def forward(self, x):\n", + " x = F.relu(F.max_pool2d(self.conv1(x), 2))\n", + " x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))\n", + " x = x.view(-1, 320)\n", + " x = F.relu(self.fc1(x))\n", + " x = F.dropout(x, training=self.training)\n", + " x = self.fc2(x)\n", + " return F.log_softmax(x)\n", + " \n", + "def inference(network, test_loader, device):\n", + " network = network.to(device)\n", + " network.eval()\n", + " test_loss = 0\n", + " correct = 0\n", + " with torch.no_grad():\n", + " for data, target in test_loader:\n", + " data = data.to(device)\n", + " target = target.to(device)\n", + " output = network(data)\n", + " test_loss += F.nll_loss(output, target, size_average=False).item()\n", + " pred = output.data.max(1, keepdim=True)[1]\n", + " correct += pred.eq(target.data.view_as(pred)).sum()\n", + " test_loss /= len(test_loader.dataset)\n", + " print('\\nTest set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\\n'.format(\n", + " test_loss, correct, len(test_loader.dataset),\n", + " 100. * correct / len(test_loader.dataset)))\n", + " accuracy = float(correct / len(test_loader.dataset))\n", + " return accuracy" + ] + }, + { + "cell_type": "markdown", + "id": "cd268911", + "metadata": {}, + "source": [ + "Next we import the `FLSpec`, `LocalRuntime`, and placement decorators.\n", + "\n", + "- `FLSpec` – Defines the flow specification. User defined flows are subclasses of this.\n", + "- `Runtime` – Defines where the flow runs, infrastructure for task transitions (how information gets sent). The `LocalRuntime` runs the flow on a single node.\n", + "- `aggregator/collaborator` - placement decorators that define where the task will be assigned\n", + "\n", + "In addition to these, we also import `FedCurv` module along with `FedcurvWeightedAvg` aggregation algorithm." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "precise-studio", + "metadata": {}, + "outputs": [], + "source": [ + "from copy import deepcopy\n", + "\n", + "from openfl.experimental.interface import FLSpec, Aggregator, Collaborator\n", + "from openfl.experimental.runtime import LocalRuntime\n", + "from openfl.experimental.placement import aggregator, collaborator\n", + "\n", + "from openfl.experimental.interface.aggregation_functions.fedcurv_weighted_average import fedcurv_weighted_average\n", + "from openfl.experimental.utilities.fedcurv import FedCurv" + ] + }, + { + "attachments": { + "image.png": { + "image/png": "" + } + }, + "cell_type": "markdown", + "id": "8e406db6", + "metadata": {}, + "source": [ + "Now we come to the flow definition. The OpenFL Workflow Interface adopts the conventions set by Metaflow, that every workflow begins with `start` and concludes with the `end` task. The aggregator begins with an optionally passed in model and optimizer. The aggregator begins the flow with the `start` task, where the list of collaborators is extracted from the runtime (`self.collaborators = self.runtime.collaborators`) and is then used as the list of participants to run the task listed in `self.next`, `aggregated_model_validation`. The model, optimizer, and anything that is not explicitly excluded from the next function will be passed from the `start` function on the aggregator to the `aggregated_model_validation` task on the collaborator. Where the tasks run is determined by the placement decorator that precedes each task definition (`@aggregator` or `@collaborator`). Once each of the collaborators (defined in the runtime) complete the `aggregated_model_validation` task, they pass their current state onto the `train` task, from `train` to `local_model_validation`, and then finally to `join` at the aggregator. It is in `join` that an average is taken of the model weights, and the next round can begin.\n", + "\n", + "![image.png](attachment:image.png)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "difficult-madrid", + "metadata": {}, + "outputs": [], + "source": [ + "class FederatedFlow(FLSpec):\n", + "\n", + " def __init__(self, model = None, optimizer = None, total_rounds = 10, top_model_accuracy=0, **kwargs):\n", + " super().__init__(**kwargs)\n", + " self.model = model\n", + " self.optimizer = optimizer\n", + " self.total_rounds = total_rounds\n", + " self.top_model_accuracy = top_model_accuracy\n", + " self.round = 0\n", + " self.agg_method = FedCurv(self.model, importance=1e4)\n", + " self.device = 'cpu'\n", + " if torch.cuda.is_available():\n", + " self.device = 'cuda:0'\n", + "\n", + " @aggregator\n", + " def start(self):\n", + " print(f'Performing initialization for model')\n", + " print(20*\"#\")\n", + " print(f\"Round {self.round}\")\n", + " print(20*\"#\")\n", + " self.collaborators = self.runtime.collaborators\n", + " self.private = 10\n", + " self.next(self.aggregated_model_validation,foreach='collaborators',exclude=['private'])\n", + "\n", + " @collaborator\n", + " def aggregated_model_validation(self):\n", + " print(f'Performing aggregated model validation for collaborator {self.input}')\n", + " self.agg_validation_score = inference(self.model,self.test_loader, self.device)\n", + " print(f'{self.input} value of {self.agg_validation_score}')\n", + " self.next(self.train)\n", + "\n", + " @collaborator\n", + " def train(self):\n", + " self.optimizer = optim.SGD(self.model.parameters(), lr=learning_rate,\n", + " momentum=momentum)\n", + " self.agg_method.on_train_begin(self.model)\n", + " self.model.train()\n", + " train_losses = []\n", + " for batch_idx, (data, target) in enumerate(self.train_loader):\n", + " data = data.to(self.device)\n", + " target = target.to(self.device)\n", + " self.optimizer.zero_grad()\n", + " output = self.model(data)\n", + " loss = F.nll_loss(output, target) + self.agg_method.get_penalty(self.model, self.device)\n", + " loss.backward()\n", + " self.optimizer.step()\n", + " if batch_idx % log_interval == 0:\n", + " print('Train Epoch: 1 [{}/{} ({:.0f}%)]\\tLoss: {:.6f}'.format(\n", + " batch_idx * len(data), len(self.train_loader.dataset),\n", + " 100. * batch_idx / len(self.train_loader), loss.item()))\n", + " train_losses.append(loss.item())\n", + " torch.save(self.model.state_dict(), 'model.pth')\n", + " torch.save(self.optimizer.state_dict(), 'optimizer.pth')\n", + " self.loss = np.mean(train_losses)\n", + " self.agg_method.on_train_end(self.model, self.train_loader, self.device, 'nll')\n", + " self.training_completed = True\n", + " self.next(self.local_model_validation)\n", + "\n", + " @collaborator\n", + " def local_model_validation(self):\n", + " self.local_validation_score = inference(self.model,self.test_loader, self.device)\n", + " print(f'Doing local model validation for collaborator {self.input}: {self.local_validation_score}')\n", + " self.next(self.join, exclude=['training_completed'])\n", + "\n", + " @aggregator\n", + " def join(self,inputs):\n", + " self.average_loss = sum(input.loss for input in inputs)/len(inputs)\n", + " self.aggregated_model_accuracy = sum(input.agg_validation_score for input in inputs)/len(inputs)\n", + " self.local_model_accuracy = sum(input.local_validation_score for input in inputs)/len(inputs)\n", + " print(f'Average aggregated model validation values = {self.aggregated_model_accuracy}')\n", + " print(f'Average training loss = {self.average_loss}')\n", + " print(f'Average local model validation values = {self.local_model_accuracy}')\n", + " fedcurv_model_dict = fedcurv_weighted_average([input.model.state_dict() for input in inputs], [collaborators_weights_dict[col] for col in collaborators])\n", + " self.model.load_state_dict(fedcurv_model_dict)\n", + " self.next(self.check_round_completion)\n", + " \n", + " @aggregator\n", + " def check_round_completion(self):\n", + " if self.round != self.total_rounds:\n", + " if self.aggregated_model_accuracy > self.top_model_accuracy:\n", + " print(f'Accuracy improved to {self.aggregated_model_accuracy} for round {self.round}')\n", + " self.top_model_accuracy = self.aggregated_model_accuracy\n", + " \n", + " self.round += 1\n", + " print(20*\"#\")\n", + " print(f\"Round {self.round}\")\n", + " print(20*\"#\")\n", + " self.next(self.aggregated_model_validation, foreach='collaborators', exclude=['private'])\n", + " else:\n", + " self.next(self.end)\n", + "\n", + " @aggregator\n", + " def end(self):\n", + " print(f'This is the end of the flow')" + ] + }, + { + "cell_type": "markdown", + "id": "2aabf61e", + "metadata": {}, + "source": [ + "You'll notice in the `FederatedFlow` definition above that there were certain attributes that the flow was not initialized with, namely the `train_loader` and `test_loader` for each of the collaborators. These are **private_attributes** that are exposed only throught he runtime. Each participant has it's own set of private attributes: a dictionary where the key is the attribute name, and the value is the object that will be made accessible through that participant's task. \n", + "\n", + "Below, we segment shards of the MNIST dataset for **four collaborators**: Portland, Seattle, Chandler, and Portland. Each has their own slice of the dataset that's accessible via the `train_loader` or `test_loader` attribute. Note that the private attributes are flexible, and you can choose to pass in a completely different type of object to any of the collaborators or aggregator (with an arbitrary name). These private attributes will always be filtered out of the current state when transfering from collaborator to aggregator, or vice versa. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "forward-world", + "metadata": {}, + "outputs": [], + "source": [ + "# Setup participants\n", + "aggregator = Aggregator()\n", + "aggregator.private_attributes = {}\n", + "\n", + "# Setup collaborators with private attributes\n", + "collaborator_names = ['Portland', 'Seattle', 'Chandler','Bangalore']\n", + "collaborators = [Collaborator(name=name) for name in collaborator_names]\n", + "# Keep a list of collaborator weights. The weights are decided by the number of samples for each collaborator\n", + "collaborators_weights_dict = {}\n", + "for idx, collaborator in enumerate(collaborators):\n", + " local_train = deepcopy(mnist_train)\n", + " local_test = deepcopy(mnist_test)\n", + " local_train.data = mnist_train.data[idx::len(collaborators)]\n", + " local_train.targets = mnist_train.targets[idx::len(collaborators)]\n", + " local_test.data = mnist_test.data[idx::len(collaborators)]\n", + " local_test.targets = mnist_test.targets[idx::len(collaborators)]\n", + " collaborator.private_attributes = {\n", + " 'train_loader': torch.utils.data.DataLoader(local_train,batch_size=batch_size_train, shuffle=True),\n", + " 'test_loader': torch.utils.data.DataLoader(local_test,batch_size=batch_size_train, shuffle=True)\n", + " }\n", + " collaborators_weights_dict[collaborator] = len(local_train.data)\n", + "\n", + "for col in collaborators_weights_dict:\n", + " collaborators_weights_dict[col] /= len(mnist_train.data)\n", + "\n", + "if len(collaborators_weights_dict) != 0:\n", + " assert np.abs(1.0 - sum(collaborators_weights_dict.values())) < 0.01, (\n", + " f'Collaborator weights do not sum to 1.0: {collaborators_weights_dict}'\n", + " )\n", + "\n", + "local_runtime = LocalRuntime(\n", + " aggregator=aggregator, collaborators=collaborators, backend=\"single_process\")\n", + "print(f'Local runtime collaborators = {local_runtime.collaborators}')" + ] + }, + { + "cell_type": "markdown", + "id": "278ad46b", + "metadata": {}, + "source": [ + "Now that we have our flow and runtime defined, let's run the experiment! " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "16937a65", + "metadata": {}, + "outputs": [], + "source": [ + "model = Net()\n", + "best_model = Net()\n", + "optimizer = optim.SGD(model.parameters(), lr=learning_rate,\n", + " momentum=momentum)\n", + "top_model_accuracy = 0\n", + "total_rounds = 5\n", + "\n", + "flflow = FederatedFlow(model=model,\n", + " optimizer=optimizer,\n", + " total_rounds=total_rounds,\n", + " top_model_accuracy=top_model_accuracy)\n", + "\n", + "flflow.runtime = local_runtime\n", + "flflow.run()\n" + ] + }, + { + "cell_type": "markdown", + "id": "c32e0844", + "metadata": {}, + "source": [ + "Now that the flow has completed, let's get the final model and accuracy:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "863761fe", + "metadata": {}, + "outputs": [], + "source": [ + "print(f'Sample of the final model weights: {flflow.model.state_dict()[\"conv1.weight\"][0]}')\n", + "\n", + "print(f'\\nFinal aggregated model accuracy for {flflow.total_rounds} rounds of training: {flflow.aggregated_model_accuracy}')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "712ca4fb-a31c-4420-b8ba-0a5114e3fe96", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.13" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/openfl-tutorials/experimental/Workflow_Interface_104_Synthetic_data_with_fedcurv.ipynb b/openfl-tutorials/experimental/Workflow_Interface_104_Synthetic_data_with_fedcurv.ipynb new file mode 100644 index 0000000000..8b38eadd0c --- /dev/null +++ b/openfl-tutorials/experimental/Workflow_Interface_104_Synthetic_data_with_fedcurv.ipynb @@ -0,0 +1,631 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "14821d97", + "metadata": {}, + "source": [ + "# Workflow Interface 104: Synthetic Data with Fedcurv implementation\n", + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/intel/openfl/blob/develop/openfl-tutorials/experimental/Workflow_Interface_104_Synthetic_data_with_fedcurv.ipynb)" + ] + }, + { + "cell_type": "markdown", + "id": "a7989e72", + "metadata": {}, + "source": [ + "In this OpenFL workflow interface tutorial, we'll learn how to implement FedCurv aggregation algorithm using Synthetic dataset. For more information on comparison amongst various aggregation algorithms, visit the [FedProx tutorial]." + ] + }, + { + "cell_type": "markdown", + "id": "fc8e35da", + "metadata": {}, + "source": [ + "# Getting Started" + ] + }, + { + "cell_type": "markdown", + "id": "4dbb89b6", + "metadata": {}, + "source": [ + "First we start by installing the necessary dependencies for the workflow interface" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f7f98600", + "metadata": {}, + "outputs": [], + "source": [ + "# !pip install git+https://github.com/intel/openfl.git\n", + "# !pip install -r https://raw.githubusercontent.com/intel/openfl/develop/openfl-tutorials/experimental/requirements_workflow_interface.txt\n", + "\n", + "# Uncomment this if running in Google Colab\n", + "#import os\n", + "#os.environ[\"USERNAME\"] = \"colab\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ac08226b-6127-4387-9c29-3becb93bdbc2", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import torch.utils.data as data\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "\n", + "import numpy as np\n", + "\n", + "import random\n", + "import collections\n", + "\n", + "import warnings\n", + "warnings.filterwarnings(\"ignore\")" + ] + }, + { + "cell_type": "markdown", + "id": "0001a45f-9aa6-4a75-8df0-eb57432e09dc", + "metadata": {}, + "source": [ + "Now we'll generate synthetic dataset and define the Synthetic Dataset class for our experiment." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f4194228-08ec-4a1f-ba76-86161564fe9c", + "metadata": {}, + "outputs": [], + "source": [ + "RANDOM_SEED = 10\n", + "batch_size = 10\n", + "\n", + "# Sets seed to reproduce the results\n", + "def set_seed(seed):\n", + " torch.manual_seed(seed)\n", + " torch.cuda.manual_seed_all(seed)\n", + " torch.use_deterministic_algorithms(True)\n", + " torch.backends.cudnn.deterministic = True\n", + " torch.backends.cudnn.benchmark = False\n", + " torch.backends.cudnn.enabled = False\n", + " np.random.seed(seed)\n", + " random.seed(seed)\n", + "\n", + "# Uncomment the line below for setting seed.\n", + "# set_seed(RANDOM_SEED)\n", + "\n", + "\n", + "def one_hot(labels, classes):\n", + " return np.eye(classes)[labels]\n", + "\n", + "\n", + "def softmax(x):\n", + " ex = np.exp(x)\n", + " sum_ex = np.sum(np.exp(x))\n", + " return ex / sum_ex\n", + "\n", + "\n", + "def generate_synthetic(alpha, beta, iid, num_collaborators, num_classes):\n", + " dimension = 60\n", + " NUM_CLASS = num_classes\n", + " NUM_USER = num_collaborators\n", + "\n", + " samples_per_user = np.random.lognormal(4, 2, (NUM_USER)).astype(int) + 50\n", + " num_samples = np.sum(samples_per_user)\n", + "\n", + " X_split = [[] for _ in range(NUM_USER)]\n", + " y_split = [[] for _ in range(NUM_USER)]\n", + "\n", + " #### define some eprior ####\n", + " mean_W = np.random.normal(0, alpha, NUM_USER)\n", + " mean_b = mean_W\n", + " B = np.random.normal(0, beta, NUM_USER)\n", + " mean_x = np.zeros((NUM_USER, dimension))\n", + "\n", + " diagonal = np.zeros(dimension)\n", + " for j in range(dimension):\n", + " diagonal[j] = np.power((j + 1), -1.2)\n", + " cov_x = np.diag(diagonal)\n", + "\n", + " for i in range(NUM_USER):\n", + " if iid == 1:\n", + " mean_x[i] = np.ones(dimension) * B[i] # all zeros\n", + " else:\n", + " mean_x[i] = np.random.normal(B[i], 1, dimension)\n", + "\n", + " if iid == 1:\n", + " W_global = np.random.normal(0, 1, (dimension, NUM_CLASS))\n", + " b_global = np.random.normal(0, 1, NUM_CLASS)\n", + "\n", + " for i in range(NUM_USER):\n", + "\n", + " W = np.random.normal(mean_W[i], 1, (dimension, NUM_CLASS))\n", + " b = np.random.normal(mean_b[i], 1, NUM_CLASS)\n", + "\n", + " if iid == 1:\n", + " W = W_global\n", + " b = b_global\n", + "\n", + " xx = np.random.multivariate_normal(\n", + " mean_x[i], cov_x, samples_per_user[i])\n", + " yy = np.zeros(samples_per_user[i])\n", + "\n", + " for j in range(samples_per_user[i]):\n", + " tmp = np.dot(xx[j], W) + b\n", + " yy[j] = np.argmax(softmax(tmp))\n", + "\n", + " X_split[i] = xx.tolist()\n", + " y_split[i] = yy.tolist()\n", + "\n", + " return X_split, y_split\n", + "\n", + "\n", + "class SyntheticFederatedDataset:\n", + " def __init__(self, num_collaborators, batch_size=1, num_classes=10, **kwargs):\n", + " self.batch_size = batch_size\n", + " X, y = generate_synthetic(0.0, 0.0, 0, num_collaborators, num_classes)\n", + " X = [np.array([np.array(sample).astype(np.float32)\n", + " for sample in col]) for col in X]\n", + " y = [np.array([np.array(one_hot(int(sample), num_classes))\n", + " for sample in col]) for col in y]\n", + " self.X_train_all = np.array([col[:int(0.9 * len(col))] for col in X], dtype=np.ndarray)\n", + " self.X_valid_all = np.array([col[int(0.9 * len(col)):] for col in X], dtype=np.ndarray)\n", + " self.y_train_all = np.array([col[:int(0.9 * len(col))] for col in y], dtype=np.ndarray)\n", + " self.y_valid_all = np.array([col[int(0.9 * len(col)):] for col in y], dtype=np.ndarray)\n", + "\n", + " def split(self, collaborators):\n", + " for i, collaborator in enumerate(collaborators):\n", + " collaborator.private_attributes = {\n", + " \"train_loader\":\n", + " data.DataLoader(\n", + " data.TensorDataset(\n", + " torch.from_numpy(self.X_train_all[i]),\n", + " torch.from_numpy(self.y_train_all[i])\n", + " ), \n", + " batch_size=batch_size, shuffle=True\n", + " ),\n", + " \"test_loader\":\n", + " data.DataLoader(\n", + " data.TensorDataset(\n", + " torch.from_numpy(self.X_valid_all[i]),\n", + " torch.from_numpy(self.y_valid_all[i])\n", + " ), \n", + " batch_size=batch_size, shuffle=True\n", + " )\n", + " }" + ] + }, + { + "cell_type": "markdown", + "id": "49081491-3b88-4339-adc0-4f8f5b39cada", + "metadata": {}, + "source": [ + "Let's now define the model, optimizer and some helper functions." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7e85e030", + "metadata": {}, + "outputs": [], + "source": [ + "class Net(nn.Module):\n", + " def __init__(self):\n", + " # Set RANDOM_SEED to reproduce same model\n", + " torch.set_rng_state(torch.manual_seed(RANDOM_SEED).get_state())\n", + " super(Net, self).__init__()\n", + " self.linear1 = nn.Linear(60, 100)\n", + " self.linear2 = nn.Linear(100, 10)\n", + "\n", + " def forward(self, x):\n", + " x = self.linear1(x)\n", + " x = self.linear2(x)\n", + " return x\n", + " \n", + "def cross_entropy(output, target, size_average=None):\n", + " \"\"\"\n", + " Binary cross-entropy metric\n", + "\n", + " \"\"\"\n", + " return F.cross_entropy(output, torch.max(target, 1)[1], size_average=size_average)\n", + "\n", + "\n", + "def compute_loss_and_acc(network, dataloader):\n", + " \"\"\"\n", + " Model test method\n", + "\n", + " Args:\n", + " network: class Net object (model)\n", + " dataloader: torch.utils.data.DataLoader\n", + "\n", + " Returns:\n", + " (accuracy,\n", + " loss,\n", + " correct,\n", + " dataloader_size)\n", + " \"\"\"\n", + " network.eval()\n", + " test_loss = 0\n", + " correct = 0\n", + " with torch.no_grad():\n", + " for data, target in dataloader:\n", + " output = network(data)\n", + " test_loss += cross_entropy(output, target).item()\n", + " tar = target.argmax(dim=1, keepdim=True)\n", + " pred = output.argmax(dim=1, keepdim=True)\n", + " correct += pred.eq(tar).sum().cpu().numpy()\n", + " dataloader_size = len(dataloader.dataset)\n", + " test_loss /= dataloader_size\n", + " accuracy = float(correct / dataloader_size)\n", + " return accuracy, test_loss, correct" + ] + }, + { + "cell_type": "markdown", + "id": "cd268911", + "metadata": {}, + "source": [ + "Next we import the `FLSpec`, `LocalRuntime`, and placement decorators.\n", + "\n", + "- `FLSpec` – Defines the flow specification. User defined flows are subclasses of this.\n", + "- `Runtime` – Defines where the flow runs, infrastructure for task transitions (how information gets sent). The `LocalRuntime` runs the flow on a single node.\n", + "- `aggregator/collaborator` - placement decorators that define where the task will be assigned\n", + "\n", + "In addition to these, we also import `FedCurv` module along with `FedcurvWeightedAvg` aggregation algorithm." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "precise-studio", + "metadata": {}, + "outputs": [], + "source": [ + "from copy import deepcopy\n", + "\n", + "from openfl.experimental.interface import FLSpec, Aggregator, Collaborator\n", + "from openfl.experimental.runtime import LocalRuntime\n", + "from openfl.experimental.placement import aggregator, collaborator\n", + "\n", + "from openfl.experimental.interface.aggregation_functions.fedcurv_weighted_average import fedcurv_weighted_average\n", + "from openfl.experimental.utilities.fedcurv import FedCurv" + ] + }, + { + "cell_type": "markdown", + "id": "3e09ee12-ce1a-43dd-8c17-0403da643a1e", + "metadata": {}, + "source": [ + "Let us now define the Workflow for our experiment. We use the methodology as provided in quickstart, and define the workflow consisting of following steps:\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "difficult-madrid", + "metadata": {}, + "outputs": [], + "source": [ + "class FederatedFlow(FLSpec):\n", + "\n", + " def __init__(self, model = None, optimizer = None, agg_method = None, n_selected_collaborators=10, total_rounds = 10, **kwargs):\n", + " super().__init__(**kwargs)\n", + " self.n_selected_collaborators = n_selected_collaborators\n", + " self.total_rounds = total_rounds\n", + " self.round_number = 0\n", + " self.total_rounds = total_rounds\n", + " if model is not None:\n", + " self.model = model\n", + " self.optimizer = optimizer\n", + " self.agg_method = agg_method\n", + " else:\n", + " self.model = Net()\n", + " self.optimizer = optim.SGD(self.model.parameters(), lr=learning_rate,\n", + " momentum=momentum)\n", + " self.agg_method = FedCurv(self.model, importance=1e4)\n", + " self.device = 'cpu'\n", + " if torch.cuda.is_available():\n", + " self.device = 'cuda:0'\n", + "\n", + " @aggregator\n", + " def start(self):\n", + " print(f'Performing initialization for model')\n", + " print(20*\"#\")\n", + " print(f\"Round {self.round_number}\")\n", + " print(20*\"#\")\n", + " self.collaborators = self.runtime.collaborators\n", + " self.next(self.compute_loss_and_accuracy,foreach='collaborators')\n", + "\n", + " @collaborator\n", + " def compute_loss_and_accuracy(self):\n", + " \"\"\"\n", + " Compute training accuracy, training loss, aggregated validation accuracy,\n", + " aggregated validation loss, \n", + " \"\"\"\n", + " # Compute Train Loss and Train Acc\n", + " self.training_accuracy, self.training_loss, _, = compute_loss_and_acc(\n", + " self.model, self.train_loader)\n", + " \n", + " # Compute Test Loss and Test Acc\n", + " self.agg_validation_score, self.agg_validation_loss, test_correct = compute_loss_and_acc(\n", + " self.model, self.test_loader)\n", + "\n", + " self.train_dataset_length = len(self.train_loader.dataset)\n", + " self.test_dataset_length = len(self.test_loader.dataset)\n", + "\n", + " print(\n", + " \" | Train Round: {:<5} : Train Loss {:<.6f}, Test Acc: {:<.6f} [{}/{}]\".format(\n", + " self.input,\n", + " self.round_number,\n", + " self.training_loss,\n", + " self.agg_validation_score,\n", + " test_correct, \n", + " self.test_dataset_length\n", + " )\n", + " )\n", + "\n", + " self.next(self.gather_results_and_take_weighted_average)\n", + "\n", + " @aggregator\n", + " def gather_results_and_take_weighted_average(self, inputs):\n", + " \"\"\"\n", + " Gather results of all collaborators computed in previous \n", + " step.\n", + " Compute train and test weightes, and compute weighted average of \n", + " aggregated training loss, and aggregated test accuracy\n", + " \"\"\"\n", + " # Calculate train_weights and test_weights\n", + " train_datasize, test_datasize = [], []\n", + " for input_ in inputs:\n", + " train_datasize.append(input_.train_dataset_length)\n", + " test_datasize.append(input_.test_dataset_length)\n", + "\n", + " self.train_weights, self.test_weights = [], []\n", + " for input_ in inputs:\n", + " self.train_weights.append(input_.train_dataset_length / sum(train_datasize))\n", + " self.test_weights.append(input_.test_dataset_length / sum(test_datasize))\n", + "\n", + " aggregated_model_accuracy_list, aggregated_model_loss_list = [], []\n", + " for input_ in inputs:\n", + " aggregated_model_loss_list.append(input_.training_loss)\n", + " aggregated_model_accuracy_list.append(input_.agg_validation_score)\n", + "\n", + " # Weighted average of training loss\n", + " self.aggregated_model_training_loss = fedcurv_weighted_average(aggregated_model_loss_list, self.train_weights)\n", + "\n", + " # Weighted average of aggregated model accuracy\n", + " self.aggregated_model_test_accuracy = fedcurv_weighted_average(aggregated_model_accuracy_list, self.test_weights)\n", + " \n", + " print(\n", + " \" | Train Round: {:<5} : Agg Train Loss {:<.6f}, Agg Test Acc: {:<.6f}\".format(\n", + " self.round_number,\n", + " self.aggregated_model_training_loss,\n", + " self.aggregated_model_test_accuracy\n", + " )\n", + " )\n", + "\n", + " self.next(self.select_collaborators)\n", + " \n", + " @aggregator\n", + " def select_collaborators(self):\n", + " \"\"\"\n", + " Randomly select n_selected_collaborators collaborator\n", + " \"\"\"\n", + " np.random.seed(self.round_number)\n", + " self.selected_collaborator_indices = np.random.choice(range(len(self.collaborators)), \\\n", + " self.n_selected_collaborators, replace=False)\n", + " self.selected_collaborators = [self.collaborators[idx] for idx in self.selected_collaborator_indices]\n", + "\n", + " self.next(self.train_selected_collaborators, foreach=\"selected_collaborators\")\n", + "\n", + " \n", + " @collaborator\n", + " def train_selected_collaborators(self):\n", + " \"\"\"\n", + " Train selected collaborators\n", + " \"\"\"\n", + "\n", + " self.train_dataset_length = len(self.train_loader.dataset)\n", + "\n", + " self.optimizer = optim.SGD(self.model.parameters(), lr=learning_rate,\n", + " momentum=momentum)\n", + "\n", + " self.agg_method.on_train_begin(self.model)\n", + " self.model = self.model.to(self.device)\n", + " self.model.train(mode=True)\n", + " \n", + " for epoch in range(local_epoch):\n", + " train_loss = []\n", + " correct = 0\n", + " for data, target in self.train_loader:\n", + " data = data.to(self.device)\n", + " target = target.to(self.device)\n", + " self.optimizer.zero_grad()\n", + " output = self.model(data)\n", + " loss = cross_entropy(output, target) + self.agg_method.get_penalty(self.model, self.device)\n", + " loss.backward()\n", + " self.optimizer.step()\n", + " pred = output.argmax(dim=1, keepdim=True)\n", + " tar = target.argmax(dim=1, keepdim=True)\n", + " correct += pred.eq(tar).sum().cpu().numpy()\n", + " train_loss.append(loss.item())\n", + " training_accuracy = float(correct / self.train_dataset_length)\n", + " training_loss = np.mean(train_loss)\n", + " print(\n", + " \" | Train Round: {:<5} | Local Epoch: {:<3}: FedCurv Optimization Train Loss {:<.6f}, Train Acc: {:<.6f} [{}/{}]\".format(\n", + " self.input,\n", + " self.round_number,\n", + " epoch,\n", + " training_loss,\n", + " training_accuracy,\n", + " correct, \n", + " len(self.train_loader.dataset)\n", + " )\n", + " )\n", + " self.agg_method.on_train_end(self.model, self.train_loader, self.device, 'cross_entropy')\n", + " self.next(self.join)\n", + "\n", + "\n", + " @aggregator\n", + " def join(self,inputs):\n", + " train_datasize = sum([input_.train_dataset_length for input_ in inputs])\n", + "\n", + " train_weights, model_state_dict_list = [], [] \n", + " for input_ in inputs:\n", + " train_weights.append(input_.train_dataset_length / train_datasize)\n", + " model_state_dict_list.append(input_.model.state_dict())\n", + " fedcurv_model_dict = fedcurv_weighted_average(model_state_dict_list, train_weights)\n", + " self.model.load_state_dict(fedcurv_model_dict)\n", + " self.next(self.end)\n", + " \n", + " @aggregator\n", + " def end(self):\n", + " if self.round_number == self.total_rounds - 1:\n", + " print(f'This is the end of the flow')\n", + " else:\n", + " self.round_number += 1" + ] + }, + { + "cell_type": "markdown", + "id": "93f94ed7-5dd2-4faf-81a7-418ee973a4ef", + "metadata": {}, + "source": [ + "****Federation Setup****\n", + "\n", + "We'll now setup the federation by defining number of collaborators, initializing dataset and Runtime." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "af23a974-534d-4728-96a9-e0f67f455720", + "metadata": {}, + "outputs": [], + "source": [ + "num_collaborators = 30\n", + "\n", + "# Setup aggregator\n", + "aggregator = Aggregator()\n", + "aggregator.private_attributes = {}\n", + "\n", + "# Setup collaborators with private attributes\n", + "collaborator_names = [f\"col{i}\" for i in range(num_collaborators)]\n", + "\n", + "collaborators = [Collaborator(name=name) for name in collaborator_names]\n", + "\n", + "synthetic_federated_dataset = SyntheticFederatedDataset(\n", + " batch_size=batch_size, num_classes=10, num_collaborators=len(collaborators), seed=RANDOM_SEED)\n", + "synthetic_federated_dataset.split(collaborators)\n", + "\n", + "local_runtime = LocalRuntime(\n", + " aggregator=aggregator, collaborators=collaborators, backend=\"single_process\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "89358389-c025-458b-bbec-1ceef2965317", + "metadata": {}, + "outputs": [], + "source": [ + "loss_and_acc = {\n", + " \"Fedcurv\": {\n", + " \"Train Loss\": [], \"Test Accuracy\": []\n", + " },\n", + " \"FedAvg\": {\n", + " \"Train Loss\": [], \"Test Accuracy\": []\n", + " }\n", + "}" + ] + }, + { + "cell_type": "markdown", + "id": "278ad46b", + "metadata": {}, + "source": [ + "Now that we have our flow and runtime defined, let's run the experiment! " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "16937a65", + "metadata": {}, + "outputs": [], + "source": [ + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "import torch.optim as optim\n", + "\n", + "n_selected_collaborators = 10\n", + "n_epochs = 200\n", + "local_epoch = 20\n", + "total_rounds = 5\n", + "\n", + "learning_rate = 0.01\n", + "momentum = 0.5\n", + "log_interval = 10\n", + "\n", + "flflow = FederatedFlow(n_selected_collaborators=n_selected_collaborators,\n", + " total_rounds=total_rounds)\n", + "\n", + "flflow.runtime = local_runtime\n", + "for i in range(n_epochs):\n", + " flflow.run()\n", + " aggregated_model_training_loss = flflow.aggregated_model_training_loss\n", + " aggregated_model_test_accuracy = flflow.aggregated_model_test_accuracy\n", + "\n", + " loss_and_acc[\"Fedcurv\"][\"Train Loss\"].append(aggregated_model_training_loss)\n", + " loss_and_acc[\"Fedcurv\"][\"Test Accuracy\"].append(aggregated_model_test_accuracy)\n" + ] + }, + { + "cell_type": "markdown", + "id": "f1daf1fd-7bf6-49cc-9bfe-1eb6359228a3", + "metadata": {}, + "source": [ + "**Comparison of aggregation algorithms**\n", + "\n", + "Now that we have demonstrated Fedcurv on synthetic dataset, let's run through the [FedProx tutorial] to see how Fedcurv compares to FedAvg and FedProx." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b3629e9a-1c55-43c8-84ee-eaaabe0a2112", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.13" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/openfl/experimental/interface/aggregation_functions/fedcurv_weighted_average.py b/openfl/experimental/interface/aggregation_functions/fedcurv_weighted_average.py new file mode 100644 index 0000000000..e6bb3115b9 --- /dev/null +++ b/openfl/experimental/interface/aggregation_functions/fedcurv_weighted_average.py @@ -0,0 +1,44 @@ +"""Supported aggregation functions.""" + +import collections +import numpy as np +import torch + +from openfl.interface.aggregation_functions.weighted_average import weighted_average as wa + + +def fedcurv_weighted_average(tensors, weights): + """ + Aggregation function of FedCurv algorithm. + Applies weighted average aggregation to all tensors + except Fisher matrices variables (u_t, v_t). + These variables are summed without weights. + FedCurv paper: https://arxiv.org/pdf/1910.07796.pdf + + Args: + tensors: Models state_dict list or optimizers state_dict list or loss list or accuracy list + weights: Weight for each element in the list + + Returns: + dict: Incase model list / optimizer list OR + float: Incase of loss list or accuracy list + """ + # Check the type of first element of tensors list + if type(tensors[0]) in (dict, collections.OrderedDict): + tmp_state_dict = {} + input_state_dict_keys = tensors[0].keys() + + # Use diag elements of Fisher matrix + for key in input_state_dict_keys: + if (key[-2:] == '_u' or key[-2:] == '_v' or key[-2:] == '_w'): + tmp_state_dict[key] = np.sum([tensor[key].detach().cpu() + if type(tensor[key]) is torch.Tensor + else tensor[key].cpu() for tensor in tensors], axis=0) + continue + tmp_state_dict[key] = np.average([tensor[key].detach().cpu() + if type(tensor[key]) is torch.Tensor + else tensor[key].cpu() for tensor in tensors], + weights=weights, axis=0) + return tmp_state_dict + else: + return wa(tensors, weights) diff --git a/openfl/utilities/fedcurv/torch/fedcurv.py b/openfl/utilities/fedcurv/torch/fedcurv.py index 0e18de1a3a..1cac73ff50 100644 --- a/openfl/utilities/fedcurv/torch/fedcurv.py +++ b/openfl/utilities/fedcurv/torch/fedcurv.py @@ -21,28 +21,6 @@ def register_buffer(module: torch.nn.Module, name: str, value: torch.Tensor): mod.register_buffer(name, value) -def get_buffer(module, target): - """Get module buffer. - - Remove after pinning to a version - where https://github.com/pytorch/pytorch/pull/61429 is included. - Use module.get_buffer() instead. - """ - module_path, _, buffer_name = target.rpartition('.') - - mod: torch.nn.Module = module.get_submodule(module_path) - - if not hasattr(mod, buffer_name): - raise AttributeError(f'{mod._get_name()} has no attribute `{buffer_name}`') - - buffer: torch.Tensor = getattr(mod, buffer_name) - - if buffer_name not in mod._buffers: - raise AttributeError('`' + buffer_name + '` is not a buffer') - - return buffer - - class FedCurv: """Federated Curvature class. @@ -80,7 +58,7 @@ def _register_fisher_parameters(self, model): def _update_params(self, model): self._params = deepcopy({n: p for n, p in model.named_parameters() if p.requires_grad}) - def _diag_fisher(self, model, data_loader, device): + def _diag_fisher(self, model, data_loader, device='cpu', loss_fn='nll'): precision_matrices = {} for n, p in self._params.items(): p.data.zero_() @@ -93,7 +71,10 @@ def _diag_fisher(self, model, data_loader, device): sample = sample.to(device) target = target.to(device) output = model(sample) - loss = F.nll_loss(F.log_softmax(output, dim=1), target) + if loss_fn == 'cross_entropy': + loss = F.cross_entropy(output, target) + else: + loss = F.nll_loss(F.log_softmax(output, dim=1), target) loss.backward() for n, p in model.named_parameters(): @@ -102,7 +83,7 @@ def _diag_fisher(self, model, data_loader, device): return precision_matrices - def get_penalty(self, model): + def get_penalty(self, model, device='cpu'): """Calculate the penalty term for the loss function. Args: @@ -117,11 +98,11 @@ def get_penalty(self, model): for name, param in model.named_parameters(): if param.requires_grad: u_global, v_global, w_global = ( - get_buffer(model, target).detach() + model.get_buffer(target).detach().to(device) for target in (f'{name}_u', f'{name}_v', f'{name}_w') ) u_local, v_local, w_local = ( - getattr(self, name).detach() + getattr(self, name).detach().to(device) for name in (f'{name}_u', f'{name}_v', f'{name}_w') ) u = u_global - u_local @@ -140,7 +121,7 @@ def on_train_begin(self, model): """ self._update_params(model) - def on_train_end(self, model: torch.nn.Module, data_loader, device): + def on_train_end(self, model: torch.nn.Module, data_loader, device='cpu', loss_fn='nll'): """Post-train steps. Args: @@ -149,7 +130,7 @@ def on_train_end(self, model: torch.nn.Module, data_loader, device): device(str): Model device. loss_fn(Callable): Train loss function. """ - precision_matrices = self._diag_fisher(model, data_loader, device) + precision_matrices = self._diag_fisher(model, data_loader, device, loss_fn) for n, m in precision_matrices.items(): u = m.data.to(device) v = m.data * model.get_parameter(n) diff --git a/setup.py b/setup.py index c02129324f..13244c2736 100644 --- a/setup.py +++ b/setup.py @@ -104,6 +104,7 @@ def run(self): 'openfl.databases.utilities', 'openfl.experimental', 'openfl.experimental.interface', + 'openfl.experimental.interface.aggregation_functions', 'openfl.experimental.placement', 'openfl.experimental.runtime', 'openfl.experimental.utilities',