From e0060f9fb213049d24099b3ade02cf7fb319ff44 Mon Sep 17 00:00:00 2001 From: gozdeg Date: Mon, 21 Aug 2023 18:24:52 +0100 Subject: [PATCH] FedDropoutAvg Tutorial FedDropoutAvg[*] Tutorial Using Workflow Interface On CIFAR10 [Gunesli, G. N., Bilal, M., Raza, S. E. A., & Rajpoot, N. M. (2021). Feddropoutavg: Generalizable federated learning for histopathology image classification. arXiv preprint arXiv:2111.13230.] --- ..._Using_Workflow_Interface_On_CIFAR10.ipynb | 1629 +++++++++++++++++ 1 file changed, 1629 insertions(+) create mode 100644 openfl-tutorials/experimental/FedDropoutAvg_Tutorial_Using_Workflow_Interface_On_CIFAR10.ipynb diff --git a/openfl-tutorials/experimental/FedDropoutAvg_Tutorial_Using_Workflow_Interface_On_CIFAR10.ipynb b/openfl-tutorials/experimental/FedDropoutAvg_Tutorial_Using_Workflow_Interface_On_CIFAR10.ipynb new file mode 100644 index 0000000000..c9fea4a72d --- /dev/null +++ b/openfl-tutorials/experimental/FedDropoutAvg_Tutorial_Using_Workflow_Interface_On_CIFAR10.ipynb @@ -0,0 +1,1629 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "id": "e52ac176", + "metadata": {}, + "source": [ + "# FedDropoutAvg Tutorial using OpenFL Workflow Interface - PyTorch CIFAR10" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "b83202c1", + "metadata": {}, + "source": [ + "* This notebook provides implementation of the __\"FedDropoutAvg\" algorithm__ __[[arXiv link]](https://arxiv.org/abs/2111.13230)__, together with the ResNet18 model with GroupNorm layers used in the paper.
\n", + "\n", + " * In a nutshell, FedDropoutAvg proposes to use dropout mechanisms to aggregate parameters of deep neural network models trained at different client sites into a federated model. \n", + "\n", + " * It proposes to use dropout mechanisms in two aspects: \n", + " 1. __client selection__: random dropout of clients for each round of federated training,\n", + " 2. __federated averaging (aggregation)__: random dropout of parameters of locally trained models for aggregation into a federated model.\n", + " \n", + "
\n", + " \n", + " * FedDropoutAvg is designed to mitigate the effects of the heterogeneity of the real-world multi-institutional histopathological datasets. However, in this tutorial we are using a toy dataset (CIFAR10) and we are randomly dividing the data between collaborators. \n", + " \n", + "\n", + "
\n", + "\n", + "\n", + "* This tutorial is adapted from the OpenFL tutorial __[\"Workflow_Interface_101_MNIST.ipynb\"](https://github.com/securefederatedai/openfl/blob/develop/openfl-tutorials/experimental/Workflow_Interface_101_MNIST.ipynb)__\n", + "\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "66deb5ec", + "metadata": {}, + "source": [ + "## Getting Started \n", + "First we start by installing the necessary dependencies for the workflow interface" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "f7f98600", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Collecting git+https://github.com/intel/openfl.git\n", + " Cloning https://github.com/intel/openfl.git to /tmp/pip-req-build-dv8v4swi\n", + " Running command git clone --filter=blob:none --quiet https://github.com/intel/openfl.git /tmp/pip-req-build-dv8v4swi\n", + " Resolved https://github.com/intel/openfl.git to commit ed501ebbd6ffab6d10b4347a8c34369564a373b2\n", + " Preparing metadata (setup.py) ... \u001b[?25ldone\n", + "\u001b[?25hRequirement already satisfied: Click==8.0.1 in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from openfl==1.5) (8.0.1)\n", + "Requirement already satisfied: PyYAML>=5.4.1 in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from openfl==1.5) (6.0)\n", + "Requirement already satisfied: cloudpickle in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from openfl==1.5) (2.2.1)\n", + "Requirement already satisfied: cryptography>=3.4.6 in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from openfl==1.5) (41.0.1)\n", + "Requirement already satisfied: docker in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages/docker-6.1.3-py3.8.egg (from openfl==1.5) (6.1.3)\n", + "Requirement already satisfied: dynaconf==3.1.7 in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages/dynaconf-3.1.7-py3.8.egg (from openfl==1.5) (3.1.7)\n", + "Requirement already satisfied: flatten_json in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages/flatten_json-0.1.13-py3.8.egg (from openfl==1.5) (0.1.13)\n", + "Requirement already satisfied: grpcio~=1.48.2 in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from openfl==1.5) (1.48.2)\n", + "Requirement already satisfied: ipykernel in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from openfl==1.5) (6.23.2)\n", + "Requirement already satisfied: jupyterlab in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from openfl==1.5) (4.0.2)\n", + "Requirement already satisfied: numpy in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from openfl==1.5) (1.24.3)\n", + "Requirement already satisfied: pandas in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from openfl==1.5) (1.5.3)\n", + "Requirement already satisfied: protobuf==3.19.6 in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from openfl==1.5) (3.19.6)\n", + "Requirement already satisfied: requests in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from openfl==1.5) (2.31.0)\n", + "Requirement already satisfied: rich in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from openfl==1.5) (13.4.2)\n", + "Requirement already satisfied: scikit-learn in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from openfl==1.5) (1.2.2)\n", + "Requirement already satisfied: tensorboard in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from openfl==1.5) (2.12.1)\n", + "Requirement already satisfied: tensorboardX<=2.6 in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages/tensorboardX-2.6-py3.8.egg (from openfl==1.5) (2.6)\n", + "Requirement already satisfied: tqdm in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from openfl==1.5) (4.65.0)\n", + "Requirement already satisfied: cffi>=1.12 in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from cryptography>=3.4.6->openfl==1.5) (1.15.1)\n", + "Requirement already satisfied: six>=1.5.2 in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from grpcio~=1.48.2->openfl==1.5) (1.16.0)\n", + "Requirement already satisfied: packaging in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from tensorboardX<=2.6->openfl==1.5) (23.0)\n", + "Requirement already satisfied: urllib3>=1.26.0 in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from docker->openfl==1.5) (1.26.16)\n", + "Requirement already satisfied: websocket-client>=0.32.0 in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from docker->openfl==1.5) (1.5.3)\n", + "Requirement already satisfied: charset-normalizer<4,>=2 in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from requests->openfl==1.5) (2.0.4)\n", + "Requirement already satisfied: idna<4,>=2.5 in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from requests->openfl==1.5) (3.4)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from requests->openfl==1.5) (2023.5.7)\n", + "Requirement already satisfied: comm>=0.1.1 in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from ipykernel->openfl==1.5) (0.1.3)\n", + "Requirement already satisfied: debugpy>=1.6.5 in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from ipykernel->openfl==1.5) (1.6.7)\n", + "Requirement already satisfied: ipython>=7.23.1 in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from ipykernel->openfl==1.5) (8.12.2)\n", + "Requirement already satisfied: jupyter-client>=6.1.12 in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from ipykernel->openfl==1.5) (8.2.0)\n", + "Requirement already satisfied: jupyter-core!=5.0.*,>=4.12 in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from ipykernel->openfl==1.5) (5.3.0)\n", + "Requirement already satisfied: matplotlib-inline>=0.1 in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from ipykernel->openfl==1.5) (0.1.6)\n", + "Requirement already satisfied: nest-asyncio in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from ipykernel->openfl==1.5) (1.5.6)\n", + "Requirement already satisfied: psutil in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from ipykernel->openfl==1.5) (5.9.0)\n", + "Requirement already satisfied: pyzmq>=20 in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from ipykernel->openfl==1.5) (25.1.0)\n", + "Requirement already satisfied: tornado>=6.1 in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from ipykernel->openfl==1.5) (6.3.2)\n", + "Requirement already satisfied: traitlets>=5.4.0 in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from ipykernel->openfl==1.5) (5.9.0)\n", + "Requirement already satisfied: async-lru>=1.0.0 in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from jupyterlab->openfl==1.5) (2.0.2)\n", + "Requirement already satisfied: importlib-metadata>=4.8.3 in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from jupyterlab->openfl==1.5) (6.0.0)\n", + "Requirement already satisfied: importlib-resources>=1.4 in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from jupyterlab->openfl==1.5) (5.12.0)\n", + "Requirement already satisfied: jinja2>=3.0.3 in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from jupyterlab->openfl==1.5) (3.0.3)\n", + "Requirement already satisfied: jupyter-lsp>=2.0.0 in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from jupyterlab->openfl==1.5) (2.2.0)\n", + "Requirement already satisfied: jupyter-server<3,>=2.4.0 in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from jupyterlab->openfl==1.5) (2.6.0)\n", + "Requirement already satisfied: jupyterlab-server<3,>=2.19.0 in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from jupyterlab->openfl==1.5) (2.23.0)\n", + "Requirement already satisfied: notebook-shim>=0.2 in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from jupyterlab->openfl==1.5) (0.2.3)\n", + "Requirement already satisfied: tomli in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from jupyterlab->openfl==1.5) (2.0.1)\n", + "Requirement already satisfied: python-dateutil>=2.8.1 in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from pandas->openfl==1.5) (2.8.2)\n", + "Requirement already satisfied: pytz>=2020.1 in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from pandas->openfl==1.5) (2022.7)\n", + "Requirement already satisfied: markdown-it-py>=2.2.0 in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from rich->openfl==1.5) (2.2.0)\n", + "Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from rich->openfl==1.5) (2.15.1)\n", + "Requirement already satisfied: typing-extensions<5.0,>=4.0.0 in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from rich->openfl==1.5) (4.6.3)\n", + "Requirement already satisfied: scipy>=1.3.2 in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from scikit-learn->openfl==1.5) (1.10.1)\n", + "Requirement already satisfied: joblib>=1.1.1 in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from scikit-learn->openfl==1.5) (1.2.0)\n", + "Requirement already satisfied: threadpoolctl>=2.0.0 in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from scikit-learn->openfl==1.5) (3.1.0)\n", + "Requirement already satisfied: absl-py>=0.4 in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from tensorboard->openfl==1.5) (1.4.0)\n", + "Requirement already satisfied: google-auth<3,>=1.6.3 in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from tensorboard->openfl==1.5) (2.6.0)\n", + "Requirement already satisfied: google-auth-oauthlib<1.1,>=0.5 in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from tensorboard->openfl==1.5) (0.5.2)\n", + "Requirement already satisfied: markdown>=2.6.8 in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from tensorboard->openfl==1.5) (3.4.1)\n", + "Requirement already satisfied: setuptools>=41.0.0 in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from tensorboard->openfl==1.5) (68.1.2)\n", + "Requirement already satisfied: tensorboard-data-server<0.8.0,>=0.7.0 in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from tensorboard->openfl==1.5) (0.7.0)\n", + "Requirement already satisfied: tensorboard-plugin-wit>=1.6.0 in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from tensorboard->openfl==1.5) (1.8.1)\n", + "Requirement already satisfied: werkzeug>=1.0.1 in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from tensorboard->openfl==1.5) (2.2.3)\n", + "Requirement already satisfied: wheel>=0.26 in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from tensorboard->openfl==1.5) (0.41.1)\n", + "Requirement already satisfied: pycparser in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from cffi>=1.12->cryptography>=3.4.6->openfl==1.5) (2.21)\n", + "Requirement already satisfied: cachetools<6.0,>=2.0.0 in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from google-auth<3,>=1.6.3->tensorboard->openfl==1.5) (4.2.2)\n", + "Requirement already satisfied: pyasn1-modules>=0.2.1 in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from google-auth<3,>=1.6.3->tensorboard->openfl==1.5) (0.2.8)\n", + "Requirement already satisfied: rsa<5,>=3.1.4 in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from google-auth<3,>=1.6.3->tensorboard->openfl==1.5) (4.7.2)\n", + "Requirement already satisfied: requests-oauthlib>=0.7.0 in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from google-auth-oauthlib<1.1,>=0.5->tensorboard->openfl==1.5) (1.3.0)\n", + "Requirement already satisfied: zipp>=0.5 in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from importlib-metadata>=4.8.3->jupyterlab->openfl==1.5) (3.11.0)\n", + "Requirement already satisfied: backcall in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from ipython>=7.23.1->ipykernel->openfl==1.5) (0.2.0)\n", + "Requirement already satisfied: decorator in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from ipython>=7.23.1->ipykernel->openfl==1.5) (5.1.1)\n", + "Requirement already satisfied: jedi>=0.16 in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from ipython>=7.23.1->ipykernel->openfl==1.5) (0.18.2)\n", + "Requirement already satisfied: pickleshare in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from ipython>=7.23.1->ipykernel->openfl==1.5) (0.7.5)\n", + "Requirement already satisfied: prompt-toolkit!=3.0.37,<3.1.0,>=3.0.30 in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from ipython>=7.23.1->ipykernel->openfl==1.5) (3.0.38)\n", + "Requirement already satisfied: stack-data in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from ipython>=7.23.1->ipykernel->openfl==1.5) (0.6.2)\n", + "Requirement already satisfied: pexpect>4.3 in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from ipython>=7.23.1->ipykernel->openfl==1.5) (4.8.0)\n", + "Requirement already satisfied: MarkupSafe>=2.0 in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from jinja2>=3.0.3->jupyterlab->openfl==1.5) (2.1.1)\n", + "Requirement already satisfied: platformdirs>=2.5 in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from jupyter-core!=5.0.*,>=4.12->ipykernel->openfl==1.5) (3.5.3)\n", + "Requirement already satisfied: anyio>=3.1.0 in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from jupyter-server<3,>=2.4.0->jupyterlab->openfl==1.5) (3.7.0)\n", + "Requirement already satisfied: argon2-cffi in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from jupyter-server<3,>=2.4.0->jupyterlab->openfl==1.5) (21.3.0)\n", + "Requirement already satisfied: jupyter-events>=0.6.0 in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from jupyter-server<3,>=2.4.0->jupyterlab->openfl==1.5) (0.6.3)\n", + "Requirement already satisfied: jupyter-server-terminals in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from jupyter-server<3,>=2.4.0->jupyterlab->openfl==1.5) (0.4.4)\n", + "Requirement already satisfied: nbconvert>=6.4.4 in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from jupyter-server<3,>=2.4.0->jupyterlab->openfl==1.5) (7.5.0)\n", + "Requirement already satisfied: nbformat>=5.3.0 in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from jupyter-server<3,>=2.4.0->jupyterlab->openfl==1.5) (5.9.0)\n", + "Requirement already satisfied: overrides in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from jupyter-server<3,>=2.4.0->jupyterlab->openfl==1.5) (7.3.1)\n", + "Requirement already satisfied: prometheus-client in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from jupyter-server<3,>=2.4.0->jupyterlab->openfl==1.5) (0.17.0)\n", + "Requirement already satisfied: send2trash in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from jupyter-server<3,>=2.4.0->jupyterlab->openfl==1.5) (1.8.2)\n", + "Requirement already satisfied: terminado>=0.8.3 in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from jupyter-server<3,>=2.4.0->jupyterlab->openfl==1.5) (0.17.1)\n", + "Requirement already satisfied: babel>=2.10 in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from jupyterlab-server<3,>=2.19.0->jupyterlab->openfl==1.5) (2.12.1)\n", + "Requirement already satisfied: json5>=0.9.0 in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from jupyterlab-server<3,>=2.19.0->jupyterlab->openfl==1.5) (0.9.14)\n", + "Requirement already satisfied: jsonschema>=4.17.3 in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from jupyterlab-server<3,>=2.19.0->jupyterlab->openfl==1.5) (4.17.3)\n", + "Requirement already satisfied: mdurl~=0.1 in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from markdown-it-py>=2.2.0->rich->openfl==1.5) (0.1.2)\n", + "Requirement already satisfied: sniffio>=1.1 in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from anyio>=3.1.0->jupyter-server<3,>=2.4.0->jupyterlab->openfl==1.5) (1.3.0)\n", + "Requirement already satisfied: exceptiongroup in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from anyio>=3.1.0->jupyter-server<3,>=2.4.0->jupyterlab->openfl==1.5) (1.1.1)\n", + "Requirement already satisfied: parso<0.9.0,>=0.8.0 in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from jedi>=0.16->ipython>=7.23.1->ipykernel->openfl==1.5) (0.8.3)\n", + "Requirement already satisfied: attrs>=17.4.0 in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from jsonschema>=4.17.3->jupyterlab-server<3,>=2.19.0->jupyterlab->openfl==1.5) (22.1.0)\n", + "Requirement already satisfied: pkgutil-resolve-name>=1.3.10 in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from jsonschema>=4.17.3->jupyterlab-server<3,>=2.19.0->jupyterlab->openfl==1.5) (1.3.10)\n", + "Requirement already satisfied: pyrsistent!=0.17.0,!=0.17.1,!=0.17.2,>=0.14.0 in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from jsonschema>=4.17.3->jupyterlab-server<3,>=2.19.0->jupyterlab->openfl==1.5) (0.19.3)\n", + "Requirement already satisfied: python-json-logger>=2.0.4 in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from jupyter-events>=0.6.0->jupyter-server<3,>=2.4.0->jupyterlab->openfl==1.5) (2.0.7)\n", + "Requirement already satisfied: rfc3339-validator in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from jupyter-events>=0.6.0->jupyter-server<3,>=2.4.0->jupyterlab->openfl==1.5) (0.1.4)\n", + "Requirement already satisfied: rfc3986-validator>=0.1.1 in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from jupyter-events>=0.6.0->jupyter-server<3,>=2.4.0->jupyterlab->openfl==1.5) (0.1.1)\n", + "Requirement already satisfied: beautifulsoup4 in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from nbconvert>=6.4.4->jupyter-server<3,>=2.4.0->jupyterlab->openfl==1.5) (4.12.2)\n", + "Requirement already satisfied: bleach!=5.0.0 in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from nbconvert>=6.4.4->jupyter-server<3,>=2.4.0->jupyterlab->openfl==1.5) (6.0.0)\n", + "Requirement already satisfied: defusedxml in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from nbconvert>=6.4.4->jupyter-server<3,>=2.4.0->jupyterlab->openfl==1.5) (0.7.1)\n", + "Requirement already satisfied: jupyterlab-pygments in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from nbconvert>=6.4.4->jupyter-server<3,>=2.4.0->jupyterlab->openfl==1.5) (0.2.2)\n", + "Requirement already satisfied: mistune<3,>=2.0.3 in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from nbconvert>=6.4.4->jupyter-server<3,>=2.4.0->jupyterlab->openfl==1.5) (2.0.5)\n", + "Requirement already satisfied: nbclient>=0.5.0 in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from nbconvert>=6.4.4->jupyter-server<3,>=2.4.0->jupyterlab->openfl==1.5) (0.7.4)\n", + "Requirement already satisfied: pandocfilters>=1.4.1 in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from nbconvert>=6.4.4->jupyter-server<3,>=2.4.0->jupyterlab->openfl==1.5) (1.5.0)\n", + "Requirement already satisfied: tinycss2 in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from nbconvert>=6.4.4->jupyter-server<3,>=2.4.0->jupyterlab->openfl==1.5) (1.2.1)\n", + "Requirement already satisfied: fastjsonschema in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from nbformat>=5.3.0->jupyter-server<3,>=2.4.0->jupyterlab->openfl==1.5) (2.17.1)\n", + "Requirement already satisfied: ptyprocess>=0.5 in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from pexpect>4.3->ipython>=7.23.1->ipykernel->openfl==1.5) (0.7.0)\n", + "Requirement already satisfied: wcwidth in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from prompt-toolkit!=3.0.37,<3.1.0,>=3.0.30->ipython>=7.23.1->ipykernel->openfl==1.5) (0.2.6)\n", + "Requirement already satisfied: pyasn1<0.5.0,>=0.4.6 in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from pyasn1-modules>=0.2.1->google-auth<3,>=1.6.3->tensorboard->openfl==1.5) (0.4.8)\n", + "Requirement already satisfied: oauthlib>=3.0.0 in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<1.1,>=0.5->tensorboard->openfl==1.5) (3.2.2)\n", + "Requirement already satisfied: argon2-cffi-bindings in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from argon2-cffi->jupyter-server<3,>=2.4.0->jupyterlab->openfl==1.5) (21.2.0)\n", + "Requirement already satisfied: executing>=1.2.0 in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from stack-data->ipython>=7.23.1->ipykernel->openfl==1.5) (1.2.0)\n", + "Requirement already satisfied: asttokens>=2.1.0 in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from stack-data->ipython>=7.23.1->ipykernel->openfl==1.5) (2.2.1)\n", + "Requirement already satisfied: pure-eval in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from stack-data->ipython>=7.23.1->ipykernel->openfl==1.5) (0.2.2)\n", + "Requirement already satisfied: webencodings in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from bleach!=5.0.0->nbconvert>=6.4.4->jupyter-server<3,>=2.4.0->jupyterlab->openfl==1.5) (0.5.1)\n", + "Requirement already satisfied: fqdn in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from jsonschema>=4.17.3->jupyterlab-server<3,>=2.19.0->jupyterlab->openfl==1.5) (1.5.1)\n", + "Requirement already satisfied: isoduration in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from jsonschema>=4.17.3->jupyterlab-server<3,>=2.19.0->jupyterlab->openfl==1.5) (20.11.0)\n", + "Requirement already satisfied: jsonpointer>1.13 in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from jsonschema>=4.17.3->jupyterlab-server<3,>=2.19.0->jupyterlab->openfl==1.5) (2.3)\n", + "Requirement already satisfied: uri-template in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from jsonschema>=4.17.3->jupyterlab-server<3,>=2.19.0->jupyterlab->openfl==1.5) (1.2.0)\n", + "Requirement already satisfied: webcolors>=1.11 in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from jsonschema>=4.17.3->jupyterlab-server<3,>=2.19.0->jupyterlab->openfl==1.5) (1.13)\n", + "Requirement already satisfied: soupsieve>1.2 in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from beautifulsoup4->nbconvert>=6.4.4->jupyter-server<3,>=2.4.0->jupyterlab->openfl==1.5) (2.4.1)\n", + "Requirement already satisfied: arrow>=0.15.0 in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from isoduration->jsonschema>=4.17.3->jupyterlab-server<3,>=2.19.0->jupyterlab->openfl==1.5) (1.2.3)\n", + "Requirement already satisfied: dill==0.3.6 in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from -r requirements_workflow_interface.txt (line 1)) (0.3.6)\n", + "Requirement already satisfied: metaflow==2.7.15 in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from -r requirements_workflow_interface.txt (line 2)) (2.7.15)\n", + "Requirement already satisfied: ray==2.2.0 in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from -r requirements_workflow_interface.txt (line 3)) (2.2.0)\n", + "Requirement already satisfied: requests in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from metaflow==2.7.15->-r requirements_workflow_interface.txt (line 2)) (2.31.0)\n", + "Requirement already satisfied: boto3 in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from metaflow==2.7.15->-r requirements_workflow_interface.txt (line 2)) (1.28.29)\n", + "Requirement already satisfied: pylint in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from metaflow==2.7.15->-r requirements_workflow_interface.txt (line 2)) (2.17.5)\n", + "Requirement already satisfied: attrs in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from ray==2.2.0->-r requirements_workflow_interface.txt (line 3)) (22.1.0)\n", + "Requirement already satisfied: click>=7.0 in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from ray==2.2.0->-r requirements_workflow_interface.txt (line 3)) (8.0.1)\n", + "Requirement already satisfied: filelock in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from ray==2.2.0->-r requirements_workflow_interface.txt (line 3)) (3.12.2)\n", + "Requirement already satisfied: jsonschema in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from ray==2.2.0->-r requirements_workflow_interface.txt (line 3)) (4.17.3)\n", + "Requirement already satisfied: msgpack<2.0.0,>=1.0.0 in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from ray==2.2.0->-r requirements_workflow_interface.txt (line 3)) (1.0.3)\n", + "Requirement already satisfied: protobuf!=3.19.5,>=3.15.3 in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from ray==2.2.0->-r requirements_workflow_interface.txt (line 3)) (3.19.6)\n", + "Requirement already satisfied: pyyaml in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from ray==2.2.0->-r requirements_workflow_interface.txt (line 3)) (6.0)\n", + "Requirement already satisfied: aiosignal in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from ray==2.2.0->-r requirements_workflow_interface.txt (line 3)) (1.2.0)\n", + "Requirement already satisfied: frozenlist in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from ray==2.2.0->-r requirements_workflow_interface.txt (line 3)) (1.3.3)\n", + "Requirement already satisfied: virtualenv>=20.0.24 in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from ray==2.2.0->-r requirements_workflow_interface.txt (line 3)) (20.23.0)\n", + "Requirement already satisfied: grpcio>=1.32.0 in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from ray==2.2.0->-r requirements_workflow_interface.txt (line 3)) (1.48.2)\n", + "Requirement already satisfied: numpy>=1.16 in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from ray==2.2.0->-r requirements_workflow_interface.txt (line 3)) (1.24.3)\n", + "Requirement already satisfied: six>=1.5.2 in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from grpcio>=1.32.0->ray==2.2.0->-r requirements_workflow_interface.txt (line 3)) (1.16.0)\n", + "Requirement already satisfied: distlib<1,>=0.3.6 in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from virtualenv>=20.0.24->ray==2.2.0->-r requirements_workflow_interface.txt (line 3)) (0.3.6)\n", + "Requirement already satisfied: platformdirs<4,>=3.2 in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from virtualenv>=20.0.24->ray==2.2.0->-r requirements_workflow_interface.txt (line 3)) (3.5.3)\n", + "Requirement already satisfied: botocore<1.32.0,>=1.31.29 in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from boto3->metaflow==2.7.15->-r requirements_workflow_interface.txt (line 2)) (1.31.29)\n", + "Requirement already satisfied: jmespath<2.0.0,>=0.7.1 in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from boto3->metaflow==2.7.15->-r requirements_workflow_interface.txt (line 2)) (1.0.1)\n", + "Requirement already satisfied: s3transfer<0.7.0,>=0.6.0 in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from boto3->metaflow==2.7.15->-r requirements_workflow_interface.txt (line 2)) (0.6.2)\n", + "Requirement already satisfied: importlib-resources>=1.4.0 in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from jsonschema->ray==2.2.0->-r requirements_workflow_interface.txt (line 3)) (5.12.0)\n", + "Requirement already satisfied: pkgutil-resolve-name>=1.3.10 in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from jsonschema->ray==2.2.0->-r requirements_workflow_interface.txt (line 3)) (1.3.10)\n", + "Requirement already satisfied: pyrsistent!=0.17.0,!=0.17.1,!=0.17.2,>=0.14.0 in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from jsonschema->ray==2.2.0->-r requirements_workflow_interface.txt (line 3)) (0.19.3)\n", + "Requirement already satisfied: astroid<=2.17.0-dev0,>=2.15.6 in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from pylint->metaflow==2.7.15->-r requirements_workflow_interface.txt (line 2)) (2.15.6)\n", + "Requirement already satisfied: isort<6,>=4.2.5 in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from pylint->metaflow==2.7.15->-r requirements_workflow_interface.txt (line 2)) (5.12.0)\n", + "Requirement already satisfied: mccabe<0.8,>=0.6 in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from pylint->metaflow==2.7.15->-r requirements_workflow_interface.txt (line 2)) (0.7.0)\n", + "Requirement already satisfied: tomlkit>=0.10.1 in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from pylint->metaflow==2.7.15->-r requirements_workflow_interface.txt (line 2)) (0.12.1)\n", + "Requirement already satisfied: typing-extensions>=3.10.0 in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from pylint->metaflow==2.7.15->-r requirements_workflow_interface.txt (line 2)) (4.6.3)\n", + "Requirement already satisfied: tomli>=1.1.0 in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from pylint->metaflow==2.7.15->-r requirements_workflow_interface.txt (line 2)) (2.0.1)\n", + "Requirement already satisfied: charset-normalizer<4,>=2 in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from requests->metaflow==2.7.15->-r requirements_workflow_interface.txt (line 2)) (2.0.4)\n", + "Requirement already satisfied: idna<4,>=2.5 in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from requests->metaflow==2.7.15->-r requirements_workflow_interface.txt (line 2)) (3.4)\n", + "Requirement already satisfied: urllib3<3,>=1.21.1 in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from requests->metaflow==2.7.15->-r requirements_workflow_interface.txt (line 2)) (1.26.16)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from requests->metaflow==2.7.15->-r requirements_workflow_interface.txt (line 2)) (2023.5.7)\n", + "Requirement already satisfied: lazy-object-proxy>=1.4.0 in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from astroid<=2.17.0-dev0,>=2.15.6->pylint->metaflow==2.7.15->-r requirements_workflow_interface.txt (line 2)) (1.9.0)\n", + "Requirement already satisfied: wrapt<2,>=1.11 in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from astroid<=2.17.0-dev0,>=2.15.6->pylint->metaflow==2.7.15->-r requirements_workflow_interface.txt (line 2)) (1.15.0)\n", + "Requirement already satisfied: python-dateutil<3.0.0,>=2.1 in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from botocore<1.32.0,>=1.31.29->boto3->metaflow==2.7.15->-r requirements_workflow_interface.txt (line 2)) (2.8.2)\n", + "Requirement already satisfied: zipp>=3.1.0 in /home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages (from importlib-resources>=1.4.0->jsonschema->ray==2.2.0->-r requirements_workflow_interface.txt (line 3)) (3.11.0)\n" + ] + } + ], + "source": [ + "!pip install git+https://github.com/intel/openfl.git\n", + "!pip install -r requirements_workflow_interface.txt\n", + "\n", + "# Uncomment this if running in Google Colab\n", + "#!pip install -r https://raw.githubusercontent.com/intel/openfl/develop/openfl-tutorials/experimental/requirements_workflow_interface.txt\n", + "#import os\n", + "#os.environ[\"USERNAME\"] = \"colab\"" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "7fec81df", + "metadata": {}, + "source": [ + "## Defining our dataloaders, model, optimizer, some helper functions, and the _`cdr` (client dropout rate)_ and _`fdr` (federated dropout rate)_ parameters which will be used for the FedDropoutAvg." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "7e85e030", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Files already downloaded and verified\n", + "Files already downloaded and verified\n" + ] + } + ], + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "import torch.optim as optim\n", + "import torch\n", + "import torchvision\n", + "from torchvision import models\n", + "import numpy as np\n", + "\n", + "n_rounds = 2 # number of rounds\n", + "batch_size_train = 256\n", + "batch_size_test = 1000\n", + "learning_rate = 0.01\n", + "momentum = 0.5\n", + "log_interval = 1\n", + "\n", + "#FedDropoutAvg parameters, if fdr==0 and cdr==0 it is same with FedAvg\n", + "fdr = 0.3 # federated dropout rate\n", + "cdr = 0.2 # client dropout rate\n", + "\n", + "random_seed = 1\n", + "torch.backends.cudnn.enabled = True \n", + "torch.manual_seed(random_seed)\n", + "\n", + "\n", + "\n", + "\n", + "transforms_train = torchvision.transforms.Compose([torchvision.transforms.RandomHorizontalFlip(),\n", + " torchvision.transforms.ToTensor(),\n", + " torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),\n", + " ])\n", + "\n", + "transforms_test = torchvision.transforms.Compose([ torchvision.transforms.ToTensor(),\n", + " torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),\n", + " ])\n", + "\n", + "\n", + "cifar10_train = torchvision.datasets.CIFAR10('files/', train=True, download=True, transform= transforms_train ) \n", + "cifar10_test = torchvision.datasets.CIFAR10('files/', train=False, download=True, transform=transforms_test)\n", + " \n", + "\n", + "class GroupNorm32(nn.GroupNorm):\n", + " def __init__(self, num_channels, num_groups=32, **kargs):\n", + " super().__init__(num_groups, num_channels, **kargs)\n", + " \n", + "\n", + "class ResNet18(nn.Module):\n", + " def __init__(self, norm = 'gn'):\n", + " # Default norm: norm layer type is GroupNorm32. If norm == 'bn', BatchNorm2d will be used - not performing well with FL\n", + " super(ResNet18, self).__init__()\n", + " \n", + " if norm == 'gn':\n", + " norm_layer = GroupNorm32\n", + " elif norm == 'bn':\n", + " norm_layer = nn.BatchNorm2d\n", + "\n", + " self.model_ft = models.resnet18(pretrained = False, norm_layer = norm_layer, num_classes = 10)\n", + " \n", + " self.model_ft = nn.Sequential(self.model_ft)\n", + "\n", + " def forward(self, x):\n", + " x = self.model_ft(x)\n", + " return F.log_softmax(x) #x \n", + "\n", + "\n", + "\n", + " \n", + "def inference(network,test_loader):\n", + " network.eval()\n", + " test_loss = 0\n", + " correct = 0\n", + " with torch.no_grad():\n", + " for data, target in test_loader:\n", + " output = network(data)\n", + " test_loss += F.nll_loss(output, target, size_average=False).item()\n", + " pred = output.data.max(1, keepdim=True)[1]\n", + " correct += pred.eq(target.data.view_as(pred)).sum()\n", + " test_loss /= len(test_loader.dataset)\n", + " print('\\nTest set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\\n'.format(\n", + " test_loss, correct, len(test_loader.dataset),\n", + " 100. * correct / len(test_loader.dataset)))\n", + " accuracy = float(correct / len(test_loader.dataset))\n", + " return accuracy" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "bf5d38ef", + "metadata": {}, + "source": [ + "## Implementation of the FedDropoutAvg class" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "precise-studio", + "metadata": {}, + "outputs": [], + "source": [ + "import copy\n", + "#from copy import deepcopy\n", + "\n", + "\n", + "# FedAvg Algo from the original tutorial\n", + "def FedAvg(models, weights=None):\n", + " new_model = models[0]\n", + " state_dicts = [model.state_dict() for model in models]\n", + " state_dict = new_model.state_dict()\n", + " for key in models[1].state_dict():\n", + " state_dict[key] = torch.from_numpy(np.average([state[key].numpy() for state in state_dicts],\n", + " axis=0, \n", + " weights=weights))\n", + " new_model.load_state_dict(state_dict)\n", + " return new_model\n", + "\n", + "\n", + "\n", + "# FedDropoutAvg class, implementing random client selection and model aggregation with dropout. \n", + "class FedDropoutAvg(): \n", + " def __init__(self, workers_dataset_sizes=None, fdr=0.3, cdr=0.2):\n", + " \n", + " self.workers_dataset_sizes = workers_dataset_sizes\n", + " self.simple_average = (workers_dataset_sizes==None) # Simple unweighted average\n", + " self.fdr = fdr # federated dropout rate\n", + " self.cdr = cdr\n", + " print('* fed_drop_avg init *')\n", + " print(\"workers_dataset_sizes : {}\".format(workers_dataset_sizes))\n", + " print()\n", + " \n", + " def get_fed_avg_weights(self, selected_worker_ids):\n", + " \n", + " size_list = [self.workers_dataset_sizes[id] for id in selected_worker_ids]\n", + " n_clients = len(selected_worker_ids)\n", + " total_data_points = np.asarray(size_list).sum()\n", + " \n", + " if(self.simple_average): \n", + " fed_avg_weights = [1 /n_clients for r in range(n_clients)]\n", + " else: # Weighted according to number of samples\n", + " fed_avg_weights = [size_list[r] / total_data_points for r in range(n_clients)]\n", + "\n", + "\n", + " print('* get_fed_avg_weights *')\n", + " print(\"FedAvg Weights: {}\".format(fed_avg_weights))\n", + " \n", + " return fed_avg_weights \n", + " \n", + " def aggregate(self, models, selected_worker_ids): # Updates model using state_dicts\n", + "\n", + " print(\"FedDropoutAvg aggragation step. # of models to aggregate = \", len(models))\n", + " new_model = models[0]\n", + " state_dicts = [model.state_dict() for model in models]\n", + "\n", + " dr_rate = self.fdr\n", + " new_state_dict = {}\n", + "\n", + " fed_avg_weights = self.get_fed_avg_weights(selected_worker_ids) # contribution weights \n", + "\n", + " keys = state_dicts[0].keys()\n", + "\n", + " for key in keys:\n", + " curr_shape = state_dicts[0][key].shape\n", + " selection_shape = np.asarray(list(curr_shape) + [len(state_dicts)])\n", + " selection_arr = (np.random.random(selection_shape) >= dr_rate).astype(int) \n", + " #print('selection_arr : ', selection_arr.shape)\n", + " #print(fed_avg_weights)\n", + "\n", + " curr_sum = np.asarray([fed_avg_weights[i] * selection_arr[...,i] for i in range(len(state_dicts))])\n", + " curr_sum = sum(curr_sum)\n", + "\n", + " for r in range(len(state_dicts)):\n", + "\n", + " # Recalculating the contribution weights for each parameter of each model after parameter dropout\n", + " curr_weights = (selection_arr[...,r] * fed_avg_weights[r] / curr_sum)\n", + "\n", + " curr_weights = np.asarray(curr_weights) # for some cases (i.e., with bn layers) where 'curr_weights' becomes a 'numpy.float64' object\n", + " curr_weights[np.isnan(curr_weights)] = 0 # for rare cases where curr_sum was 0\n", + " \n", + " if(key not in new_state_dict.keys()): \n", + " new_state_dict[key] = copy.deepcopy(state_dicts[r][key]) * curr_weights\n", + " else:\n", + " new_state_dict[key] += copy.deepcopy(state_dicts[r][key]) * curr_weights\n", + "\n", + " # Load new model weights\n", + " new_model.load_state_dict(new_state_dict)\n", + " return new_model\n", + "\n", + "\n", + " def select_random_clients(self, worker_ids):\n", + "\n", + " # Random worker (collaborator) selection for the round \n", + " # Uses random choice, so always same number of clients each round\n", + "\n", + " num_selected = int((1-self.cdr)*len(worker_ids))\n", + " if(num_selected == 0):\n", + " print(\"ERR: num_selected == 0\")\n", + " return None\n", + " selected_workers_this_round = np.concatenate([np.ones(num_selected, dtype=bool), np.zeros(len(worker_ids) - num_selected, dtype=bool)])\n", + " np.random.shuffle(selected_workers_this_round)\n", + "\n", + " selected_worker_ids_this_round = []\n", + " for ind in range(len(worker_ids)):\n", + " if(selected_workers_this_round[ind]):\n", + " selected_worker_ids_this_round += [worker_ids[ind]]\n", + "\n", + " print()\n", + " print('client_dropout_rate = ', self.cdr)\n", + " print('selected_workers_this_round = ', selected_workers_this_round)\n", + " print('selected_worker_ids_this_round = ', selected_worker_ids_this_round)\n", + " print()\n", + "\n", + " return selected_worker_ids_this_round\n", + "\n", + " # # # # choice updated" + ] + }, + { + "cell_type": "markdown", + "id": "cd268911", + "metadata": {}, + "source": [ + "Next we import the `FLSpec`, `LocalRuntime`, and placement decorators.\n", + "\n", + "- `FLSpec` – Defines the flow specification. User defined flows are subclasses of this.\n", + "- `Runtime` – Defines where the flow runs, infrastructure for task transitions (how information gets sent). The `LocalRuntime` runs the flow on a single node.\n", + "- `aggregator/collaborator` - placement decorators that define where the task will be assigned" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "3e0bf04b", + "metadata": {}, + "outputs": [], + "source": [ + "from openfl.experimental.interface import FLSpec, Aggregator, Collaborator\n", + "from openfl.experimental.runtime import LocalRuntime\n", + "from openfl.experimental.placement import aggregator, collaborator" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "fa89dfdd", + "metadata": {}, + "source": [ + "* Now we come to the flow definition. The OpenFL Workflow Interface adopts the conventions set by Metaflow, that every workflow begins with `start` and concludes with the `end` task. The aggregator begins with an optionally passed in model and optimizer. The aggregator begins the flow with the `start` task, where the list of collaborators is extracted and is then used as the list of participants to run the task listed in `self.next`, `aggregated_model_validation`. The model, optimizer, and anything that is not explicitly excluded from the next function will be passed from the `start` function on the aggregator to the `aggregated_model_validation` task on the collaborator. Where the tasks run is determined by the placement decorator that precedes each task definition (`@aggregator` or `@collaborator`). Once each of the collaborators (defined in the runtime) complete the `aggregated_model_validation` task, they pass their current state onto the `train` task, from `train` to `local_model_validation`, and then finally to `join` at the aggregator. It is in `join` that an average is taken of the model weights, and the next round can begin.\n", + "\n", + "* In __`FederatedDropoutAvgFlow`__ we define here, \n", + " * At the `start` task (at the the start of the flow) and at the `join` task (at the end of each round), some random collaborators are selected for the next round, from the `self.runtime.collaborators` using `FedDropoutAvg.select_random_clients` method. So, not every collaborator is participating at training.\n", + " * At the `join` task, model aggregation is done using `FedDropoutAvg.aggregate` method\n" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "difficult-madrid", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Aggregator step \"start\" registered\n", + "Collaborator step \"aggregated_model_validation\" registered\n", + "Collaborator step \"train\" registered\n", + "Collaborator step \"local_model_validation\" registered\n", + "Aggregator step \"join\" registered\n", + "Aggregator step \"end\" registered\n" + ] + } + ], + "source": [ + "class FederatedDropoutAvgFlow(FLSpec):\n", + "\n", + " def __init__(self, model = None, optimizer = None, rounds=3, fdr=0.3, cdr=0.2, train_set_sizes=None, **kwargs):\n", + " super().__init__(**kwargs)\n", + " if model is not None:\n", + " self.model = model\n", + " self.optimizer = optimizer\n", + " else:\n", + " self.model = ResNet18(norm = 'gn') \n", + " self.optimizer = optim.SGD(self.model.parameters(), lr=learning_rate,\n", + " momentum=momentum)\n", + " \n", + " self.rounds = rounds\n", + " \n", + " #FedDropoutAvg \n", + " self.FDRaggregator = FedDropoutAvg(workers_dataset_sizes=train_set_sizes, fdr=fdr, cdr=cdr)\n", + "\n", + " @aggregator\n", + " def start(self):\n", + " print(f'Performing initialization for model')\n", + "\n", + " # FedDropoutAvg random collaborator selection for the first round\n", + " self.collaborators = self.FDRaggregator.select_random_clients(self.runtime.collaborators) \n", + "\n", + " self.private = 10\n", + " self.current_round = 0\n", + " self.next(self.aggregated_model_validation,foreach='collaborators',exclude=['private']) #\n", + "\n", + " @collaborator\n", + " def aggregated_model_validation(self):\n", + " print(f'Performing aggregated model validation for collaborator {self.input}')\n", + " self.agg_validation_score = inference(self.model,self.test_loader)\n", + " print(f'{self.input} value of {self.agg_validation_score}')\n", + " self.next(self.train)\n", + "\n", + " @collaborator\n", + " def train(self):\n", + " self.model.train()\n", + " self.optimizer = optim.SGD(self.model.parameters(), lr=learning_rate,\n", + " momentum=momentum)\n", + " train_losses = []\n", + " for batch_idx, (data, target) in enumerate(self.train_loader):\n", + " self.optimizer.zero_grad()\n", + " output = self.model(data)\n", + " loss = F.nll_loss(output, target)\n", + " loss.backward()\n", + " self.optimizer.step()\n", + " if batch_idx % log_interval == 0:\n", + " print('Train Epoch: 1 [{}/{} ({:.0f}%)]\\tLoss: {:.6f}'.format(\n", + " batch_idx * len(data), len(self.train_loader.dataset),\n", + " 100. * batch_idx / len(self.train_loader), loss.item()))\n", + " self.loss = loss.item()\n", + " torch.save(self.model.state_dict(), 'model.pth')\n", + " torch.save(self.optimizer.state_dict(), 'optimizer.pth')\n", + " self.training_completed = True\n", + " self.next(self.local_model_validation)\n", + "\n", + " @collaborator\n", + " def local_model_validation(self):\n", + " self.local_validation_score = inference(self.model,self.test_loader)\n", + " print(f'Doing local model validation for collaborator {self.input}: {self.local_validation_score}')\n", + " self.next(self.join, exclude=['training_completed'])\n", + "\n", + " @aggregator\n", + " def join(self,inputs):\n", + " self.average_loss = sum(input.loss for input in inputs)/len(inputs)\n", + " self.aggregated_model_accuracy = sum(input.agg_validation_score for input in inputs)/len(inputs)\n", + " self.local_model_accuracy = sum(input.local_validation_score for input in inputs)/len(inputs)\n", + " \n", + " print(f'\\n* Ending round = {self.current_round}')\n", + " print(f'Average aggregated model validation values = {self.aggregated_model_accuracy}')\n", + " print(f'Average training loss = {self.average_loss}')\n", + " print(f'Average local model validation values = {self.local_model_accuracy}')\n", + " \n", + " models = [input.model for input in inputs]\n", + " \n", + " #self.model = FedAvg(models)\n", + " self.model = self.FDRaggregator.aggregate(models, self.collaborators)\n", + "\n", + " self.optimizer = [input.optimizer for input in inputs][0]\n", + " self.current_round += 1\n", + "\n", + " if self.current_round < self.rounds: \n", + " # FedDropoutAvg random ccollaborator selection for the next round\n", + " self.collaborators = self.FDRaggregator.select_random_clients(self.runtime.collaborators) \n", + " self.next(self.aggregated_model_validation, foreach='collaborators', exclude=['private'])\n", + " else:\n", + " self.next(self.end)\n", + " \n", + " @aggregator\n", + " def end(self):\n", + " print(f'This is the end of the flow') " + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "2aabf61e", + "metadata": {}, + "source": [ + "Below, we segment shards of the CIFAR10 dataset for **ten collaborators**. Each has their own slice of the dataset that's accessible via the `train_loader` or `test_loader` attribute." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "forward-world", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Local runtime collaborators = ['Portland', 'Seattle', 'Tokyo', 'New York', 'Mumbai', 'Budapest', 'Vienna', 'London', 'York', 'Istanbul']\n", + "train_set_sizes = {'Portland': 5000, 'Seattle': 5000, 'Tokyo': 5000, 'New York': 5000, 'Mumbai': 5000, 'Budapest': 5000, 'Vienna': 5000, 'London': 5000, 'York': 5000, 'Istanbul': 5000}\n" + ] + } + ], + "source": [ + "# Setup participants\n", + "aggregator = Aggregator()\n", + "aggregator.private_attributes = {}\n", + "\n", + "# Setup collaborators with private attributes\n", + "collaborator_names = ['Portland', 'Seattle', 'Tokyo', 'New York', 'Mumbai', 'Budapest', 'Vienna', 'London', 'York', 'Istanbul'] \n", + "\n", + "collaborators = [Collaborator(name=name) for name in collaborator_names]\n", + "train_set_sizes = {} \n", + "\n", + "for idx, collaborator in enumerate(collaborators):\n", + " local_train = copy.deepcopy(cifar10_train)\n", + " local_test = copy.deepcopy(cifar10_test)\n", + "\n", + " local_train.data = cifar10_train.data[idx::len(collaborators)]\n", + " local_train.targets = cifar10_train.targets[idx::len(collaborators)]\n", + " train_set_sizes[collaborator_names[idx]] = len(local_train.data)\n", + "\n", + " local_test.data = cifar10_test.data[idx::len(collaborators)]\n", + " local_test.targets = cifar10_test.targets[idx::len(collaborators)]\n", + " collaborator.private_attributes = {\n", + " 'train_loader': torch.utils.data.DataLoader(local_train,batch_size=batch_size_train, shuffle=True),\n", + " 'test_loader': torch.utils.data.DataLoader(local_test,batch_size=batch_size_train, shuffle=True)\n", + " }\n", + "\n", + "local_runtime = LocalRuntime(aggregator=aggregator, collaborators=collaborators, backend='single_process')\n", + "print(f'Local runtime collaborators = {local_runtime.collaborators}')\n", + "print(f'train_set_sizes = {train_set_sizes}')" + ] + }, + { + "cell_type": "markdown", + "id": "278ad46b", + "metadata": {}, + "source": [ + "Now that we have our flow and runtime defined, let's run the experiment! " + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "16937a65", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "* fed_drop_avg init *\n", + "workers_dataset_sizes : {'Portland': 5000, 'Seattle': 5000, 'Tokyo': 5000, 'New York': 5000, 'Mumbai': 5000, 'Budapest': 5000, 'Vienna': 5000, 'London': 5000, 'York': 5000, 'Istanbul': 5000}\n", + "\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.\n", + " warnings.warn(\n", + "/home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=None`.\n", + " warnings.warn(msg)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Created flow FederatedDropoutAvgFlow\n", + "\n", + "Calling start\n", + "Performing initialization for model\n", + "\n", + "client_dropout_rate = 0.2\n", + "selected_workers_this_round = [ True True True True True False True True True False]\n", + "selected_worker_ids_this_round = ['Portland', 'Seattle', 'Tokyo', 'New York', 'Mumbai', 'Vienna', 'London', 'York']\n", + "\n", + "Saving data artifacts for start\n", + "Saved data artifacts for start\n", + "Sending state from aggregator to collaborators\n", + "\n", + "Calling aggregated_model_validation\n", + "Performing aggregated model validation for collaborator Portland\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_914876/1662705434.py:63: UserWarning: Implicit dimension choice for log_softmax has been deprecated. Change the call to include dim=X as an argument.\n", + " return F.log_softmax(x) #x\n", + "/home/u2058145/anaconda3/envs/tiatoolbox-dev/lib/python3.8/site-packages/torch/nn/_reduction.py:42: UserWarning: size_average and reduce args will be deprecated, please use reduction='sum' instead.\n", + " warnings.warn(warning.format(ret))\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Test set: Avg. loss: 2.5109, Accuracy: 95/1000 (10%)\n", + "\n", + "Portland value of 0.0949999988079071\n", + "Saving data artifacts for aggregated_model_validation\n", + "Saved data artifacts for aggregated_model_validation\n", + "\n", + "Calling train\n", + "Train Epoch: 1 [0/5000 (0%)]\tLoss: 2.575697\n", + "Train Epoch: 1 [256/5000 (5%)]\tLoss: 2.620125\n", + "Train Epoch: 1 [512/5000 (10%)]\tLoss: 2.710043\n", + "Train Epoch: 1 [768/5000 (15%)]\tLoss: 2.459184\n", + "Train Epoch: 1 [1024/5000 (20%)]\tLoss: 2.506419\n", + "Train Epoch: 1 [1280/5000 (25%)]\tLoss: 2.380469\n", + "Train Epoch: 1 [1536/5000 (30%)]\tLoss: 2.517232\n", + "Train Epoch: 1 [1792/5000 (35%)]\tLoss: 2.416172\n", + "Train Epoch: 1 [2048/5000 (40%)]\tLoss: 2.375356\n", + "Train Epoch: 1 [2304/5000 (45%)]\tLoss: 2.305528\n", + "Train Epoch: 1 [2560/5000 (50%)]\tLoss: 2.188028\n", + "Train Epoch: 1 [2816/5000 (55%)]\tLoss: 2.179701\n", + "Train Epoch: 1 [3072/5000 (60%)]\tLoss: 2.149461\n", + "Train Epoch: 1 [3328/5000 (65%)]\tLoss: 2.091488\n", + "Train Epoch: 1 [3584/5000 (70%)]\tLoss: 2.197521\n", + "Train Epoch: 1 [3840/5000 (75%)]\tLoss: 2.265912\n", + "Train Epoch: 1 [4096/5000 (80%)]\tLoss: 2.270894\n", + "Train Epoch: 1 [4352/5000 (85%)]\tLoss: 2.292987\n", + "Train Epoch: 1 [4608/5000 (90%)]\tLoss: 2.162272\n", + "Train Epoch: 1 [2584/5000 (95%)]\tLoss: 1.973521\n", + "Saving data artifacts for train\n", + "Saved data artifacts for train\n", + "\n", + "Calling local_model_validation\n", + "\n", + "Test set: Avg. loss: 2.1073, Accuracy: 205/1000 (20%)\n", + "\n", + "Doing local model validation for collaborator Portland: 0.20499999821186066\n", + "Saving data artifacts for local_model_validation\n", + "Saved data artifacts for local_model_validation\n", + "Should transfer from local_model_validation to join\n", + "\n", + "Calling aggregated_model_validation\n", + "Performing aggregated model validation for collaborator Seattle\n", + "\n", + "Test set: Avg. loss: 2.5109, Accuracy: 94/1000 (9%)\n", + "\n", + "Seattle value of 0.09399999678134918\n", + "Saving data artifacts for aggregated_model_validation\n", + "Saved data artifacts for aggregated_model_validation\n", + "\n", + "Calling train\n", + "Train Epoch: 1 [0/5000 (0%)]\tLoss: 2.559444\n", + "Train Epoch: 1 [256/5000 (5%)]\tLoss: 2.582267\n", + "Train Epoch: 1 [512/5000 (10%)]\tLoss: 2.495960\n", + "Train Epoch: 1 [768/5000 (15%)]\tLoss: 2.343980\n", + "Train Epoch: 1 [1024/5000 (20%)]\tLoss: 2.426143\n", + "Train Epoch: 1 [1280/5000 (25%)]\tLoss: 2.369343\n", + "Train Epoch: 1 [1536/5000 (30%)]\tLoss: 2.397931\n", + "Train Epoch: 1 [1792/5000 (35%)]\tLoss: 2.445926\n", + "Train Epoch: 1 [2048/5000 (40%)]\tLoss: 2.401167\n", + "Train Epoch: 1 [2304/5000 (45%)]\tLoss: 2.300503\n", + "Train Epoch: 1 [2560/5000 (50%)]\tLoss: 2.279062\n", + "Train Epoch: 1 [2816/5000 (55%)]\tLoss: 2.145592\n", + "Train Epoch: 1 [3072/5000 (60%)]\tLoss: 2.168720\n", + "Train Epoch: 1 [3328/5000 (65%)]\tLoss: 2.185812\n", + "Train Epoch: 1 [3584/5000 (70%)]\tLoss: 2.206379\n", + "Train Epoch: 1 [3840/5000 (75%)]\tLoss: 2.178717\n", + "Train Epoch: 1 [4096/5000 (80%)]\tLoss: 2.166626\n", + "Train Epoch: 1 [4352/5000 (85%)]\tLoss: 2.177924\n", + "Train Epoch: 1 [4608/5000 (90%)]\tLoss: 2.083949\n", + "Train Epoch: 1 [2584/5000 (95%)]\tLoss: 2.080594\n", + "Saving data artifacts for train\n", + "Saved data artifacts for train\n", + "\n", + "Calling local_model_validation\n", + "\n", + "Test set: Avg. loss: 2.1677, Accuracy: 199/1000 (20%)\n", + "\n", + "Doing local model validation for collaborator Seattle: 0.19900000095367432\n", + "Saving data artifacts for local_model_validation\n", + "Saved data artifacts for local_model_validation\n", + "Should transfer from local_model_validation to join\n", + "\n", + "Calling aggregated_model_validation\n", + "Performing aggregated model validation for collaborator Tokyo\n", + "\n", + "Test set: Avg. loss: 2.4976, Accuracy: 102/1000 (10%)\n", + "\n", + "Tokyo value of 0.10199999809265137\n", + "Saving data artifacts for aggregated_model_validation\n", + "Saved data artifacts for aggregated_model_validation\n", + "\n", + "Calling train\n", + "Train Epoch: 1 [0/5000 (0%)]\tLoss: 2.557174\n", + "Train Epoch: 1 [256/5000 (5%)]\tLoss: 2.617011\n", + "Train Epoch: 1 [512/5000 (10%)]\tLoss: 2.550061\n", + "Train Epoch: 1 [768/5000 (15%)]\tLoss: 2.599874\n", + "Train Epoch: 1 [1024/5000 (20%)]\tLoss: 2.705568\n", + "Train Epoch: 1 [1280/5000 (25%)]\tLoss: 2.435205\n", + "Train Epoch: 1 [1536/5000 (30%)]\tLoss: 2.470760\n", + "Train Epoch: 1 [1792/5000 (35%)]\tLoss: 2.298102\n", + "Train Epoch: 1 [2048/5000 (40%)]\tLoss: 2.223739\n", + "Train Epoch: 1 [2304/5000 (45%)]\tLoss: 2.191273\n", + "Train Epoch: 1 [2560/5000 (50%)]\tLoss: 2.340257\n", + "Train Epoch: 1 [2816/5000 (55%)]\tLoss: 2.175023\n", + "Train Epoch: 1 [3072/5000 (60%)]\tLoss: 2.167267\n", + "Train Epoch: 1 [3328/5000 (65%)]\tLoss: 2.210913\n", + "Train Epoch: 1 [3584/5000 (70%)]\tLoss: 2.188567\n", + "Train Epoch: 1 [3840/5000 (75%)]\tLoss: 2.075622\n", + "Train Epoch: 1 [4096/5000 (80%)]\tLoss: 2.134003\n", + "Train Epoch: 1 [4352/5000 (85%)]\tLoss: 2.445056\n", + "Train Epoch: 1 [4608/5000 (90%)]\tLoss: 2.544509\n", + "Train Epoch: 1 [2584/5000 (95%)]\tLoss: 2.529624\n", + "Saving data artifacts for train\n", + "Saved data artifacts for train\n", + "\n", + "Calling local_model_validation\n", + "\n", + "Test set: Avg. loss: 2.2867, Accuracy: 219/1000 (22%)\n", + "\n", + "Doing local model validation for collaborator Tokyo: 0.21899999678134918\n", + "Saving data artifacts for local_model_validation\n", + "Saved data artifacts for local_model_validation\n", + "Should transfer from local_model_validation to join\n", + "\n", + "Calling aggregated_model_validation\n", + "Performing aggregated model validation for collaborator New York\n", + "\n", + "Test set: Avg. loss: 2.5167, Accuracy: 84/1000 (8%)\n", + "\n", + "New York value of 0.08399999886751175\n", + "Saving data artifacts for aggregated_model_validation\n", + "Saved data artifacts for aggregated_model_validation\n", + "\n", + "Calling train\n", + "Train Epoch: 1 [0/5000 (0%)]\tLoss: 2.435062\n", + "Train Epoch: 1 [256/5000 (5%)]\tLoss: 2.439531\n", + "Train Epoch: 1 [512/5000 (10%)]\tLoss: 2.629870\n", + "Train Epoch: 1 [768/5000 (15%)]\tLoss: 2.614744\n", + "Train Epoch: 1 [1024/5000 (20%)]\tLoss: 2.676745\n", + "Train Epoch: 1 [1280/5000 (25%)]\tLoss: 2.519074\n", + "Train Epoch: 1 [1536/5000 (30%)]\tLoss: 2.416895\n", + "Train Epoch: 1 [1792/5000 (35%)]\tLoss: 2.428877\n", + "Train Epoch: 1 [2048/5000 (40%)]\tLoss: 2.316732\n", + "Train Epoch: 1 [2304/5000 (45%)]\tLoss: 2.315463\n", + "Train Epoch: 1 [2560/5000 (50%)]\tLoss: 2.322055\n", + "Train Epoch: 1 [2816/5000 (55%)]\tLoss: 2.273499\n", + "Train Epoch: 1 [3072/5000 (60%)]\tLoss: 2.263719\n", + "Train Epoch: 1 [3328/5000 (65%)]\tLoss: 2.156230\n", + "Train Epoch: 1 [3584/5000 (70%)]\tLoss: 2.246406\n", + "Train Epoch: 1 [3840/5000 (75%)]\tLoss: 2.241237\n", + "Train Epoch: 1 [4096/5000 (80%)]\tLoss: 2.123760\n", + "Train Epoch: 1 [4352/5000 (85%)]\tLoss: 2.109340\n", + "Train Epoch: 1 [4608/5000 (90%)]\tLoss: 2.026899\n", + "Train Epoch: 1 [2584/5000 (95%)]\tLoss: 2.047124\n", + "Saving data artifacts for train\n", + "Saved data artifacts for train\n", + "\n", + "Calling local_model_validation\n", + "\n", + "Test set: Avg. loss: 2.0436, Accuracy: 286/1000 (29%)\n", + "\n", + "Doing local model validation for collaborator New York: 0.28600001335144043\n", + "Saving data artifacts for local_model_validation\n", + "Saved data artifacts for local_model_validation\n", + "Should transfer from local_model_validation to join\n", + "\n", + "Calling aggregated_model_validation\n", + "Performing aggregated model validation for collaborator Mumbai\n", + "\n", + "Test set: Avg. loss: 2.4931, Accuracy: 92/1000 (9%)\n", + "\n", + "Mumbai value of 0.09200000017881393\n", + "Saving data artifacts for aggregated_model_validation\n", + "Saved data artifacts for aggregated_model_validation\n", + "\n", + "Calling train\n", + "Train Epoch: 1 [0/5000 (0%)]\tLoss: 2.475485\n", + "Train Epoch: 1 [256/5000 (5%)]\tLoss: 2.478536\n", + "Train Epoch: 1 [512/5000 (10%)]\tLoss: 2.409363\n", + "Train Epoch: 1 [768/5000 (15%)]\tLoss: 2.475603\n", + "Train Epoch: 1 [1024/5000 (20%)]\tLoss: 2.469703\n", + "Train Epoch: 1 [1280/5000 (25%)]\tLoss: 2.503402\n", + "Train Epoch: 1 [1536/5000 (30%)]\tLoss: 2.502491\n", + "Train Epoch: 1 [1792/5000 (35%)]\tLoss: 2.387769\n", + "Train Epoch: 1 [2048/5000 (40%)]\tLoss: 2.261808\n", + "Train Epoch: 1 [2304/5000 (45%)]\tLoss: 2.367011\n", + "Train Epoch: 1 [2560/5000 (50%)]\tLoss: 2.237673\n", + "Train Epoch: 1 [2816/5000 (55%)]\tLoss: 2.197210\n", + "Train Epoch: 1 [3072/5000 (60%)]\tLoss: 2.214212\n", + "Train Epoch: 1 [3328/5000 (65%)]\tLoss: 2.222021\n", + "Train Epoch: 1 [3584/5000 (70%)]\tLoss: 2.205699\n", + "Train Epoch: 1 [3840/5000 (75%)]\tLoss: 2.092920\n", + "Train Epoch: 1 [4096/5000 (80%)]\tLoss: 2.200138\n", + "Train Epoch: 1 [4352/5000 (85%)]\tLoss: 2.252608\n", + "Train Epoch: 1 [4608/5000 (90%)]\tLoss: 2.263825\n", + "Train Epoch: 1 [2584/5000 (95%)]\tLoss: 2.122983\n", + "Saving data artifacts for train\n", + "Saved data artifacts for train\n", + "\n", + "Calling local_model_validation\n", + "\n", + "Test set: Avg. loss: 2.1101, Accuracy: 209/1000 (21%)\n", + "\n", + "Doing local model validation for collaborator Mumbai: 0.20900000631809235\n", + "Saving data artifacts for local_model_validation\n", + "Saved data artifacts for local_model_validation\n", + "Should transfer from local_model_validation to join\n", + "\n", + "Calling aggregated_model_validation\n", + "Performing aggregated model validation for collaborator Vienna\n", + "\n", + "Test set: Avg. loss: 2.5170, Accuracy: 86/1000 (9%)\n", + "\n", + "Vienna value of 0.0860000029206276\n", + "Saving data artifacts for aggregated_model_validation\n", + "Saved data artifacts for aggregated_model_validation\n", + "\n", + "Calling train\n", + "Train Epoch: 1 [0/5000 (0%)]\tLoss: 2.540579\n", + "Train Epoch: 1 [256/5000 (5%)]\tLoss: 2.595216\n", + "Train Epoch: 1 [512/5000 (10%)]\tLoss: 2.679543\n", + "Train Epoch: 1 [768/5000 (15%)]\tLoss: 2.576197\n", + "Train Epoch: 1 [1024/5000 (20%)]\tLoss: 2.462202\n", + "Train Epoch: 1 [1280/5000 (25%)]\tLoss: 2.289311\n", + "Train Epoch: 1 [1536/5000 (30%)]\tLoss: 2.284979\n", + "Train Epoch: 1 [1792/5000 (35%)]\tLoss: 2.221318\n", + "Train Epoch: 1 [2048/5000 (40%)]\tLoss: 2.257116\n", + "Train Epoch: 1 [2304/5000 (45%)]\tLoss: 2.238572\n", + "Train Epoch: 1 [2560/5000 (50%)]\tLoss: 2.240233\n", + "Train Epoch: 1 [2816/5000 (55%)]\tLoss: 2.251343\n", + "Train Epoch: 1 [3072/5000 (60%)]\tLoss: 2.206989\n", + "Train Epoch: 1 [3328/5000 (65%)]\tLoss: 2.193255\n", + "Train Epoch: 1 [3584/5000 (70%)]\tLoss: 2.206035\n", + "Train Epoch: 1 [3840/5000 (75%)]\tLoss: 2.143415\n", + "Train Epoch: 1 [4096/5000 (80%)]\tLoss: 2.057267\n", + "Train Epoch: 1 [4352/5000 (85%)]\tLoss: 2.146593\n", + "Train Epoch: 1 [4608/5000 (90%)]\tLoss: 2.356225\n", + "Train Epoch: 1 [2584/5000 (95%)]\tLoss: 2.148012\n", + "Saving data artifacts for train\n", + "Saved data artifacts for train\n", + "\n", + "Calling local_model_validation\n", + "\n", + "Test set: Avg. loss: 2.3087, Accuracy: 175/1000 (18%)\n", + "\n", + "Doing local model validation for collaborator Vienna: 0.17499999701976776\n", + "Saving data artifacts for local_model_validation\n", + "Saved data artifacts for local_model_validation\n", + "Should transfer from local_model_validation to join\n", + "\n", + "Calling aggregated_model_validation\n", + "Performing aggregated model validation for collaborator London\n", + "\n", + "Test set: Avg. loss: 2.4569, Accuracy: 94/1000 (9%)\n", + "\n", + "London value of 0.09399999678134918\n", + "Saving data artifacts for aggregated_model_validation\n", + "Saved data artifacts for aggregated_model_validation\n", + "\n", + "Calling train\n", + "Train Epoch: 1 [0/5000 (0%)]\tLoss: 2.502911\n", + "Train Epoch: 1 [256/5000 (5%)]\tLoss: 2.532863\n", + "Train Epoch: 1 [512/5000 (10%)]\tLoss: 2.360670\n", + "Train Epoch: 1 [768/5000 (15%)]\tLoss: 2.391947\n", + "Train Epoch: 1 [1024/5000 (20%)]\tLoss: 2.444503\n", + "Train Epoch: 1 [1280/5000 (25%)]\tLoss: 2.330840\n", + "Train Epoch: 1 [1536/5000 (30%)]\tLoss: 2.337116\n", + "Train Epoch: 1 [1792/5000 (35%)]\tLoss: 2.293673\n", + "Train Epoch: 1 [2048/5000 (40%)]\tLoss: 2.296888\n", + "Train Epoch: 1 [2304/5000 (45%)]\tLoss: 2.212420\n", + "Train Epoch: 1 [2560/5000 (50%)]\tLoss: 2.214494\n", + "Train Epoch: 1 [2816/5000 (55%)]\tLoss: 2.193125\n", + "Train Epoch: 1 [3072/5000 (60%)]\tLoss: 2.109787\n", + "Train Epoch: 1 [3328/5000 (65%)]\tLoss: 2.173886\n", + "Train Epoch: 1 [3584/5000 (70%)]\tLoss: 2.128002\n", + "Train Epoch: 1 [3840/5000 (75%)]\tLoss: 2.124499\n", + "Train Epoch: 1 [4096/5000 (80%)]\tLoss: 2.185050\n", + "Train Epoch: 1 [4352/5000 (85%)]\tLoss: 2.312467\n", + "Train Epoch: 1 [4608/5000 (90%)]\tLoss: 2.427933\n", + "Train Epoch: 1 [2584/5000 (95%)]\tLoss: 2.516885\n", + "Saving data artifacts for train\n", + "Saved data artifacts for train\n", + "\n", + "Calling local_model_validation\n", + "\n", + "Test set: Avg. loss: 2.2781, Accuracy: 243/1000 (24%)\n", + "\n", + "Doing local model validation for collaborator London: 0.24300000071525574\n", + "Saving data artifacts for local_model_validation\n", + "Saved data artifacts for local_model_validation\n", + "Should transfer from local_model_validation to join\n", + "\n", + "Calling aggregated_model_validation\n", + "Performing aggregated model validation for collaborator York\n", + "\n", + "Test set: Avg. loss: 2.4967, Accuracy: 110/1000 (11%)\n", + "\n", + "York value of 0.10999999940395355\n", + "Saving data artifacts for aggregated_model_validation\n", + "Saved data artifacts for aggregated_model_validation\n", + "\n", + "Calling train\n", + "Train Epoch: 1 [0/5000 (0%)]\tLoss: 2.524751\n", + "Train Epoch: 1 [256/5000 (5%)]\tLoss: 2.566129\n", + "Train Epoch: 1 [512/5000 (10%)]\tLoss: 2.457451\n", + "Train Epoch: 1 [768/5000 (15%)]\tLoss: 2.490930\n", + "Train Epoch: 1 [1024/5000 (20%)]\tLoss: 2.478456\n", + "Train Epoch: 1 [1280/5000 (25%)]\tLoss: 2.397409\n", + "Train Epoch: 1 [1536/5000 (30%)]\tLoss: 2.388125\n", + "Train Epoch: 1 [1792/5000 (35%)]\tLoss: 2.386880\n", + "Train Epoch: 1 [2048/5000 (40%)]\tLoss: 2.387819\n", + "Train Epoch: 1 [2304/5000 (45%)]\tLoss: 2.214784\n", + "Train Epoch: 1 [2560/5000 (50%)]\tLoss: 2.230114\n", + "Train Epoch: 1 [2816/5000 (55%)]\tLoss: 2.188889\n", + "Train Epoch: 1 [3072/5000 (60%)]\tLoss: 2.142556\n", + "Train Epoch: 1 [3328/5000 (65%)]\tLoss: 2.111259\n", + "Train Epoch: 1 [3584/5000 (70%)]\tLoss: 2.105552\n", + "Train Epoch: 1 [3840/5000 (75%)]\tLoss: 2.224369\n", + "Train Epoch: 1 [4096/5000 (80%)]\tLoss: 2.207006\n", + "Train Epoch: 1 [4352/5000 (85%)]\tLoss: 2.141580\n", + "Train Epoch: 1 [4608/5000 (90%)]\tLoss: 2.414766\n", + "Train Epoch: 1 [2584/5000 (95%)]\tLoss: 2.104281\n", + "Saving data artifacts for train\n", + "Saved data artifacts for train\n", + "\n", + "Calling local_model_validation\n", + "\n", + "Test set: Avg. loss: 2.1318, Accuracy: 230/1000 (23%)\n", + "\n", + "Doing local model validation for collaborator York: 0.23000000417232513\n", + "Saving data artifacts for local_model_validation\n", + "Saved data artifacts for local_model_validation\n", + "Should transfer from local_model_validation to join\n", + "\n", + "Calling join\n", + "\n", + "* Ending round = 0\n", + "Average aggregated model validation values = 0.09462499897927046\n", + "Average training loss = 2.190378025174141\n", + "Average local model validation values = 0.2207500021904707\n", + "FedDropoutAvg aggragation step. # of models to aggregate = 8\n", + "* get_fed_avg_weights *\n", + "FedAvg Weights: [0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125]\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_914876/3699002234.py:74: RuntimeWarning: invalid value encountered in divide\n", + " curr_weights = (selection_arr[...,r] * fed_avg_weights[r] / curr_sum)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "client_dropout_rate = 0.2\n", + "selected_workers_this_round = [False True True True False True True True True True]\n", + "selected_worker_ids_this_round = ['Seattle', 'Tokyo', 'New York', 'Budapest', 'Vienna', 'London', 'York', 'Istanbul']\n", + "\n", + "Saving data artifacts for join\n", + "Saved data artifacts for join\n", + "Sending state from aggregator to collaborators\n", + "\n", + "Calling aggregated_model_validation\n", + "Performing aggregated model validation for collaborator Seattle\n", + "\n", + "Test set: Avg. loss: 2.0144, Accuracy: 279/1000 (28%)\n", + "\n", + "Seattle value of 0.27900001406669617\n", + "Saving data artifacts for aggregated_model_validation\n", + "Saved data artifacts for aggregated_model_validation\n", + "\n", + "Calling train\n", + "Train Epoch: 1 [0/5000 (0%)]\tLoss: 2.022604\n", + "Train Epoch: 1 [256/5000 (5%)]\tLoss: 2.084951\n", + "Train Epoch: 1 [512/5000 (10%)]\tLoss: 2.053971\n", + "Train Epoch: 1 [768/5000 (15%)]\tLoss: 1.991255\n", + "Train Epoch: 1 [1024/5000 (20%)]\tLoss: 1.949353\n", + "Train Epoch: 1 [1280/5000 (25%)]\tLoss: 1.968518\n", + "Train Epoch: 1 [1536/5000 (30%)]\tLoss: 1.979830\n", + "Train Epoch: 1 [1792/5000 (35%)]\tLoss: 2.040371\n", + "Train Epoch: 1 [2048/5000 (40%)]\tLoss: 2.049551\n", + "Train Epoch: 1 [2304/5000 (45%)]\tLoss: 2.164682\n", + "Train Epoch: 1 [2560/5000 (50%)]\tLoss: 2.030992\n", + "Train Epoch: 1 [2816/5000 (55%)]\tLoss: 1.970891\n", + "Train Epoch: 1 [3072/5000 (60%)]\tLoss: 1.979251\n", + "Train Epoch: 1 [3328/5000 (65%)]\tLoss: 1.913164\n", + "Train Epoch: 1 [3584/5000 (70%)]\tLoss: 2.093827\n", + "Train Epoch: 1 [3840/5000 (75%)]\tLoss: 2.245971\n", + "Train Epoch: 1 [4096/5000 (80%)]\tLoss: 2.083004\n", + "Train Epoch: 1 [4352/5000 (85%)]\tLoss: 1.890087\n", + "Train Epoch: 1 [4608/5000 (90%)]\tLoss: 1.916753\n", + "Train Epoch: 1 [2584/5000 (95%)]\tLoss: 2.062671\n", + "Saving data artifacts for train\n", + "Saved data artifacts for train\n", + "\n", + "Calling local_model_validation\n", + "\n", + "Test set: Avg. loss: 2.0156, Accuracy: 248/1000 (25%)\n", + "\n", + "Doing local model validation for collaborator Seattle: 0.24799999594688416\n", + "Saving data artifacts for local_model_validation\n", + "Saved data artifacts for local_model_validation\n", + "Should transfer from local_model_validation to join\n", + "\n", + "Calling aggregated_model_validation\n", + "Performing aggregated model validation for collaborator Tokyo\n", + "\n", + "Test set: Avg. loss: 2.0383, Accuracy: 290/1000 (29%)\n", + "\n", + "Tokyo value of 0.28999999165534973\n", + "Saving data artifacts for aggregated_model_validation\n", + "Saved data artifacts for aggregated_model_validation\n", + "\n", + "Calling train\n", + "Train Epoch: 1 [0/5000 (0%)]\tLoss: 2.036776\n", + "Train Epoch: 1 [256/5000 (5%)]\tLoss: 1.978999\n", + "Train Epoch: 1 [512/5000 (10%)]\tLoss: 2.022073\n", + "Train Epoch: 1 [768/5000 (15%)]\tLoss: 2.011048\n", + "Train Epoch: 1 [1024/5000 (20%)]\tLoss: 2.021605\n", + "Train Epoch: 1 [1280/5000 (25%)]\tLoss: 1.997957\n", + "Train Epoch: 1 [1536/5000 (30%)]\tLoss: 1.988759\n", + "Train Epoch: 1 [1792/5000 (35%)]\tLoss: 2.031986\n", + "Train Epoch: 1 [2048/5000 (40%)]\tLoss: 2.034485\n", + "Train Epoch: 1 [2304/5000 (45%)]\tLoss: 2.092000\n", + "Train Epoch: 1 [2560/5000 (50%)]\tLoss: 1.965717\n", + "Train Epoch: 1 [2816/5000 (55%)]\tLoss: 2.030584\n", + "Train Epoch: 1 [3072/5000 (60%)]\tLoss: 2.097492\n", + "Train Epoch: 1 [3328/5000 (65%)]\tLoss: 2.080436\n", + "Train Epoch: 1 [3584/5000 (70%)]\tLoss: 1.975417\n", + "Train Epoch: 1 [3840/5000 (75%)]\tLoss: 1.958862\n", + "Train Epoch: 1 [4096/5000 (80%)]\tLoss: 1.971893\n", + "Train Epoch: 1 [4352/5000 (85%)]\tLoss: 2.088602\n", + "Train Epoch: 1 [4608/5000 (90%)]\tLoss: 2.106205\n", + "Train Epoch: 1 [2584/5000 (95%)]\tLoss: 2.101292\n", + "Saving data artifacts for train\n", + "Saved data artifacts for train\n", + "\n", + "Calling local_model_validation\n", + "\n", + "Test set: Avg. loss: 2.3143, Accuracy: 219/1000 (22%)\n", + "\n", + "Doing local model validation for collaborator Tokyo: 0.21899999678134918\n", + "Saving data artifacts for local_model_validation\n", + "Saved data artifacts for local_model_validation\n", + "Should transfer from local_model_validation to join\n", + "\n", + "Calling aggregated_model_validation\n", + "Performing aggregated model validation for collaborator New York\n", + "\n", + "Test set: Avg. loss: 2.0127, Accuracy: 281/1000 (28%)\n", + "\n", + "New York value of 0.2809999883174896\n", + "Saving data artifacts for aggregated_model_validation\n", + "Saved data artifacts for aggregated_model_validation\n", + "\n", + "Calling train\n", + "Train Epoch: 1 [0/5000 (0%)]\tLoss: 2.022024\n", + "Train Epoch: 1 [256/5000 (5%)]\tLoss: 2.074870\n", + "Train Epoch: 1 [512/5000 (10%)]\tLoss: 2.022179\n", + "Train Epoch: 1 [768/5000 (15%)]\tLoss: 1.997844\n", + "Train Epoch: 1 [1024/5000 (20%)]\tLoss: 1.996588\n", + "Train Epoch: 1 [1280/5000 (25%)]\tLoss: 1.982725\n", + "Train Epoch: 1 [1536/5000 (30%)]\tLoss: 2.015751\n", + "Train Epoch: 1 [1792/5000 (35%)]\tLoss: 2.091727\n", + "Train Epoch: 1 [2048/5000 (40%)]\tLoss: 2.146847\n", + "Train Epoch: 1 [2304/5000 (45%)]\tLoss: 2.042103\n", + "Train Epoch: 1 [2560/5000 (50%)]\tLoss: 2.170175\n", + "Train Epoch: 1 [2816/5000 (55%)]\tLoss: 1.978170\n", + "Train Epoch: 1 [3072/5000 (60%)]\tLoss: 1.992675\n", + "Train Epoch: 1 [3328/5000 (65%)]\tLoss: 1.984998\n", + "Train Epoch: 1 [3584/5000 (70%)]\tLoss: 1.917243\n", + "Train Epoch: 1 [3840/5000 (75%)]\tLoss: 2.003946\n", + "Train Epoch: 1 [4096/5000 (80%)]\tLoss: 1.919209\n", + "Train Epoch: 1 [4352/5000 (85%)]\tLoss: 1.936433\n", + "Train Epoch: 1 [4608/5000 (90%)]\tLoss: 2.021804\n", + "Train Epoch: 1 [2584/5000 (95%)]\tLoss: 2.197349\n", + "Saving data artifacts for train\n", + "Saved data artifacts for train\n", + "\n", + "Calling local_model_validation\n", + "\n", + "Test set: Avg. loss: 1.9833, Accuracy: 284/1000 (28%)\n", + "\n", + "Doing local model validation for collaborator New York: 0.2840000092983246\n", + "Saving data artifacts for local_model_validation\n", + "Saved data artifacts for local_model_validation\n", + "Should transfer from local_model_validation to join\n", + "\n", + "Calling aggregated_model_validation\n", + "Performing aggregated model validation for collaborator Budapest\n", + "\n", + "Test set: Avg. loss: 2.0319, Accuracy: 260/1000 (26%)\n", + "\n", + "Budapest value of 0.25999999046325684\n", + "Saving data artifacts for aggregated_model_validation\n", + "Saved data artifacts for aggregated_model_validation\n", + "\n", + "Calling train\n", + "Train Epoch: 1 [0/5000 (0%)]\tLoss: 2.061434\n", + "Train Epoch: 1 [256/5000 (5%)]\tLoss: 1.963689\n", + "Train Epoch: 1 [512/5000 (10%)]\tLoss: 2.041071\n", + "Train Epoch: 1 [768/5000 (15%)]\tLoss: 2.008275\n", + "Train Epoch: 1 [1024/5000 (20%)]\tLoss: 2.053854\n", + "Train Epoch: 1 [1280/5000 (25%)]\tLoss: 2.043153\n", + "Train Epoch: 1 [1536/5000 (30%)]\tLoss: 2.014827\n", + "Train Epoch: 1 [1792/5000 (35%)]\tLoss: 2.174351\n", + "Train Epoch: 1 [2048/5000 (40%)]\tLoss: 2.165076\n", + "Train Epoch: 1 [2304/5000 (45%)]\tLoss: 2.057882\n", + "Train Epoch: 1 [2560/5000 (50%)]\tLoss: 2.080222\n", + "Train Epoch: 1 [2816/5000 (55%)]\tLoss: 2.058807\n", + "Train Epoch: 1 [3072/5000 (60%)]\tLoss: 1.980807\n", + "Train Epoch: 1 [3328/5000 (65%)]\tLoss: 2.040826\n", + "Train Epoch: 1 [3584/5000 (70%)]\tLoss: 1.896287\n", + "Train Epoch: 1 [3840/5000 (75%)]\tLoss: 2.010593\n", + "Train Epoch: 1 [4096/5000 (80%)]\tLoss: 2.035838\n", + "Train Epoch: 1 [4352/5000 (85%)]\tLoss: 2.006614\n", + "Train Epoch: 1 [4608/5000 (90%)]\tLoss: 1.879216\n", + "Train Epoch: 1 [2584/5000 (95%)]\tLoss: 1.978764\n", + "Saving data artifacts for train\n", + "Saved data artifacts for train\n", + "\n", + "Calling local_model_validation\n", + "\n", + "Test set: Avg. loss: 2.1175, Accuracy: 262/1000 (26%)\n", + "\n", + "Doing local model validation for collaborator Budapest: 0.2619999945163727\n", + "Saving data artifacts for local_model_validation\n", + "Saved data artifacts for local_model_validation\n", + "Should transfer from local_model_validation to join\n", + "\n", + "Calling aggregated_model_validation\n", + "Performing aggregated model validation for collaborator Vienna\n", + "\n", + "Test set: Avg. loss: 2.0389, Accuracy: 256/1000 (26%)\n", + "\n", + "Vienna value of 0.25600001215934753\n", + "Saving data artifacts for aggregated_model_validation\n", + "Saved data artifacts for aggregated_model_validation\n", + "\n", + "Calling train\n", + "Train Epoch: 1 [0/5000 (0%)]\tLoss: 2.032061\n", + "Train Epoch: 1 [256/5000 (5%)]\tLoss: 2.087704\n", + "Train Epoch: 1 [512/5000 (10%)]\tLoss: 2.077573\n", + "Train Epoch: 1 [768/5000 (15%)]\tLoss: 1.963778\n", + "Train Epoch: 1 [1024/5000 (20%)]\tLoss: 2.130500\n", + "Train Epoch: 1 [1280/5000 (25%)]\tLoss: 2.131711\n", + "Train Epoch: 1 [1536/5000 (30%)]\tLoss: 2.147037\n", + "Train Epoch: 1 [1792/5000 (35%)]\tLoss: 2.127461\n", + "Train Epoch: 1 [2048/5000 (40%)]\tLoss: 1.950479\n", + "Train Epoch: 1 [2304/5000 (45%)]\tLoss: 1.931547\n", + "Train Epoch: 1 [2560/5000 (50%)]\tLoss: 2.077667\n", + "Train Epoch: 1 [2816/5000 (55%)]\tLoss: 1.964857\n", + "Train Epoch: 1 [3072/5000 (60%)]\tLoss: 1.894112\n", + "Train Epoch: 1 [3328/5000 (65%)]\tLoss: 1.908302\n", + "Train Epoch: 1 [3584/5000 (70%)]\tLoss: 1.967790\n", + "Train Epoch: 1 [3840/5000 (75%)]\tLoss: 2.187759\n", + "Train Epoch: 1 [4096/5000 (80%)]\tLoss: 2.321854\n", + "Train Epoch: 1 [4352/5000 (85%)]\tLoss: 2.113746\n", + "Train Epoch: 1 [4608/5000 (90%)]\tLoss: 1.994545\n", + "Train Epoch: 1 [2584/5000 (95%)]\tLoss: 1.938500\n", + "Saving data artifacts for train\n", + "Saved data artifacts for train\n", + "\n", + "Calling local_model_validation\n", + "\n", + "Test set: Avg. loss: 2.0496, Accuracy: 249/1000 (25%)\n", + "\n", + "Doing local model validation for collaborator Vienna: 0.24899999797344208\n", + "Saving data artifacts for local_model_validation\n", + "Saved data artifacts for local_model_validation\n", + "Should transfer from local_model_validation to join\n", + "\n", + "Calling aggregated_model_validation\n", + "Performing aggregated model validation for collaborator London\n", + "\n", + "Test set: Avg. loss: 2.0057, Accuracy: 299/1000 (30%)\n", + "\n", + "London value of 0.29899999499320984\n", + "Saving data artifacts for aggregated_model_validation\n", + "Saved data artifacts for aggregated_model_validation\n", + "\n", + "Calling train\n", + "Train Epoch: 1 [0/5000 (0%)]\tLoss: 2.048916\n", + "Train Epoch: 1 [256/5000 (5%)]\tLoss: 2.051465\n", + "Train Epoch: 1 [512/5000 (10%)]\tLoss: 2.051873\n", + "Train Epoch: 1 [768/5000 (15%)]\tLoss: 2.061419\n", + "Train Epoch: 1 [1024/5000 (20%)]\tLoss: 2.033423\n", + "Train Epoch: 1 [1280/5000 (25%)]\tLoss: 2.052135\n", + "Train Epoch: 1 [1536/5000 (30%)]\tLoss: 2.017332\n", + "Train Epoch: 1 [1792/5000 (35%)]\tLoss: 2.035055\n", + "Train Epoch: 1 [2048/5000 (40%)]\tLoss: 2.074398\n", + "Train Epoch: 1 [2304/5000 (45%)]\tLoss: 1.871239\n", + "Train Epoch: 1 [2560/5000 (50%)]\tLoss: 2.026396\n", + "Train Epoch: 1 [2816/5000 (55%)]\tLoss: 2.331779\n", + "Train Epoch: 1 [3072/5000 (60%)]\tLoss: 2.524098\n", + "Train Epoch: 1 [3328/5000 (65%)]\tLoss: 2.276949\n", + "Train Epoch: 1 [3584/5000 (70%)]\tLoss: 2.115633\n", + "Train Epoch: 1 [3840/5000 (75%)]\tLoss: 2.073865\n", + "Train Epoch: 1 [4096/5000 (80%)]\tLoss: 2.099614\n", + "Train Epoch: 1 [4352/5000 (85%)]\tLoss: 1.896590\n", + "Train Epoch: 1 [4608/5000 (90%)]\tLoss: 1.986085\n", + "Train Epoch: 1 [2584/5000 (95%)]\tLoss: 1.849730\n", + "Saving data artifacts for train\n", + "Saved data artifacts for train\n", + "\n", + "Calling local_model_validation\n", + "\n", + "Test set: Avg. loss: 1.9641, Accuracy: 244/1000 (24%)\n", + "\n", + "Doing local model validation for collaborator London: 0.24400000274181366\n", + "Saving data artifacts for local_model_validation\n", + "Saved data artifacts for local_model_validation\n", + "Should transfer from local_model_validation to join\n", + "\n", + "Calling aggregated_model_validation\n", + "Performing aggregated model validation for collaborator York\n", + "\n", + "Test set: Avg. loss: 2.0181, Accuracy: 288/1000 (29%)\n", + "\n", + "York value of 0.2879999876022339\n", + "Saving data artifacts for aggregated_model_validation\n", + "Saved data artifacts for aggregated_model_validation\n", + "\n", + "Calling train\n", + "Train Epoch: 1 [0/5000 (0%)]\tLoss: 2.043256\n", + "Train Epoch: 1 [256/5000 (5%)]\tLoss: 2.163815\n", + "Train Epoch: 1 [512/5000 (10%)]\tLoss: 2.102471\n", + "Train Epoch: 1 [768/5000 (15%)]\tLoss: 2.089473\n", + "Train Epoch: 1 [1024/5000 (20%)]\tLoss: 2.040744\n", + "Train Epoch: 1 [1280/5000 (25%)]\tLoss: 2.007309\n", + "Train Epoch: 1 [1536/5000 (30%)]\tLoss: 2.038326\n", + "Train Epoch: 1 [1792/5000 (35%)]\tLoss: 1.953669\n", + "Train Epoch: 1 [2048/5000 (40%)]\tLoss: 2.016808\n", + "Train Epoch: 1 [2304/5000 (45%)]\tLoss: 2.138150\n", + "Train Epoch: 1 [2560/5000 (50%)]\tLoss: 2.159809\n", + "Train Epoch: 1 [2816/5000 (55%)]\tLoss: 2.178420\n", + "Train Epoch: 1 [3072/5000 (60%)]\tLoss: 2.047402\n", + "Train Epoch: 1 [3328/5000 (65%)]\tLoss: 2.356047\n", + "Train Epoch: 1 [3584/5000 (70%)]\tLoss: 2.178318\n", + "Train Epoch: 1 [3840/5000 (75%)]\tLoss: 2.045300\n", + "Train Epoch: 1 [4096/5000 (80%)]\tLoss: 1.998527\n", + "Train Epoch: 1 [4352/5000 (85%)]\tLoss: 1.969887\n", + "Train Epoch: 1 [4608/5000 (90%)]\tLoss: 2.025174\n", + "Train Epoch: 1 [2584/5000 (95%)]\tLoss: 1.884651\n", + "Saving data artifacts for train\n", + "Saved data artifacts for train\n", + "\n", + "Calling local_model_validation\n", + "\n", + "Test set: Avg. loss: 1.9107, Accuracy: 307/1000 (31%)\n", + "\n", + "Doing local model validation for collaborator York: 0.3070000112056732\n", + "Saving data artifacts for local_model_validation\n", + "Saved data artifacts for local_model_validation\n", + "Should transfer from local_model_validation to join\n", + "\n", + "Calling aggregated_model_validation\n", + "Performing aggregated model validation for collaborator Istanbul\n", + "\n", + "Test set: Avg. loss: 2.0169, Accuracy: 270/1000 (27%)\n", + "\n", + "Istanbul value of 0.27000001072883606\n", + "Saving data artifacts for aggregated_model_validation\n", + "Saved data artifacts for aggregated_model_validation\n", + "\n", + "Calling train\n", + "Train Epoch: 1 [0/5000 (0%)]\tLoss: 1.997646\n", + "Train Epoch: 1 [256/5000 (5%)]\tLoss: 2.022416\n", + "Train Epoch: 1 [512/5000 (10%)]\tLoss: 2.016927\n", + "Train Epoch: 1 [768/5000 (15%)]\tLoss: 2.069757\n", + "Train Epoch: 1 [1024/5000 (20%)]\tLoss: 2.070618\n", + "Train Epoch: 1 [1280/5000 (25%)]\tLoss: 2.113342\n", + "Train Epoch: 1 [1536/5000 (30%)]\tLoss: 2.142785\n", + "Train Epoch: 1 [1792/5000 (35%)]\tLoss: 2.013771\n", + "Train Epoch: 1 [2048/5000 (40%)]\tLoss: 1.901719\n", + "Train Epoch: 1 [2304/5000 (45%)]\tLoss: 2.072110\n", + "Train Epoch: 1 [2560/5000 (50%)]\tLoss: 1.948489\n", + "Train Epoch: 1 [2816/5000 (55%)]\tLoss: 2.113527\n", + "Train Epoch: 1 [3072/5000 (60%)]\tLoss: 2.109587\n", + "Train Epoch: 1 [3328/5000 (65%)]\tLoss: 2.169302\n", + "Train Epoch: 1 [3584/5000 (70%)]\tLoss: 1.941946\n", + "Train Epoch: 1 [3840/5000 (75%)]\tLoss: 2.009107\n", + "Train Epoch: 1 [4096/5000 (80%)]\tLoss: 1.976783\n", + "Train Epoch: 1 [4352/5000 (85%)]\tLoss: 2.018227\n", + "Train Epoch: 1 [4608/5000 (90%)]\tLoss: 1.928527\n", + "Train Epoch: 1 [2584/5000 (95%)]\tLoss: 1.904041\n", + "Saving data artifacts for train\n", + "Saved data artifacts for train\n", + "\n", + "Calling local_model_validation\n", + "\n", + "Test set: Avg. loss: 2.1580, Accuracy: 188/1000 (19%)\n", + "\n", + "Doing local model validation for collaborator Istanbul: 0.18799999356269836\n", + "Saving data artifacts for local_model_validation\n", + "Saved data artifacts for local_model_validation\n", + "Should transfer from local_model_validation to join\n", + "\n", + "Calling join\n", + "\n", + "* Ending round = 1\n", + "Average aggregated model validation values = 0.27787499874830246\n", + "Average training loss = 1.9896245896816254\n", + "Average local model validation values = 0.25012500025331974\n", + "FedDropoutAvg aggragation step. # of models to aggregate = 8\n", + "* get_fed_avg_weights *\n", + "FedAvg Weights: [0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125]\n", + "Saving data artifacts for join\n", + "Saved data artifacts for join\n", + "\n", + "Calling end\n", + "This is the end of the flow\n", + "Saving data artifacts for end\n", + "Saved data artifacts for end\n" + ] + } + ], + "source": [ + "model = None\n", + "best_model = None\n", + "optimizer = None\n", + "\n", + "flflow = FederatedDropoutAvgFlow(model,optimizer,rounds=n_rounds,fdr=fdr,cdr=cdr,train_set_sizes=train_set_sizes,checkpoint=True)\n", + "\n", + "flflow.runtime = local_runtime\n", + "flflow.run()" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "24acb66e", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['Seattle',\n", + " 'Tokyo',\n", + " 'New York',\n", + " 'Budapest',\n", + " 'Vienna',\n", + " 'London',\n", + " 'York',\n", + " 'Istanbul']" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# the collaborators from the last round:\n", + "flflow.collaborators" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "21a08533", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['Portland',\n", + " 'Seattle',\n", + " 'Tokyo',\n", + " 'New York',\n", + " 'Mumbai',\n", + " 'Budapest',\n", + " 'Vienna',\n", + " 'London',\n", + " 'York',\n", + " 'Istanbul']" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# All collaborators available in runtime:\n", + "flflow.runtime.collaborators" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "c32e0844", + "metadata": {}, + "source": [ + "Now that the flow has completed, we can get the final model, and all other aggregator attributes after the flow completes.\n", + "\n", + "Let's get the final model and accuracy:" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "863761fe", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Final aggregated model accuracy for 2 rounds of training: 0.27787499874830246\n" + ] + } + ], + "source": [ + "print(f'\\nFinal aggregated model accuracy for {flflow.rounds} rounds of training: {flflow.aggregated_model_accuracy}')" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "426f2395", + "metadata": {}, + "source": [ + "## This is the end of the FedDropoutAvg tutorial. \n", + "\n", + "## Feel free to change the _`cdr` (client dropout rate)_ and _`fdr` (federated dropout rate)_ parameters of the algorithm, and/or try it on different datasets." + ] + }, + { + "cell_type": "markdown", + "id": "b07d6d42", + "metadata": {}, + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "tiatoolbox-dev", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.17" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}