diff --git a/.github/workflows/android-release.yml b/.github/workflows/android-release.yml index 35df8c8a9cfb..ba11e1ee85e7 100644 --- a/.github/workflows/android-release.yml +++ b/.github/workflows/android-release.yml @@ -15,6 +15,7 @@ jobs: run: working-directory: src/kotlin name: Release build and publish + if: github.repository == 'adap/flower' runs-on: ubuntu-latest steps: - name: Check out code diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 8bc4d93e8a61..d6640a34968b 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -32,7 +32,7 @@ jobs: - name: Build docs run: ./dev/build-docs.sh - name: Deploy docs - if: github.ref == 'refs/heads/main' + if: github.ref == 'refs/heads/main' && github.repository == 'adap/flower' && ${{ !github.event.pull_request.head.repo.fork }} env: AWS_DEFAULT_REGION: ${{ secrets. AWS_DEFAULT_REGION }} AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }} diff --git a/.github/workflows/e2e.yml b/.github/workflows/e2e.yml index 3b70db43a6c8..3a58503ea66e 100644 --- a/.github/workflows/e2e.yml +++ b/.github/workflows/e2e.yml @@ -16,9 +16,40 @@ env: FLWR_TELEMETRY_ENABLED: 0 jobs: + wheel: + runs-on: ubuntu-22.04 + name: Build, test and upload wheel + steps: + - uses: actions/checkout@v3 + - name: Bootstrap + uses: ./.github/actions/bootstrap + - name: Install dependencies (mandatory only) + run: python -m poetry install + - name: Build wheel + run: ./dev/build.sh + - name: Test wheel + run: ./dev/test-wheel.sh + - name: Upload wheel + if: github.repository == 'adap/flower' && ${{ !github.event.pull_request.head.repo.fork }} + id: upload + env: + AWS_DEFAULT_REGION: ${{ secrets. AWS_DEFAULT_REGION }} + AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }} + AWS_SECRET_ACCESS_KEY: ${{ secrets. AWS_SECRET_ACCESS_KEY }} + run: | + cd ./dist + echo "WHL_PATH=$(ls *.whl)" >> "$GITHUB_OUTPUT" + sha_short=$(git rev-parse --short HEAD) + echo "SHORT_SHA=$sha_short" >> "$GITHUB_OUTPUT" + aws s3 cp --content-disposition "attachment" --cache-control "no-cache" ./ s3://artifact.flower.dev/py/${{ github.head_ref }}/$sha_short --recursive + outputs: + whl_path: ${{ steps.upload.outputs.WHL_PATH }} + short_sha: ${{ steps.upload.outputs.SHORT_SHA }} + frameworks: runs-on: ubuntu-22.04 timeout-minutes: 10 + needs: wheel # Using approach described here: # https://docs.github.com/en/actions/using-jobs/using-a-matrix-for-your-jobs strategy: @@ -89,6 +120,10 @@ jobs: python-version: 3.8 - name: Install dependencies run: python -m poetry install + - name: Install Flower wheel from artifact store + if: github.repository == 'adap/flower' && ${{ !github.event.pull_request.head.repo.fork }} + run: | + python -m pip install https://artifact.flower.dev/py/${{ github.head_ref }}/${{ needs.wheel.outputs.short_sha }}/${{ needs.wheel.outputs.whl_path }} - name: Download dataset if: ${{ matrix.dataset }} run: python -c "${{ matrix.dataset }}" @@ -102,6 +137,7 @@ jobs: strategies: runs-on: ubuntu-22.04 timeout-minutes: 10 + needs: wheel strategy: matrix: strat: ["FedMedian", "FedTrimmedAvg", "QFedAvg", "FaultTolerantFedAvg", "FedAvgM", "FedAdam", "FedAdagrad", "FedYogi"] @@ -119,6 +155,10 @@ jobs: - name: Install dependencies run: | python -m poetry install + - name: Install Flower wheel from artifact store + if: github.repository == 'adap/flower' && ${{ !github.event.pull_request.head.repo.fork }} + run: | + python -m pip install https://artifact.flower.dev/py/${{ github.head_ref }}/${{ needs.wheel.outputs.short_sha }}/${{ needs.wheel.outputs.whl_path }} - name: Cache Datasets uses: actions/cache@v3 with: diff --git a/.github/workflows/flower-swift_sync.yml b/.github/workflows/flower-swift_sync.yml index d3fce3b22a0f..836d905b2df2 100644 --- a/.github/workflows/flower-swift_sync.yml +++ b/.github/workflows/flower-swift_sync.yml @@ -12,6 +12,7 @@ concurrency: jobs: build: runs-on: ubuntu-latest + if: github.repository == 'adap/flower' steps: - uses: actions/checkout@v4 - name: Pushes src/swift to flower-swift repository diff --git a/.github/workflows/release-nightly.yml b/.github/workflows/release-nightly.yml index 0ae9c43ddbf1..823ff1513790 100644 --- a/.github/workflows/release-nightly.yml +++ b/.github/workflows/release-nightly.yml @@ -11,6 +11,7 @@ jobs: release_nightly: runs-on: ubuntu-22.04 name: Nightly + if: github.repository == 'adap/flower' steps: - uses: actions/checkout@v4 - name: Bootstrap diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml deleted file mode 100644 index d8f4e403482b..000000000000 --- a/.github/workflows/release.yml +++ /dev/null @@ -1,26 +0,0 @@ -name: Release - -on: - schedule: - - cron: "0 23 * * *" - -concurrency: - group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_id || github.event.pull_request.number || github.ref }} - cancel-in-progress: true - -env: - FLWR_TELEMETRY_ENABLED: 0 - -jobs: - nightly_release: - runs-on: ubuntu-22.04 - name: Nightly - steps: - - uses: actions/checkout@v4 - - name: Bootstrap - uses: ./.github/actions/bootstrap - - name: Release nightly - env: - PYPI_TOKEN: ${{ secrets.PYPI_TOKEN }} - run: | - ./dev/publish-nightly.sh diff --git a/.github/workflows/swift.yml b/.github/workflows/swift.yml index 207bb1283739..9edd7f7ff6e1 100644 --- a/.github/workflows/swift.yml +++ b/.github/workflows/swift.yml @@ -40,7 +40,7 @@ jobs: deploy_docs: needs: "build_docs" - if: github.ref == 'refs/heads/main' + if: github.ref == 'refs/heads/main' && github.repository == 'adap/flower' && ${{ !github.event.pull_request.head.repo.fork }} runs-on: macos-latest name: Deploy docs steps: diff --git a/dev/get-latest-changelog.sh b/dev/get-latest-changelog.sh new file mode 100755 index 000000000000..d7f4ca7db168 --- /dev/null +++ b/dev/get-latest-changelog.sh @@ -0,0 +1,17 @@ +#!/bin/bash + +set -e +cd "$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"/../ + +# Extract the latest release notes from the changelog, which starts at the line containing +# the latest version tag and ends one line before the previous version tag. +tags=$(git tag --sort=-creatordate) +new_version=$(echo "$tags" | sed -n '1p') +old_version=$(echo "$tags" | sed -n '2p') + +awk -v start="$new_version" -v end="$old_version" ' + $0 ~ start {flag=1; next} + $0 ~ end {flag=0} + flag && !printed && /^$/ {next} # skip the first blank line + flag && !printed {printed=1} + flag' doc/source/ref-changelog.md diff --git a/doc/source/index.rst b/doc/source/index.rst index 48f8d59ea9b7..4ac99cc24c09 100644 --- a/doc/source/index.rst +++ b/doc/source/index.rst @@ -61,17 +61,15 @@ A learning-oriented series of federated learning tutorials, the best place to st QUICKSTART TUTORIALS: :doc:`PyTorch ` | :doc:`TensorFlow ` | :doc:`🤗 Transformers ` | :doc:`JAX ` | :doc:`Pandas ` | :doc:`fastai ` | :doc:`PyTorch Lightning ` | :doc:`MXNet ` | :doc:`scikit-learn ` | :doc:`XGBoost ` | :doc:`Android ` | :doc:`iOS ` -.. grid:: 2 +We also made video tutorials for PyTorch: - .. grid-item-card:: PyTorch +.. youtube:: jOmmuzMIQ4c + :width: 80% - .. youtube:: jOmmuzMIQ4c - :width: 100% +And TensorFlow: - .. grid-item-card:: TensorFlow - - .. youtube:: FGTc2TQq7VM - :width: 100% +.. youtube:: FGTc2TQq7VM + :width: 80% How-to guides ~~~~~~~~~~~~~ diff --git a/doc/source/ref-changelog.md b/doc/source/ref-changelog.md index e1d90b01fb35..7355b4123347 100644 --- a/doc/source/ref-changelog.md +++ b/doc/source/ref-changelog.md @@ -16,11 +16,13 @@ - FedProx ([#2210](https://github.com/adap/flower/pull/2210), [#2286](https://github.com/adap/flower/pull/2286)) +- **Update Flower Examples** ([#2384](https://github.com/adap/flower/pull/2384)) + - **General updates to baselines** ([#2301](https://github.com/adap/flower/pull/2301).[#2305](https://github.com/adap/flower/pull/2305), [#2307](https://github.com/adap/flower/pull/2307), [#2327](https://github.com/adap/flower/pull/2327)) -- **General updates to the simulation engine** ([#2331](https://github.com/adap/flower/pull/2331)) +- **General updates to the simulation engine** ([#2331](https://github.com/adap/flower/pull/2331), [#2448](https://github.com/adap/flower/pull/2448)) -- **General improvements** ([#2309](https://github.com/adap/flower/pull/2309), [#2310](https://github.com/adap/flower/pull/2310), [2313](https://github.com/adap/flower/pull/2313), [#2316](https://github.com/adap/flower/pull/2316), [2317](https://github.com/adap/flower/pull/2317),[#2349](https://github.com/adap/flower/pull/2349), [#2360](https://github.com/adap/flower/pull/2360)) +- **General improvements** ([#2309](https://github.com/adap/flower/pull/2309), [#2310](https://github.com/adap/flower/pull/2310), [2313](https://github.com/adap/flower/pull/2313), [#2316](https://github.com/adap/flower/pull/2316), [2317](https://github.com/adap/flower/pull/2317),[#2349](https://github.com/adap/flower/pull/2349), [#2360](https://github.com/adap/flower/pull/2360), [#2402](https://github.com/adap/flower/pull/2402), [#2446](https://github.com/adap/flower/pull/2446)) Flower received many improvements under the hood, too many to list here. diff --git a/examples/advanced-pytorch/requirements.txt b/examples/advanced-pytorch/requirements.txt index 21c886d16e4d..ba7b284df90e 100644 --- a/examples/advanced-pytorch/requirements.txt +++ b/examples/advanced-pytorch/requirements.txt @@ -1,4 +1,4 @@ flwr>=1.0, <2.0 torch==1.13.1 torchvision==0.14.1 - +validators==0.18.2 diff --git a/examples/advanced-tensorflow/requirements.txt b/examples/advanced-tensorflow/requirements.txt index 6420aab25ec8..7a70c46a8128 100644 --- a/examples/advanced-tensorflow/requirements.txt +++ b/examples/advanced-tensorflow/requirements.txt @@ -1,3 +1,3 @@ flwr>=1.0, <2.0 -tensorflow-macos>=2.9.1, != 2.11.1 ; sys_platform == "darwin" and platform_machine == "arm64" tensorflow-cpu>=2.9.1, != 2.11.1 ; platform_machine == "x86_64" +tensorflow-macos>=2.9.1, != 2.11.1 ; sys_platform == "darwin" and platform_machine == "arm64" diff --git a/examples/android-kotlin/pyproject.toml b/examples/android-kotlin/pyproject.toml index dee6cbc35711..9cf0688d83b5 100644 --- a/examples/android-kotlin/pyproject.toml +++ b/examples/android-kotlin/pyproject.toml @@ -1,3 +1,7 @@ +[build-system] +requires = ["poetry-core"] +build-backend = "poetry.core.masonry.api" + [tool.poetry] name = "flower-android-kotlin" version = "0.1.0" @@ -7,7 +11,3 @@ authors = ["Steven Hé (Sīchàng) "] [tool.poetry.dependencies] python = ">=3.8,<3.11" flwr = ">=1.0,<2.0" - -[build-system] -requires = ["poetry-core"] -build-backend = "poetry.core.masonry.api" diff --git a/examples/android/pyproject.toml b/examples/android/pyproject.toml index 0ecaaa73989f..2b9cd8c978a7 100644 --- a/examples/android/pyproject.toml +++ b/examples/android/pyproject.toml @@ -10,7 +10,6 @@ authors = ["The Flower Authors "] [tool.poetry.dependencies] python = ">=3.8,<3.11" -# flwr = { path = "../../", develop = true } # Development flwr = ">=1.0,<2.0" tensorflow-cpu = {version = ">=2.9.1,<2.11.1 || >2.11.1", markers = "platform_machine == \"x86_64\""} tensorflow-macos = {version = ">=2.9.1,<2.11.1 || >2.11.1", markers = "sys_platform == \"darwin\" and platform_machine == \"arm64\""} diff --git a/examples/android/requirements.txt b/examples/android/requirements.txt index 6420aab25ec8..7a70c46a8128 100644 --- a/examples/android/requirements.txt +++ b/examples/android/requirements.txt @@ -1,3 +1,3 @@ flwr>=1.0, <2.0 -tensorflow-macos>=2.9.1, != 2.11.1 ; sys_platform == "darwin" and platform_machine == "arm64" tensorflow-cpu>=2.9.1, != 2.11.1 ; platform_machine == "x86_64" +tensorflow-macos>=2.9.1, != 2.11.1 ; sys_platform == "darwin" and platform_machine == "arm64" diff --git a/examples/embedded-devices/Dockerfile b/examples/embedded-devices/Dockerfile index add8d6d50d2e..ea63839bc9d6 100644 --- a/examples/embedded-devices/Dockerfile +++ b/examples/embedded-devices/Dockerfile @@ -1,28 +1,13 @@ -ARG BASE_IMAGE_TYPE=cpu -# these images have been pushed to Dockerhub but you can find -# each Dockerfile used in the `base_images` directory -FROM jafermarq/jetsonfederated_$BASE_IMAGE_TYPE:latest +ARG BASE_IMAGE -RUN apt-get install wget -y +# Pull the base image from NVIDIA +FROM $BASE_IMAGE -# Download and extract CIFAR-10 -# To keep things simple, we keep this as part of the docker image. -# If the dataset is already in your system you can mount it instead. -ENV DATA_DIR=/app/data/cifar-10 -RUN mkdir -p $DATA_DIR -WORKDIR $DATA_DIR -RUN wget https://www.cs.toronto.edu/\~kriz/cifar-10-python.tar.gz -RUN tar -zxvf cifar-10-python.tar.gz - -WORKDIR /app -# Scripts needed for Flower client -ADD client.py /app -ADD utils.py /app - -# update pip +# Update pip RUN pip3 install --upgrade pip -# making sure the latest version of flower is installed -RUN pip3 install flwr>=1.0.0 +# Install flower +RUN pip3 install flwr>=1.0 +RUN pip3 install tqdm==4.65.0 -ENTRYPOINT ["python3","-u","./client.py"] +WORKDIR /client diff --git a/examples/embedded-devices/README.md b/examples/embedded-devices/README.md index 16cc47bf3992..b485f663e08f 100644 --- a/examples/embedded-devices/README.md +++ b/examples/embedded-devices/README.md @@ -1,142 +1,216 @@ # Federated Learning on Embedded Devices with Flower -This demo will show you how Flower makes it very easy to run Federated Learning workloads on edge devices. Here we'll be showing how to use NVIDIA Jetson devices and Raspberry Pi as Flower clients. This demo uses Flower with PyTorch. The source code used is mostly borrowed from the [example that Flower provides for CIFAR-10](https://github.com/adap/flower/tree/main/src/py/flwr_example/pytorch_cifar). +This example will show you how Flower makes it very easy to run Federated Learning workloads on edge devices. Here we'll be showing how to use NVIDIA Jetson devices and Raspberry Pi as Flower clients. You can run this example using either PyTorch or Tensorflow. The FL workload (i.e. model, dataset and training loop) is mostly borrowed from the [quickstart-pytorch](https://github.com/adap/flower/tree/main/examples/simulation-pytorch) and [quickstart-tensorflow](https://github.com/adap/flower/tree/main/examples/quickstart-tensorflow) examples. -## Getting things ready +![Different was of running Flower FL on embedded devices](_static/diagram.png) -This is a list of components that you'll need: +## Getting things ready -- For server: A machine running Linux/macOS. -- For clients: either a Rapsberry Pi 3 B+ (RPi 4 would work too) or a Jetson Xavier-NX (or any other recent NVIDIA-Jetson device). -- A 32GB uSD card and ideally UHS-1 or better. (not needed if you plan to use a Jetson TX2 instead) -- Software to flash the images to a uSD card (e.g. [Etcher](https://www.balena.io/etcher/)) +> This example is designed for beginners that know a bit about Flower and/or ML but that are less familiar with embedded devices. If you already have a couple of devices up and running, clone this example and start the Flower clients after launching the Flower server. -What follows is a step-by-step guide on how to setup your client/s and the server. In order to minimize the amount of setup and potential issues that might arise due to the hardware/software heterogenity between clients we'll be running the clients inside a Docker. We provide two docker images: one built for Jetson devices and make use of their GPU; and the other for CPU-only training suitable for Raspberry Pi (but would also work on Jetson devices). The following diagram illustrates the setup for this demo: +This tutorial allows for a variety of settings (some shown in the diagrams above). As long as you have access to one embedded device, you can follow along. This is a list of components that you'll need: - +- For Flower server: A machine running Linux/macOS/Windows (e.g. your laptop). You can run the server on an embedded device too! +- For Flower clients (one or more): Raspberry Pi 4 (or Zero 2), or an NVIDIA Jetson Xavier-NX (or Nano), or anything similar to these. +- A uSD card with 32GB or more. +- Software to flash the images to a uSD card: + - For Raspberry Pi we recommend the [Raspberry Pi Imager](https://www.raspberrypi.com/software/) + - For other devices [balenaEtcher](https://www.balena.io/etcher/) it's a great option. -![alt text](_static/diagram.png) +What follows is a step-by-step guide on how to setup your client/s and the server. -## Clone this repo +## Clone this example -Start with cloning the Flower repo and checking out the example. We have prepared a single line which you can copy into your shell: +Start with cloning this example on your laptop or desktop machine. Later you'll run the same command on your embedded devices. We have prepared a single line which you can copy and execute: ```bash -$ git clone --depth=1 https://github.com/adap/flower.git && mv flower/examples/embedded-devices . && rm -rf flower && cd embedded-devices +git clone --depth=1 https://github.com/adap/flower.git && mv flower/examples/embedded-devices . && rm -rf flower && cd embedded-devices ``` ## Setting up the server -The only requirement for the server is to have flower installed. You can do so by running `pip install flwr` inside your virtualenv or conda environment. +The only requirement for the server is to have Flower installed alongside your ML framework of choice. Inside your Python environment run: + +```bash +pip install -r requierments_pytorch.txt # to install Flower and PyTorch + +# or the below for TensorFlower +# pip install -r requirements_tensorflow.txt +``` + +If you are working on this tutorial on your laptop or desktop, it can host the Flower server that will orchestrate the entire FL process. You could also use an embedded device (e.g. a Raspberry Pi) as the Flower server. In order to do that, please follow the setup steps below. + +## Setting up a Raspberry Pi + +> Wheter you use your RPi as a Flower server or a client, you need to follow these steps. + +![alt text](_static/rpi_imager.png) + +1. **Installing Ubuntu server on your Raspberry Pi** is easy with the [Raspberry Pi Imager](https://www.raspberrypi.com/software/). Before starting ensure you have a uSD card attached to your PC/Laptop and that it has sufficient space (ideally larger than 16GB). Then: + + - Click on `CHOOSE OS` > `Other general-pupose OS` > `Ubuntu` > `Ubuntu Server 22.04.03 LTS (64-bit)`. Other versions of `Ubuntu Server` would likely work but try to use a `64-bit` one. + - Select the uSD you want to flash the OS onto. (This will be the uSD you insert in your Raspberry Pi) + - Click on the gear icon on the bottom right of the `Raspberry Pi Imager` window (the icon only appears after choosing your OS image). Here you can very conveniently set the username/password to access your device over ssh. You'll see I use as username `piubuntu` (you can choose something different) It's also the ideal place to select your WiFi network and add the password (this is of course not needed if you plan to connect the Raspberry Pi via ethernet). Click "save" when you are done. + - Finally, click on `WRITE` to start flashing Ubuntu onto the uSD card. + +2. **Connecting to your Rapsberry Pi** + + After `ssh`-ing into your Raspberry Pi for the first time, make sure your OS is up-to-date. + + - Run: `sudo apt update` to look for updates + - And then: `sudo apt upgrade -y` to apply updates (this might take a few minutes on the RPi Zero) + - Then reboot your RPi with `sudo reboot`. Then ssh into it again. + +3. **Preparations for your Flower experiments** + + - Install `pip`. In the terminal type: `sudo apt install python3-pip -y` + - Now clone this directory. You just need to execute the `git clone` command shown at the top of this README.md on your device. + - Install Flower and your ML framework: We have prepared some convenient installation scripts that will install everything you need. You are free to install other versions of these ML frameworks to suit your needs. + - If you want your clients to use PyTorch: `pip3 install -r requirements_pytorch.txt` + - If you want your clients to use TensorFlow: `pip3 install -r requirements_tf.txt` + + > While preparing this example I noticed that installing TensorFlow on the **Raspberry pi Zero** would fail due to lack of RAM (it only has 512MB). A workaround is to create a `swap` disk partition (non-existant by default) so the OS can offload some elements to disk. I followed the steps described [in this blogpost](https://www.digitalocean.com/community/tutorials/how-to-add-swap-space-on-ubuntu-20-04) that I copy below. You can follow these steps if you often see your RPi Zero running out of memory: + + ```bash + # Let's create a 1GB swap partition + sudo fallocate -l 1G /swapfile + sudo chmod 600 /swapfile + sudo mkswap /swapfile + # Enable swap + sudo swapon /swapfile # you should now be able to see the swap size on htop. + # make changes permanent after reboot + sudo cp /etc/fstab /etc/fstab.bak + echo '/swapfile none swap sw 0 0' | sudo tee -a /etc/fstab + ``` + + Please note using swap as if it was RAM comes with a large penalty in terms of data movement. + +4. Run your Flower experiments following the steps in the [Running FL with Flower](https://github.com/adap/flower/tree/main/examples/embedded-devices#running-fl-training-with-flower) section. ## Setting up a Jetson Xavier-NX -> These steps have been validated for a Jetson Xavier-NX Dev Kit. An identical setup is needed for a Jetson Nano and Jetson TX2 once you get ssh access to them (i.e. jumping straight to point `4` below). For instructions on how to setup these devices please refer to the "getting started guides" for [Jetson Nano](https://developer.nvidia.com/embedded/learn/get-started-jetson-nano-devkit#intro) and [Jetson TX2](https://developer.nvidia.com/embedded/dlc/l4t-28-2-jetson-developer-kit-user-guide-ga). +> These steps have been validated for a Jetson Xavier-NX Dev Kit. An identical setup is needed for a Jetson Nano once you get ssh access to it (i.e. jumping straight to point `4` below). For instructions on how to setup these devices please refer to the "getting started guides" for [Jetson Nano](https://developer.nvidia.com/embedded/learn/get-started-jetson-nano-devkit#intro). + +1. **Install JetPack 5.1.2 on your Jetson device** -1. Download the Ubuntu 18.04 image from [NVIDIA-embedded](https://developer.nvidia.com/embedded/downloads), note that you'll need a NVIDIA developer account. This image comes with Docker pre-installed as well as PyTorch+Torchvision compiled with GPU support. + - Download the JetPack 5.1.2 image from [NVIDIA-embedded](https://developer.nvidia.com/embedded/jetpack-sdk-512), note that you might need an NVIDIA developer account. You can find the download link under the `SD Card Image Method` section on NVIDIA's site. This image comes with Docker pre-installed as well as PyTorch+Torchvision and TensorFlow compiled with GPU support. -2. Extract the image (~14GB) and flash it onto the uSD card using Etcher (or equivalent). + - Extract the image (~18GB and named `sd-blob.img`) and flash it onto the uSD card using [balenaEtcher](https://www.balena.io/etcher/) (or equivalent). -3. Follow [the instructions](https://developer.nvidia.com/embedded/learn/get-started-jetson-xavier-nx-devkit) to setup the device. +2. **Follow [the instructions](https://developer.nvidia.com/embedded/learn/get-started-jetson-xavier-nx-devkit) to set up the device.** The first time you boot your Xavier-NX you should plug it into a display to complete the installation process. After that, a display is no longer needed for this example but you could still use it instead of connecting to your device over ssh. -4. Installing Docker: Docker comes pre-installed with the Ubuntu image provided by NVIDIA. But for convinience we will create a new user group and add our user to it (with the idea of not having to use `sudo` for every command involving docker (e.g. `docker run`, `docker ps`, etc)). More details about what this entails can be found in the [Docker documentation](https://docs.docker.com/engine/install/linux-postinstall/). You can achieve this by doing: +3. **Setup Docker**: Docker comes pre-installed with the Ubuntu image provided by NVIDIA. But for convenience, we will create a new user group and add our user to it (with the idea of not having to use `sudo` for every command involving docker (e.g. `docker run`, `docker ps`, etc)). More details about what this entails can be found in the [Docker documentation](https://docs.docker.com/engine/install/linux-postinstall/). You can achieve this by doing: ```bash - $ sudo usermod -aG docker $USER + sudo usermod -aG docker $USER # apply changes to current shell (or logout/reboot) - $ newgrp docker + newgrp docker ``` -5. The minimal installation to run this example only requires an additional package, `git`, in order to clone this repo. Install `git` by: +4. **Update OS and install utilities.** Then, install some useful utilities: ```bash - $ sudo apt-get update && sudo apt-get install git -y + sudo apt update && sudo apt upgrade -y + # now reboot + sudo reboot ``` -6. (optional) additional packages: + Login again and (optional) install the following packages: + - - [jtop](https://github.com/rbonghi/jetson_stats), to monitor CPU/GPU utilization, power consumption and, many more. + - [jtop](https://github.com/rbonghi/jetson_stats), to monitor CPU/GPU utilization, power consumption and, many more. You can read more about it in [this blog post](https://jetsonhacks.com/2023/02/07/jtop-the-ultimate-tool-for-monitoring-nvidia-jetson-devices/). ```bash # First we need to install pip3 - $ sudo apt-get install python3-pip -y - # updated pip3 - $ sudo pip3 install -U pip + sudo apt install python3-pip -y # finally, install jtop - $ sudo -H pip3 install -U jetson-stats + sudo pip3 install -U jetson-stats + # now reboot (or run `sudo systemctl restart jtop.service` and login again) + sudo reboot ``` - - [TMUX](https://github.com/tmux/tmux/wiki), a terminal multiplexer. + Now you have installed `jtop`, just launch it by running the `jtop` command on your terminal. An interactive panel similar to the one shown on the right will show up. `jtop` allows you to monitor and control many features of your Jetson device. Read more in the [jtop documentation](https://rnext.it/jetson_stats/jtop/jtop.html) + + - [TMUX](https://github.com/tmux/tmux/wiki), a terminal multiplexer. As its name suggests, it allows you to device a single terminal window into multiple panels. In this way, you could (for example) use one panel to show your terminal and another to show `jtop`. That's precisely what the visualization on the right shows. ```bash # install tmux - $ sudo apt-get install tmux -y + sudo apt install tmux -y # add mouse support - $ echo set -g mouse on > ~/.tmux.conf + echo set -g mouse on > ~/.tmux.conf ``` -7. Power modes: The Jetson devices can operate at different power modes, each making use of more or less CPU cores clocked at different freqencies. The right power mode might very much depend on the application and scenario. When power consumption is not a limiting factor, we could use the highest 15W mode using all 6 CPU cores. On the other hand, if the devices are battery-powered we might want to make use of a low power mode using 10W and 2 CPU cores. All the details regarding the different power modes of a Jetson Xavier-NX can be found [here](https://docs.nvidia.com/jetson/l4t/index.html#page/Tegra%2520Linux%2520Driver%2520Package%2520Development%2520Guide%2Fpower_management_jetson_xavier.html%23wwpID0E0NO0HA). For this demo we'll be setting the device to the high performance mode: +5. **Power modes**. The Jetson devices can operate at different power modes, each making use of more or less CPU cores clocked at different frequencies. The right power mode might very much depend on the application and scenario. When power consumption is not a limiting factor, we could use the highest 15W mode using all 6 CPU cores. On the other hand, if the devices are battery-powered we might want to make use of a low-power mode using 10W and 2 CPU cores. All the details regarding the different power modes of a Jetson Xavier-NX can be found [here](https://docs.nvidia.com/jetson/l4t/index.html#page/Tegra%2520Linux%2520Driver%2520Package%2520Development%2520Guide%2Fpower_management_jetson_xavier.html%23wwpID0E0NO0HA). For this demo, we'll be setting the device to high-performance mode: ```bash - $ sudo /usr/sbin/nvpmodel -m 2 # 15W with 6cpus @ 1.4GHz + sudo /usr/sbin/nvpmodel -m 2 # 15W with 6cpus @ 1.4GHz ``` -## Setting up a Raspberry Pi (3B+ or 4B) - -1. Install Ubuntu server 20.04 LTS 64-bit for Rapsberry Pi. You can do this by using one of the images provided [by Ubuntu](https://ubuntu.com/download/raspberry-pi) and then use Etcher. Alternativelly, astep-by-step installation guide, showing how to download and flash the image onto a uSD card and, go throught the first boot process, can be found [here](https://ubuntu.com/tutorials/how-to-install-ubuntu-on-your-raspberry-pi#1-overview). Please note that the first time you boot your RPi it will automatically update the system (which will lock `sudo` and prevent running the commands below for a few minutes) + Jetson Stats (that you launch via `jtop`) also allows you to see and set the power mode on your device. Navigate to the `CTRL` panel and click on one of the `NVM modes` available. -2. Install docker (+ post-installation steps as in [Docker Docs](https://docs.docker.com/engine/install/linux-postinstall/)): +6. **Build base client image**. Before running a Flower client, we need to install `Flower` and other ML dependencies (i.e. Pytorch or Tensorflow). Instead of installing this manually via `pip3 install ...`, let's use the pre-built Docker images provided by NVIDIA. In this way, we can be confident that the ML infrastructure is optimized for these devices. Build your Flower client image with: ```bash - # make sure your OS is up-to-date - $ sudo apt-get update + # On your Jetson's terminal run + ./build_jetson_flower_client.sh --pytorch # or --tensorflow + # Bear in mind this might take a few minutes since the base images need to be donwloaded (~7GB) and decompressed. + # To the above script pass the additional flag `--no-cache` to re-build the image. + ``` - # get the installation script - $ curl -fsSL https://get.docker.com -o get-docker.sh + Once your script is finished, verify your `flower_client` Docker image is present. If you type `docker images` you'll see something like the following: - # install docker - $ sudo sh get-docker.sh + ```bash + REPOSITORY TAG IMAGE ID CREATED SIZE + flower_client latest 87e935a8ee37 18 seconds ago 12.6GB + ``` - # add your user to the docker group - $ sudo usermod -aG docker $USER +7. **Access your client image**. Before launching the Flower client, we need to run the image we just created. To keep things simpler, let's run the image in interactive mode (`-it`), mount the entire repository you cloned inside the `/client` directory of your container (`` -v `pwd`:/client ``), and use the NVIDIA runtime so we can access the GPU `--runtime nvidia`: - # apply changes to current shell (or logout/reboot) - $ newgrp docker + ```bash + # first ensure you are in the `embedded-devices` directory. If you are not, use the `cd` command to navigate to it + + # run the client container (this won't launch your Flower client, it will just "take you inside docker". The client can be run following the steps in the next section of the readme) + docker run -it --rm --runtime nvidia -v `pwd`:/client flower_client + # this will take you to a shell that looks something like this: + root@6e6ce826b8bb:/client# ``` -. (optional) additional packages: you could install `TMUX` (see point `6` above) and `htop` as a replacement for `jtop` (which is only available for Jetson devices). Htop can be installed via: `sudo apt-get install htop -y`. +8. **Run your FL experiments with Flower**. Follow the steps in the section below. -## Running FL training with Flower +## Running Embedded FL with Flower -For this demo we'll be using [CIFAR-10](https://www.cs.toronto.edu/~kriz/cifar.html), a popular dataset for image classification comprised of 10 classes (e.g. car, bird, airplane) and a total of 60K `32x32` RGB images. The training set contains 50K images. The server will automatically download the dataset should it not be found in `./data`. To keep the client side simple, the datasets will be downloaded when building the docker image. This will happen as the first stage in both `run_pi.sh` and `run_jetson.sh`. +For this demo, we'll be using [CIFAR-10](https://www.cs.toronto.edu/~kriz/cifar.html), a popular dataset for image classification comprised of 10 classes (e.g. car, bird, airplane) and a total of 60K `32x32` RGB images. The training set contains 50K images. The server will automatically download the dataset should it not be found in `./data`. The clients do the same. The dataset is by default split into 50 partitions (each to be assigned to a different client). This can be controlled with the `NUM_CLIENTS` global variable in the client scripts. In this example, each device will play the role of a specific user (specified via `--cid` -- we'll show this later) and therefore only do local training with that portion of the data. For CIFAR-10, clients will be training a MobileNet-v2/3 model. -> If you'd like to make use of your own dataset you could [mount it](https://docs.docker.com/storage/volumes/) to the client docker container when calling `docker run`. We leave this an other more advanced topics for a future example. +You can run this example using MNIST and a smaller CNN model by passing flag `--mnist`. This is useful if you are using devices with a very limited amount of memory (e.g. RaspberryPi Zero) or if you want the training taking place on the embedded devices to be much faster (specially if these are CPU-only). The partitioning of the dataset is done in the same way. -### Server +### Start your Flower Server -Launch the server and define the model you'd like to train. The current code (see `utils.py`) provides two models for CIFAR-10: a small CNN (more suitable for Raspberry Pi) and, a ResNet18, which will run well on the gpu. Each model can be specified using the `--model` flag with options `Net` or `ResNet18`. Launch a FL training setup with one client and doing three rounds as: +On the machine of your choice, launch the server: ```bash -# launch your server. It will be waiting until one client connects -$ python server.py --server_address --rounds 3 --min_num_clients 1 --min_sample_size 1 --model ResNet18 +# Launch your server. +# Will wait for at least 2 clients to be connected, then will train for 3 FL rounds +# The command below will sample all clients connected (since sample_fraction=1.0) +python server.py --rounds 3 --min_num_clients 2 --sample_fraction 1.0 # append `--mnist` if you want to use that dataset/model setting ``` -### Clients +> If you are on macOS with Apple Silicon (i.e. M1, M2 chips), you might encounter a `grpcio`-related issue when launching your server. If you are in a conda environment you can solve this easily by doing: `pip uninstall grpcio` and then `conda install grpcio`. -Asuming you have cloned this repo onto the device/s, then execute the appropiate script to run the docker image, connect with the server and proceed with the training. Note that you can use both a Jetson and a RPi simultaneously, just make sure you modify the script above when launching the server so it waits until 2 clients are online. +### Start the Flower Clients -#### For Jetson +It's time to launch your clients! Ensure you have followed the setup stages outline above for the devices at your disposal. -```bash -$ ./run_jetson.sh --server_address= --cid=0 --model=ResNet18 -``` - -#### For Raspberry Pi +The first time you run this, the dataset will be downloaded. From the commands below, replace `` with either `pytorch` or `tf` to run the corresponding client Python file. In a FL setting, each client has its unique dataset. In this example you can simulate this by manually assigning an ID to a client (`cid`) which should be an integer `[0, NUM_CLIENTS-1]`, where `NUM_CLIENTS` is the total number of partitions or clients that could participate at any point. This is defined at the top of the client files and defaults to `50`. You can change this value to make each partition larger or smaller. -Depending on the model of RapsberryPi you have, running the smaller `Net` model might be the only option due to the higher RAM budget needed for ResNet18. It should be fine for a RaspberryPi 4 with 4GB of RAM to run a RestNet18 (with an appropiate batch size) but bear in mind that each batch might take several second to complete. The following would run the smaller `Net` model: +Launch your Flower clients as follows. Remember that if you are using a Jetson device, you need first to run your Docker container (see tha last steps for the Jetson setup). If you are using Raspberry Pi Zero devices, it is normal if starting the clients take a few seconds. ```bash -# note that pulling the base image, extracting the content might take a while (specially on a RPi 3) the first time you run this. -$ ./run_pi.sh --server_address= --cid=0 --model=Net +# Run the default example (CIFAR-10) +python3 client_.py --cid= --server_address= + +# Use MNIST (and a smaller model) if your devices require a more lightweight workload +python3 client_.py --cid= --server_address= --mnist ``` + +Repeat the above for as many devices as you have. Pass a different `CLIENT_ID` to each device. You can naturally run this example using different types of devices (e.g. RPi, RPi Zero, Jetson) at the same time as long as they are training the same model. If you want to start more clients than the number of embedded devices you currently have access to, you can launch clients in your laptop: simply open a new terminal and run one of the `python3 client_.py ...` commands above. diff --git a/examples/embedded-devices/_static/diagram.png b/examples/embedded-devices/_static/diagram.png index 66d8855c859f..7eaa85fb24c6 100644 Binary files a/examples/embedded-devices/_static/diagram.png and b/examples/embedded-devices/_static/diagram.png differ diff --git a/examples/embedded-devices/_static/rpi_imager.png b/examples/embedded-devices/_static/rpi_imager.png new file mode 100644 index 000000000000..a59a3137334e Binary files /dev/null and b/examples/embedded-devices/_static/rpi_imager.png differ diff --git a/examples/embedded-devices/base_images/README.md b/examples/embedded-devices/base_images/README.md deleted file mode 100644 index b24608a466f9..000000000000 --- a/examples/embedded-devices/base_images/README.md +++ /dev/null @@ -1,11 +0,0 @@ -## Building the base Docker images - -We provide the base images used in the Dockerfile of the parent directory (i.e. `jafermarq/jetsonfederated_cpu` and `jafermarq/jetsonfederated_gpu`). To make the process of running the demo as seamsless as possible (i.e. without long Docker build times) we have pre-built these images and uploaded them to dockerhub. In that way, the Dockerfile in the parent directory only requires adding a couple of python scripts to the image. If you want to build these images by yourself, you can do so by running the `build.sh` script in each directory. Note that building these images might take around one hour, depending on your system's specs. - -These images target a `aarch64` machine (e.g. RPi) but you'd probably will be building these images on a `x86_64` machine. To achieve this you'll need `qemu`. You should enable this before building the images by doing: - -```bash -$ docker run --rm --privileged multiarch/qemu-user-static --reset -p yes -``` - -More details can be found in the [`qemu-user-static`](https://github.com/multiarch/qemu-user-static) repository. diff --git a/examples/embedded-devices/base_images/cpu/Dockerfile b/examples/embedded-devices/base_images/cpu/Dockerfile deleted file mode 100644 index 2f68406405a1..000000000000 --- a/examples/embedded-devices/base_images/cpu/Dockerfile +++ /dev/null @@ -1,43 +0,0 @@ -# From an ubuntu18 image bult for Arm, we build from source Pytorch and Torchvision -# Then we install Flower -# We also install dependencies for a very exciting future project - -FROM arm64v8/ubuntu:18.04 - -# basics -RUN apt-get update -RUN apt-get install libopenblas-dev libopenmpi-dev python3-pip cmake -y - -# update pip -RUN python3 -m pip install --upgrade pip - -RUN pip3 install Cython numpy - -RUN mkdir /app -WORKDIR /app - -## Installing Pytorch + Torchvision -RUN mkdir build -WORKDIR build -RUN apt-get install git bzip2 -y -RUN pip3 install scikit-build ninja - -# PyTorch -RUN git clone https://github.com/pytorch/pytorch.git -WORKDIR pytorch -RUN git checkout v1.6.0 && git submodule update --init --recursive -ENV USE_NCCL=0 USE_QNNPACK=0 USE_PYTORCH_QNNPACK=0 -RUN pip3 install -r requirements.txt -RUN python3 setup.py install - -# torchvision -WORKDIR /app/build -RUN git clone https://github.com/pytorch/vision.git -# checkout v0.7.0 (the one compatible with PyTorch 1.6) -WORKDIR vision -RUN git checkout v0.7.0 && git submodule update --recursive -RUN apt-get install libavcodec-dev libavformat-dev libswscale-dev libjpeg8-dev zlib1g-dev -y -RUN python3 setup.py install - -WORKDIR /app -RUN echo "done!" diff --git a/examples/embedded-devices/base_images/cpu/build.sh b/examples/embedded-devices/base_images/cpu/build.sh deleted file mode 100755 index be44de10de62..000000000000 --- a/examples/embedded-devices/base_images/cpu/build.sh +++ /dev/null @@ -1,9 +0,0 @@ -#!/bin/bash - -if [ -z "${CI}" ]; then - BUILDKIT=1 -else - BUILDKIT=0 -fi - -DOCKER_BUILDKIT=${BUILDKIT} docker build $@ . -t jetsonfederated_cpu:latest diff --git a/examples/embedded-devices/base_images/gpu/Dockerfile b/examples/embedded-devices/base_images/gpu/Dockerfile deleted file mode 100644 index 50c0b8a895fd..000000000000 --- a/examples/embedded-devices/base_images/gpu/Dockerfile +++ /dev/null @@ -1,10 +0,0 @@ -# Uses a pre-built image from nvidia, then we install Flower -# We also install dependencies for a very exciting future project - -# check this repo to learn more about the image below: https://github.com/dusty-nv/jetson-containers -FROM nvcr.io/nvidia/l4t-pytorch:r32.5.0-pth1.6-py3 - -RUN mkdir /app -WORKDIR /app - -RUN echo "done!" diff --git a/examples/embedded-devices/base_images/gpu/build.sh b/examples/embedded-devices/base_images/gpu/build.sh deleted file mode 100755 index 7b0e91d4ad0b..000000000000 --- a/examples/embedded-devices/base_images/gpu/build.sh +++ /dev/null @@ -1,9 +0,0 @@ -#!/bin/bash - -if [ -z "${CI}" ]; then - BUILDKIT=1 -else - BUILDKIT=0 -fi - -DOCKER_BUILDKIT=${BUILDKIT} docker build $@ . -t jetsonfederated_gpu:latest diff --git a/examples/embedded-devices/build_image.sh b/examples/embedded-devices/build_image.sh deleted file mode 100755 index ad19dd3a0e23..000000000000 --- a/examples/embedded-devices/build_image.sh +++ /dev/null @@ -1,12 +0,0 @@ -#!/bin/bash - -if [ -z "${CI}" ]; then - BUILDKIT=1 -else - BUILDKIT=0 -fi - -# TODO: should we do a `docker pull` here ? - -DOCKER_BUILDKIT=${BUILDKIT} docker build $@ . -t flower_client:latest - diff --git a/examples/embedded-devices/build_jetson_flower_client.sh b/examples/embedded-devices/build_jetson_flower_client.sh new file mode 100755 index 000000000000..32725a58f1f7 --- /dev/null +++ b/examples/embedded-devices/build_jetson_flower_client.sh @@ -0,0 +1,42 @@ +#!/bin/bash + +if [ -z "${CI}" ]; then + BUILDKIT=1 +else + BUILDKIT=0 +fi + +# This script build a docker image that's ready to run your flower client. +# Depending on your choice of ML framework (TF or PyTorch), the appropiate +# base image from NVIDIA will be pulled. This ensures you get the best +# performance out of your Jetson device. + +BASE_PYTORCH=nvcr.io/nvidia/l4t-pytorch:r35.1.0-pth1.13-py3 +BASE_TF=nvcr.io/nvidia/l4t-tensorflow:r35.3.1-tf2.11-py3 +EXTRA="" + +while [[ $# -gt 0 ]]; do + case $1 in + -p|--pytorch) + BASE_IMAGE=$BASE_PYTORCH + shift + ;; + -t|--tensorflow) + BASE_IMAGE=$BASE_TF + shift + ;; + -r|--no-cache) + EXTRA="--no-cache" + shift + ;; + -*|--*) + echo "Unknown option $1 (pass either --pytorch or --tensorflow)" + exit 1 + ;; + esac +done + +DOCKER_BUILDKIT=${BUILDKIT} docker build $EXTRA \ + --build-arg BASE_IMAGE=$BASE_IMAGE \ + . \ + -t flower_client:latest diff --git a/examples/embedded-devices/client.py b/examples/embedded-devices/client.py deleted file mode 100644 index c0b7d6989d7f..000000000000 --- a/examples/embedded-devices/client.py +++ /dev/null @@ -1,194 +0,0 @@ -# Copyright 2020 Adap GmbH. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Flower client example using PyTorch for CIFAR-10 image classification.""" - - -import argparse -import timeit -from collections import OrderedDict -from importlib import import_module - -import flwr as fl -import numpy as np -import torch -import torchvision -from flwr.common import ( - EvaluateIns, - EvaluateRes, - FitIns, - FitRes, - ParametersRes, - NDArrays, -) - -import utils - -# pylint: disable=no-member -DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") -# pylint: enable=no-member - - -def get_weights(model: torch.nn.ModuleList) -> fl.common.NDArrays: - """Get model weights as a list of NumPy ndarrays.""" - return [val.cpu().numpy() for _, val in model.state_dict().items()] - - -def set_weights(model: torch.nn.ModuleList, weights: fl.common.NDArrays) -> None: - """Set model weights from a list of NumPy ndarrays.""" - state_dict = OrderedDict( - { - k: torch.tensor(np.atleast_1d(v)) - for k, v in zip(model.state_dict().keys(), weights) - } - ) - model.load_state_dict(state_dict, strict=True) - - -class CifarClient(fl.client.Client): - """Flower client implementing CIFAR-10 image classification using PyTorch.""" - - def __init__( - self, - cid: str, - model: torch.nn.Module, - trainset: torchvision.datasets.CIFAR10, - testset: torchvision.datasets.CIFAR10, - ) -> None: - self.cid = cid - self.model = model - self.trainset = trainset - self.testset = testset - - def get_parameters(self, config) -> ParametersRes: - print(f"Client {self.cid}: get_parameters") - - weights: NDArrays = get_weights(self.model) - parameters = fl.common.ndarrays_to_parameters(weights) - return ParametersRes(parameters=parameters) - - def _instantiate_model(self, model_str: str): - # will load utils.model_str - m = getattr(import_module("utils"), model_str) - # instantiate model - self.model = m() - - def fit(self, ins: FitIns) -> FitRes: - print(f"Client {self.cid}: fit") - - weights: NDArrays = fl.common.parameters_to_ndarrays(ins.parameters) - config = ins.config - fit_begin = timeit.default_timer() - - # Get training config - epochs = int(config["epochs"]) - batch_size = int(config["batch_size"]) - pin_memory = bool(config["pin_memory"]) - num_workers = int(config["num_workers"]) - - # Set model parameters - set_weights(self.model, weights) - - if torch.cuda.is_available(): - kwargs = { - "num_workers": num_workers, - "pin_memory": pin_memory, - "drop_last": True, - } - else: - kwargs = {"drop_last": True} - - # Train model - trainloader = torch.utils.data.DataLoader( - self.trainset, batch_size=batch_size, shuffle=True, **kwargs - ) - utils.train(self.model, trainloader, epochs=epochs, device=DEVICE) - - # Return the refined weights and the number of examples used for training - weights_prime: NDArrays = get_weights(self.model) - params_prime = fl.common.ndarrays_to_parameters(weights_prime) - num_examples_train = len(self.trainset) - metrics = {"duration": timeit.default_timer() - fit_begin} - return FitRes( - parameters=params_prime, num_examples=num_examples_train, metrics=metrics - ) - - def evaluate(self, ins: EvaluateIns) -> EvaluateRes: - print(f"Client {self.cid}: evaluate") - - weights = fl.common.parameters_to_ndarrays(ins.parameters) - - # Use provided weights to update the local model - set_weights(self.model, weights) - - # Evaluate the updated model on the local dataset - testloader = torch.utils.data.DataLoader( - self.testset, batch_size=32, shuffle=False - ) - loss, accuracy = utils.test(self.model, testloader, device=DEVICE) - - # Return the number of evaluation examples and the evaluation result (loss) - metrics = {"accuracy": float(accuracy)} - return EvaluateRes( - loss=float(loss), num_examples=len(self.testset), metrics=metrics - ) - - -def main() -> None: - """Load data, create and start CifarClient.""" - parser = argparse.ArgumentParser(description="Flower") - parser.add_argument( - "--server_address", - type=str, - required=True, - help=f"gRPC server address", - ) - parser.add_argument( - "--cid", type=str, required=True, help="Client CID (no default)" - ) - parser.add_argument( - "--log_host", - type=str, - help="Logserver address (no default)", - ) - parser.add_argument( - "--data_dir", - type=str, - help="Directory where the dataset lives", - ) - parser.add_argument( - "--model", - type=str, - default="ResNet18", - choices=["Net", "ResNet18"], - help="model to train", - ) - args = parser.parse_args() - - # Configure logger - fl.common.logger.configure(f"client_{args.cid}", host=args.log_host) - - # model - model = utils.load_model(args.model) - model.to(DEVICE) - # load (local, on-device) dataset - trainset, testset = utils.load_cifar() - - # Start client - client = CifarClient(args.cid, model, trainset, testset) - fl.client.start_client(server_address=args.server_address, client=client) - - -if __name__ == "__main__": - main() diff --git a/examples/embedded-devices/client_pytorch.py b/examples/embedded-devices/client_pytorch.py new file mode 100644 index 000000000000..5d236c9e9389 --- /dev/null +++ b/examples/embedded-devices/client_pytorch.py @@ -0,0 +1,216 @@ +import argparse +import warnings +from collections import OrderedDict + +import flwr as fl +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.data import DataLoader, random_split +from torchvision.datasets import CIFAR10, MNIST +from torchvision.transforms import Compose, Normalize, ToTensor +from torchvision.models import mobilenet_v3_small +from tqdm import tqdm + +parser = argparse.ArgumentParser(description="Flower Embedded devices") +parser.add_argument( + "--server_address", + type=str, + default="0.0.0.0:8080", + help=f"gRPC server address (deafault '0.0.0.0:8080')", +) +parser.add_argument( + "--cid", + type=int, + required=True, + help="Client id. Should be an integer between 0 and NUM_CLIENTS", +) +parser.add_argument( + "--mnist", + action="store_true", + help="If you use Raspberry Pi Zero clients (which just have 512MB or RAM) use MNIST", +) + + +warnings.filterwarnings("ignore", category=UserWarning) +NUM_CLIENTS = 50 + +# a config for mobilenetv2 that works for +# small input sizes (i.e. 32x32 as in CIFAR) +mb2_cfg = [ + (1, 16, 1, 1), + (6, 24, 2, 1), + (6, 32, 3, 2), + (6, 64, 4, 2), + (6, 96, 3, 1), + (6, 160, 3, 2), + (6, 320, 1, 1), +] + + +class Net(nn.Module): + """Model (simple CNN adapted from 'PyTorch: A 60 Minute Blitz').""" + + def __init__(self) -> None: + super(Net, self).__init__() + self.conv1 = nn.Conv2d(1, 6, 5) + self.pool = nn.MaxPool2d(2, 2) + self.conv2 = nn.Conv2d(6, 16, 5) + self.fc1 = nn.Linear(16 * 4 * 4, 120) + self.fc2 = nn.Linear(120, 84) + self.fc3 = nn.Linear(84, 10) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.pool(F.relu(self.conv1(x))) + x = self.pool(F.relu(self.conv2(x))) + x = x.view(-1, 16 * 4 * 4) + x = F.relu(self.fc1(x)) + x = F.relu(self.fc2(x)) + return self.fc3(x) + + +def train(net, trainloader, optimizer, epochs, device): + """Train the model on the training set.""" + criterion = torch.nn.CrossEntropyLoss() + for _ in range(epochs): + for images, labels in tqdm(trainloader): + optimizer.zero_grad() + criterion(net(images.to(device)), labels.to(device)).backward() + optimizer.step() + + +def test(net, testloader, device): + """Validate the model on the test set.""" + criterion = torch.nn.CrossEntropyLoss() + correct, loss = 0, 0.0 + with torch.no_grad(): + for images, labels in tqdm(testloader): + outputs = net(images.to(device)) + labels = labels.to(device) + loss += criterion(outputs, labels).item() + correct += (torch.max(outputs.data, 1)[1] == labels).sum().item() + accuracy = correct / len(testloader.dataset) + return loss, accuracy + + +def prepare_dataset(use_mnist: bool): + """Get MNIST/CIFAR-10 and return client partitions and global testset.""" + dataset = MNIST if use_mnist else CIFAR10 + if use_mnist: + norm = Normalize((0.1307,), (0.3081,)) + else: + norm = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + + trf = Compose([ToTensor(), norm]) + trainset = dataset("./data", train=True, download=True, transform=trf) + testset = dataset("./data", train=False, download=True, transform=trf) + + print("Partitioning dataset (IID)...") + + # Split trainset into `num_partitions` trainsets + num_images = len(trainset) // NUM_CLIENTS + partition_len = [num_images] * NUM_CLIENTS + + trainsets = random_split( + trainset, partition_len, torch.Generator().manual_seed(2023) + ) + + val_ratio = 0.1 + + # Create dataloaders with train+val support + train_partitions = [] + val_partitions = [] + for trainset_ in trainsets: + num_total = len(trainset_) + num_val = int(val_ratio * num_total) + num_train = num_total - num_val + + for_train, for_val = random_split( + trainset_, [num_train, num_val], torch.Generator().manual_seed(2023) + ) + + train_partitions.append(for_train) + val_partitions.append(for_val) + + return train_partitions, val_partitions, testset + + +# Flower client, adapted from Pytorch quickstart/simulation example +class FlowerClient(fl.client.NumPyClient): + """A FlowerClient that trains a MobileNetV3 model for CIFAR-10 or a much smaller CNN + for MNIST.""" + + def __init__(self, trainset, valset, use_mnist): + self.trainset = trainset + self.valset = valset + # Instantiate model + if use_mnist: + self.model = Net() + else: + self.model = mobilenet_v3_small(num_classes=10) + # let's not reduce spatial resolution too early + self.model.features[0][0].stride = (1, 1) + # Determine device + self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + self.model.to(self.device) # send model to device + + def set_parameters(self, params): + """Set model weights from a list of NumPy ndarrays.""" + params_dict = zip(self.model.state_dict().keys(), params) + state_dict = OrderedDict( + { + k: torch.Tensor(v) if v.shape != torch.Size([]) else torch.Tensor([0]) + for k, v in params_dict + } + ) + self.model.load_state_dict(state_dict, strict=True) + + def get_parameters(self, config): + return [val.cpu().numpy() for _, val in self.model.state_dict().items()] + + def fit(self, parameters, config): + print("Client sampled for fit()") + self.set_parameters(parameters) + # Read hyperparameters from config set by the server + batch, epochs = config["batch_size"], config["epochs"] + # Construct dataloader + trainloader = DataLoader(self.trainset, batch_size=batch, shuffle=True) + # Define optimizer + optimizer = torch.optim.SGD(self.model.parameters(), lr=0.01, momentum=0.9) + # Train + train(self.model, trainloader, optimizer, epochs=epochs, device=self.device) + # Return local model and statistics + return self.get_parameters({}), len(trainloader.dataset), {} + + def evaluate(self, parameters, config): + print("Client sampled for evaluate()") + self.set_parameters(parameters) + # Construct dataloader + valloader = DataLoader(self.valset, batch_size=64) + # Evaluate + loss, accuracy = test(self.model, valloader, device=self.device) + # Return statistics + return float(loss), len(valloader.dataset), {"accuracy": float(accuracy)} + + +def main(): + args = parser.parse_args() + print(args) + + assert args.cid < NUM_CLIENTS + + use_mnist = args.mnist + # Download CIFAR-10 dataset and partition it + trainsets, valsets, _ = prepare_dataset(use_mnist) + + # Start Flower client setting its associated data partition + fl.client.start_numpy_client( + server_address=args.server_address, + client=FlowerClient( + trainset=trainsets[args.cid], valset=valsets[args.cid], use_mnist=use_mnist + ), + ) + + +if __name__ == "__main__": + main() diff --git a/examples/embedded-devices/client_tf.py b/examples/embedded-devices/client_tf.py new file mode 100644 index 000000000000..3457af1c7a66 --- /dev/null +++ b/examples/embedded-devices/client_tf.py @@ -0,0 +1,133 @@ +import math +import argparse +import warnings + +import flwr as fl +import tensorflow as tf +from tensorflow import keras as keras + +parser = argparse.ArgumentParser(description="Flower Embedded devices") +parser.add_argument( + "--server_address", + type=str, + default="0.0.0.0:8080", + help=f"gRPC server address (deafault '0.0.0.0:8080')", +) +parser.add_argument( + "--cid", + type=int, + required=True, + help="Client id. Should be an integer between 0 and NUM_CLIENTS", +) +parser.add_argument( + "--mnist", + action="store_true", + help="If you use Raspberry Pi Zero clients (which just have 512MB or RAM) use MNIST", +) + +warnings.filterwarnings("ignore", category=UserWarning) +NUM_CLIENTS = 50 + + +def prepare_dataset(use_mnist: bool): + """Download and partitions the CIFAR-10/MNIST dataset.""" + if use_mnist: + (x_train, y_train), testset = tf.keras.datasets.mnist.load_data() + else: + (x_train, y_train), testset = tf.keras.datasets.cifar10.load_data() + partitions = [] + # We keep all partitions equal-sized in this example + partition_size = math.floor(len(x_train) / NUM_CLIENTS) + for cid in range(NUM_CLIENTS): + # Split dataset into non-overlapping NUM_CLIENT partitions + idx_from, idx_to = int(cid) * partition_size, (int(cid) + 1) * partition_size + + x_train_cid, y_train_cid = ( + x_train[idx_from:idx_to] / 255.0, + y_train[idx_from:idx_to], + ) + + # now partition into train/validation + # Use 10% of the client's training data for validation + split_idx = math.floor(len(x_train_cid) * 0.9) + + client_train = (x_train_cid[:split_idx], y_train_cid[:split_idx]) + client_val = (x_train_cid[split_idx:], y_train_cid[split_idx:]) + partitions.append((client_train, client_val)) + + return partitions, testset + + +class FlowerClient(fl.client.NumPyClient): + """A FlowerClient that uses MobileNetV3 for CIFAR-10 or a much smaller CNN for + MNIST.""" + + def __init__(self, trainset, valset, use_mnist: bool): + self.x_train, self.y_train = trainset + self.x_val, self.y_val = valset + # Instantiate model + if use_mnist: + # small model for MNIST + self.model = model = keras.Sequential( + [ + keras.Input(shape=(28, 28, 1)), + keras.layers.Conv2D(32, kernel_size=(5, 5), activation="relu"), + keras.layers.MaxPooling2D(pool_size=(2, 2)), + keras.layers.Conv2D(64, kernel_size=(3, 3), activation="relu"), + keras.layers.MaxPooling2D(pool_size=(2, 2)), + keras.layers.Flatten(), + keras.layers.Dropout(0.5), + keras.layers.Dense(10, activation="softmax"), + ] + ) + else: + # let's use a larger model for cifar + self.model = tf.keras.applications.MobileNetV3Small( + (32, 32, 3), classes=10, weights=None + ) + self.model.compile( + "adam", "sparse_categorical_crossentropy", metrics=["accuracy"] + ) + + def get_parameters(self, config): + return self.model.get_weights() + + def set_parameters(self, params): + self.model.set_weights(params) + + def fit(self, parameters, config): + print("Client sampled for fit()") + self.set_parameters(parameters) + # Set hyperparameters from config sent by server/strategy + batch, epochs = config["batch_size"], config["epochs"] + # train + self.model.fit(self.x_train, self.y_train, epochs=epochs, batch_size=batch) + return self.get_parameters({}), len(self.x_train), {} + + def evaluate(self, parameters, config): + print("Client sampled for evaluate()") + self.set_parameters(parameters) + loss, accuracy = self.model.evaluate(self.x_val, self.y_val) + return loss, len(self.x_val), {"accuracy": accuracy} + + +def main(): + args = parser.parse_args() + print(args) + + assert args.cid < NUM_CLIENTS + + use_mnist = args.mnist + # Download CIFAR-10 dataset and partition it + partitions, _ = prepare_dataset(use_mnist) + trainset, valset = partitions[args.cid] + + # Start Flower client setting its associated data partition + fl.client.start_numpy_client( + server_address=args.server_address, + client=FlowerClient(trainset=trainset, valset=valset, use_mnist=use_mnist), + ) + + +if __name__ == "__main__": + main() diff --git a/examples/embedded-devices/requirements.txt b/examples/embedded-devices/requirements.txt deleted file mode 100644 index cdb29230ffeb..000000000000 --- a/examples/embedded-devices/requirements.txt +++ /dev/null @@ -1,4 +0,0 @@ -flwr==1.3.0 -numpy==1.24.2 -torch==2.0.0 -torchvision==0.15.1 diff --git a/examples/embedded-devices/requirements_pytorch.txt b/examples/embedded-devices/requirements_pytorch.txt new file mode 100644 index 000000000000..797ca6db6244 --- /dev/null +++ b/examples/embedded-devices/requirements_pytorch.txt @@ -0,0 +1,4 @@ +flwr>=1.0, <2.0 +torch==1.13.1 +torchvision==0.14.1 +tqdm==4.65.0 diff --git a/examples/embedded-devices/requirements_tf.txt b/examples/embedded-devices/requirements_tf.txt new file mode 100644 index 000000000000..c7068d40b9c2 --- /dev/null +++ b/examples/embedded-devices/requirements_tf.txt @@ -0,0 +1,2 @@ +flwr>=1.0, <2.0 +tensorflow >=2.9.1, != 2.11.1 diff --git a/examples/embedded-devices/run_jetson.sh b/examples/embedded-devices/run_jetson.sh deleted file mode 100755 index 4939506081aa..000000000000 --- a/examples/embedded-devices/run_jetson.sh +++ /dev/null @@ -1,25 +0,0 @@ -#!/bin/bash - -# Copyright 2020 Adap GmbH. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - - -# run as: ./run_jetson.sh --server_address= --cid=0 --model=ResNet18 - -echo "ARGS: ${@}" - -./build_image.sh --build-arg BASE_IMAGE_TYPE=gpu - -docker run --runtime nvidia --rm flower_client ${@} diff --git a/examples/embedded-devices/run_pi.sh b/examples/embedded-devices/run_pi.sh deleted file mode 100755 index 9a2dae23bf46..000000000000 --- a/examples/embedded-devices/run_pi.sh +++ /dev/null @@ -1,25 +0,0 @@ -#!/bin/bash - -# Copyright 2020 Adap GmbH. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - - -# run as: ./run_pi.sh --server_address= --cid=0 --model=Net - -echo "ARGS: ${@}" - -./build_image.sh --build-arg BASE_IMAGE_TYPE=cpu - -docker run --rm flower_client ${@} diff --git a/examples/embedded-devices/server.py b/examples/embedded-devices/server.py index bdf12d6fe640..2a15f792297e 100644 --- a/examples/embedded-devices/server.py +++ b/examples/embedded-devices/server.py @@ -1,47 +1,22 @@ -# Copyright 2020 Adap GmbH. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Minimal example on how to start a simple Flower server.""" - - import argparse -from collections import OrderedDict -from typing import Callable, Dict, Optional, Tuple +from typing import List, Tuple import flwr as fl -import numpy as np -import torch -import torchvision - -import utils +from flwr.common import Metrics -# pylint: disable=no-member -DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") -# pylint: enable=no-member -parser = argparse.ArgumentParser(description="Flower") +parser = argparse.ArgumentParser(description="Flower Embedded devices") parser.add_argument( "--server_address", type=str, - required=True, - help=f"gRPC server address", + default="0.0.0.0:8080", + help=f"gRPC server address (deafault '0.0.0.0:8080')", ) parser.add_argument( "--rounds", type=int, - default=1, - help="Number of rounds of federated learning (default: 1)", + default=5, + help="Number of rounds of federated learning (default: 5)", ) parser.add_argument( "--sample_fraction", @@ -49,12 +24,6 @@ default=1.0, help="Fraction of available clients used for fit/evaluate (default: 1.0)", ) -parser.add_argument( - "--min_sample_size", - type=int, - default=2, - help="Minimum number of clients used for fit/evaluate (default: 2)", -) parser.add_argument( "--min_num_clients", type=int, @@ -62,107 +31,54 @@ help="Minimum number of available clients required for sampling (default: 2)", ) parser.add_argument( - "--log_host", - type=str, - help="Logserver address (no default)", -) -parser.add_argument( - "--model", - type=str, - default="ResNet18", - choices=["Net", "ResNet18"], - help="model to train", -) -parser.add_argument( - "--batch_size", - type=int, - default=32, - help="training batch size", -) -parser.add_argument( - "--num_workers", - type=int, - default=4, - help="number of workers for dataset reading", + "--mnist", + action="store_true", + help="If you use Raspberry Pi Zero clients (which just have 512MB or RAM) use MNIST", ) -parser.add_argument("--pin_memory", action="store_true") -args = parser.parse_args() -def main() -> None: - """Start server and train five rounds.""" +# Define metric aggregation function +def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: + """Thist function averages teh `accuracy` metric sent by the clients in a `evaluate` + stage (i.e. clients received the global model and evaluate it on their local + validation sets).""" + # Multiply accuracy of each client by number of examples used + accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics] + examples = [num_examples for num_examples, _ in metrics] - print(args) + # Aggregate and return custom metric (weighted average) + return {"accuracy": sum(accuracies) / sum(examples)} + + +def fit_config(server_round: int): + """Return a configuration with static batch size and (local) epochs.""" + config = { + "epochs": 3, # Number of local epochs done by clients + "batch_size": 16, # Batch size to use by clients during fit() + } + return config - assert ( - args.min_sample_size <= args.min_num_clients - ), f"Num_clients shouldn't be lower than min_sample_size" - # Configure logger - fl.common.logger.configure("server", host=args.log_host) +def main(): + args = parser.parse_args() - # Load evaluation data - _, testset = utils.load_cifar(download=True) + print(args) - # Create client_manager, strategy, and server - client_manager = fl.server.SimpleClientManager() + # Define strategy strategy = fl.server.strategy.FedAvg( fraction_fit=args.sample_fraction, - min_fit_clients=args.min_sample_size, - min_available_clients=args.min_num_clients, - evaluate_fn=get_eval_fn(testset), + fraction_evaluate=args.sample_fraction, + min_fit_clients=args.min_num_clients, on_fit_config_fn=fit_config, + evaluate_metrics_aggregation_fn=weighted_average, ) - server = fl.server.Server(client_manager=client_manager, strategy=strategy) - # Run server + # Start Flower server fl.server.start_server( server_address=args.server_address, - server=server, - config=fl.server.ServerConfig(num_rounds=args.rounds), - ) - - -def fit_config(server_round: int) -> Dict[str, fl.common.Scalar]: - """Return a configuration with static batch size and (local) epochs.""" - config = { - "epoch_global": str(server_round), - "epochs": str(1), - "batch_size": str(args.batch_size), - "num_workers": str(args.num_workers), - "pin_memory": str(args.pin_memory), - } - return config - - -def set_weights(model: torch.nn.ModuleList, weights: fl.common.NDArrays) -> None: - """Set model weights from a list of NumPy ndarrays.""" - state_dict = OrderedDict( - { - k: torch.tensor(np.atleast_1d(v)) - for k, v in zip(model.state_dict().keys(), weights) - } + config=fl.server.ServerConfig(num_rounds=3), + strategy=strategy, ) - model.load_state_dict(state_dict, strict=True) - - -def get_eval_fn( - testset: torchvision.datasets.CIFAR10, -) -> Callable[[fl.common.NDArrays], Optional[Tuple[float, float]]]: - """Return an evaluation function for centralized evaluation.""" - - def evaluate(weights: fl.common.NDArrays) -> Optional[Tuple[float, float]]: - """Use the entire CIFAR-10 test set for evaluation.""" - - model = utils.load_model(args.model) - set_weights(model, weights) - model.to(DEVICE) - - testloader = torch.utils.data.DataLoader(testset, batch_size=32, shuffle=False) - loss, accuracy = utils.test(model, testloader, device=DEVICE) - return loss, {"accuracy": accuracy} - - return evaluate if __name__ == "__main__": diff --git a/examples/embedded-devices/utils.py b/examples/embedded-devices/utils.py deleted file mode 100644 index c0946a758e45..000000000000 --- a/examples/embedded-devices/utils.py +++ /dev/null @@ -1,175 +0,0 @@ -# Copyright 2020 Adap GmbH. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""PyTorch CIFAR-10 image classification. - -The code is generally adapted from 'PyTorch: A 60 Minute Blitz'. Further -explanations are given in the official PyTorch tutorial: - -https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html -""" - - -# mypy: ignore-errors -# pylint: disable=W0223 - - -from collections import OrderedDict -from pathlib import Path -from time import time -from typing import Tuple - -import flwr as fl -import torch -import torch.nn as nn -import torch.nn.functional as F -import torchvision.transforms as transforms -from torch import Tensor -from torchvision import datasets -from torchvision.models import resnet18 - -DATA_ROOT = Path("./data") - - -# pylint: disable=unsubscriptable-object -class Net(nn.Module): - """Simple CNN adapted from 'PyTorch: A 60 Minute Blitz'.""" - - def __init__(self) -> None: - super(Net, self).__init__() - self.conv1 = nn.Conv2d(3, 6, 5) - self.pool = nn.MaxPool2d(2, 2) - self.conv2 = nn.Conv2d(6, 16, 5) - self.fc1 = nn.Linear(16 * 5 * 5, 120) - self.fc2 = nn.Linear(120, 84) - self.fc3 = nn.Linear(84, 10) - - # pylint: disable=arguments-differ,invalid-name - def forward(self, x: Tensor) -> Tensor: - """Compute forward pass.""" - x = self.pool(F.relu(self.conv1(x))) - x = self.pool(F.relu(self.conv2(x))) - x = x.view(-1, 16 * 5 * 5) - x = F.relu(self.fc1(x)) - x = F.relu(self.fc2(x)) - x = self.fc3(x) - return x - - def get_weights(self) -> fl.common.NDArrays: - """Get model weights as a list of NumPy ndarrays.""" - return [val.cpu().numpy() for _, val in self.state_dict().items()] - - def set_weights(self, weights: fl.common.NDArrays) -> None: - """Set model weights from a list of NumPy ndarrays.""" - state_dict = OrderedDict( - {k: torch.tensor(v) for k, v in zip(self.state_dict().keys(), weights)} - ) - self.load_state_dict(state_dict, strict=True) - - -def ResNet18(): - """Returns a ResNet18 model from TorchVision adapted for CIFAR-10.""" - - model = resnet18(num_classes=10) - - # replace w/ smaller input layer - model.conv1 = torch.nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) - nn.init.kaiming_normal_(model.conv1.weight, mode="fan_out", nonlinearity="relu") - # no need for pooling if training for CIFAR-10 - model.maxpool = torch.nn.Identity() - - return model - - -def load_model(model_name: str) -> nn.Module: - if model_name == "Net": - return Net() - elif model_name == "ResNet18": - return ResNet18() - else: - raise NotImplementedError(f"model {model_name} is not implemented") - - -# pylint: disable=unused-argument -def load_cifar(download=False) -> Tuple[datasets.CIFAR10, datasets.CIFAR10]: - """Load CIFAR-10 (training and test set).""" - transform = transforms.Compose( - [ - transforms.ToTensor(), - transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), - ] - ) - trainset = datasets.CIFAR10( - root=DATA_ROOT / "cifar-10", train=True, download=download, transform=transform - ) - testset = datasets.CIFAR10( - root=DATA_ROOT / "cifar-10", train=False, download=download, transform=transform - ) - return trainset, testset - - -def train( - net: Net, - trainloader: torch.utils.data.DataLoader, - epochs: int, - device: torch.device, # pylint: disable=no-member -) -> None: - """Train the network.""" - # Define loss and optimizer - criterion = nn.CrossEntropyLoss() - optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9) - - print(f"Training {epochs} epoch(s) w/ {len(trainloader)} batches each") - t = time() - # Train the network - for epoch in range(epochs): # loop over the dataset multiple times - running_loss = 0.0 - for i, data in enumerate(trainloader, 0): - images, labels = data[0].to(device), data[1].to(device) - - # zero the parameter gradients - optimizer.zero_grad() - - # forward + backward + optimize - outputs = net(images) - loss = criterion(outputs, labels) - loss.backward() - optimizer.step() - - # print statistics - running_loss += loss.item() - if i % 2000 == 1999: # print every 2000 mini-batches - print("[%d, %5d] loss: %.3f" % (epoch + 1, i + 1, running_loss / 2000)) - running_loss = 0.0 - - print(f"Epoch took: {time() - t:.2f} seconds") - - -def test( - net: Net, - testloader: torch.utils.data.DataLoader, - device: torch.device, # pylint: disable=no-member -) -> Tuple[float, float]: - """Validate the network on the entire test set.""" - criterion = nn.CrossEntropyLoss() - correct, loss = 0, 0.0 - with torch.no_grad(): - for data in testloader: - images, labels = data[0].to(device), data[1].to(device) - outputs = net(images) - loss += criterion(outputs, labels).item() - _, predicted = torch.max(outputs.data, 1) # pylint: disable=no-member - correct += (predicted == labels).sum().item() - accuracy = correct / len(testloader.dataset) - return loss, accuracy diff --git a/examples/ios/pyproject.toml b/examples/ios/pyproject.toml index 531e9253e0d1..c1bdbb815bd5 100644 --- a/examples/ios/pyproject.toml +++ b/examples/ios/pyproject.toml @@ -10,4 +10,4 @@ authors = ["The Flower Authors "] [tool.poetry.dependencies] python = ">=3.8,<3.11" -flwr = "^1.0.0" +flwr = ">=1.0,<2.0" diff --git a/examples/ios/requirements.txt b/examples/ios/requirements.txt index 9d6b364ee36c..236ca6a487fa 100644 --- a/examples/ios/requirements.txt +++ b/examples/ios/requirements.txt @@ -1,2 +1 @@ -flwr~=1.4.0 -numpy~=1.21.1 +flwr>=1.0, <2.0 diff --git a/examples/mt-pytorch/pyproject.toml b/examples/mt-pytorch/pyproject.toml index f285af016499..4978035495ea 100644 --- a/examples/mt-pytorch/pyproject.toml +++ b/examples/mt-pytorch/pyproject.toml @@ -10,8 +10,7 @@ authors = ["The Flower Authors "] [tool.poetry.dependencies] python = ">=3.8,<3.11" -flwr-nightly = { version = "^1.5.0.dev20230629", extras = ["simulation", "rest"] } -# flwr = { path = "../../", develop = true, extras = ["simulation", "rest"] } -torch = "^2.0.1" -torchvision = "^0.15.2" +flwr-nightly = {version = ">=1.0,<2.0", extras = ["rest", "simulation"]} +torch = "1.13.1" +torchvision = "0.14.1" tqdm = "4.65.0" diff --git a/examples/mt-pytorch/requirements.txt b/examples/mt-pytorch/requirements.txt index 98b7617e776d..ae0a65386f2b 100644 --- a/examples/mt-pytorch/requirements.txt +++ b/examples/mt-pytorch/requirements.txt @@ -1,4 +1,4 @@ -flwr-nightly[simulation,rest] +flwr-nightly[rest,simulation]>=1.0, <2.0 torch==1.13.1 -torchvision==0.13.0 +torchvision==0.14.1 tqdm==4.65.0 diff --git a/examples/opacus/pyproject.toml b/examples/opacus/pyproject.toml index 8ee2cc7d10b8..af0eaf596fbf 100644 --- a/examples/opacus/pyproject.toml +++ b/examples/opacus/pyproject.toml @@ -9,9 +9,7 @@ description = "Differentially Private Federated Learning with Opacus and Flower" authors = ["The Flower Authors "] [tool.poetry.dependencies] -python = "^3.8" -flwr = "^1.0.0" -# flwr = { path = "../../", develop = true } # Development -opacus = "^1.4.0" -torch = "^1.13.1" -torchvision = "^0.14.0" +python = ">=3.8,<3.11" +flwr = ">=1.0,<2.0" +opacus = "1.4.0" +torchvision = "0.15.2" diff --git a/examples/opacus/requirements.txt b/examples/opacus/requirements.txt index e6e5dbb2fdfa..f17b78fbf311 100644 --- a/examples/opacus/requirements.txt +++ b/examples/opacus/requirements.txt @@ -1,4 +1,3 @@ -flwr~=1.4.0 -numpy~=1.21.1 -torch~=2.0.1 -torchvision~=0.15.2 +flwr>=1.0, <2.0 +opacus==1.4.0 +torchvision==0.15.2 diff --git a/examples/quickstart-jax/pyproject.toml b/examples/quickstart-jax/pyproject.toml index 6a67cff6f4b5..41b4462d0a14 100644 --- a/examples/quickstart-jax/pyproject.toml +++ b/examples/quickstart-jax/pyproject.toml @@ -5,12 +5,11 @@ description = "JAX example training a linear regression model with federated lea authors = ["The Flower Authors "] [tool.poetry.dependencies] -python = "^3.8" -flwr = "^1.0.0" -jax = "^0.4.0" -jaxlib = "^0.4.0" -scikit-learn = "^1.1.1" -numpy = "^1.21.4" +python = ">=3.8,<3.11" +flwr = "1.0.0" +jax = "0.4.17" +jaxlib = "0.4.17" +scikit-learn = "1.1.1" [build-system] requires = ["poetry-core>=1.4.0"] diff --git a/examples/quickstart-jax/requirements.txt b/examples/quickstart-jax/requirements.txt index bf7a9c64d66f..964f07a51b7d 100644 --- a/examples/quickstart-jax/requirements.txt +++ b/examples/quickstart-jax/requirements.txt @@ -1,4 +1,4 @@ -flwr~=1.4.0 -jax~=0.4.10 -numpy~=1.21.1 -scikit_learn~=1.2.2 +flwr>=1.0,<2.0 +jax==0.4.17 +jaxlib==0.4.17 +scikit-learn==1.1.1 diff --git a/examples/secaggplus-mt/README.md b/examples/secaggplus-mt/README.md index 164174509d4c..0b3b4db3942e 100644 --- a/examples/secaggplus-mt/README.md +++ b/examples/secaggplus-mt/README.md @@ -1,6 +1,6 @@ # Secure Aggregation with Driver API -This example contains highly experimental code. Please consult the regular PyTorch code examples ([quickstart](https://github.com/adap/flower/tree/main/examples/quickstart_pytorch), [advanced](https://github.com/adap/flower/tree/main/examples/advanced_pytorch)) to learn how to use Flower with PyTorch. +This example contains highly experimental code. Please consult the regular PyTorch code examples ([quickstart](https://github.com/adap/flower/tree/main/examples/quickstart-pytorch), [advanced](https://github.com/adap/flower/tree/main/examples/advanced-pytorch)) to learn how to use Flower with PyTorch. ## Installing Dependencies diff --git a/examples/secaggplus-mt/driver.py b/examples/secaggplus-mt/driver.py index c168edf070af..4e0a53ed1c91 100644 --- a/examples/secaggplus-mt/driver.py +++ b/examples/secaggplus-mt/driver.py @@ -23,7 +23,7 @@ def merge(_task: task_pb2.Task, _merge_task: task_pb2.Task) -> task_pb2.Task: task_pb2.TaskIns( task_id="", # Do not set, will be created and set by the DriverAPI group_id="", - workload_id="", + workload_id=workload_id, task=merge( task, task_pb2.Task( @@ -84,8 +84,14 @@ def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: # -------------------------------------------------------------------------- Driver SDK driver.connect() +create_workload_res: driver_pb2.CreateWorkloadResponse = driver.create_workload( + req=driver_pb2.CreateWorkloadRequest() +) # -------------------------------------------------------------------------- Driver SDK +workload_id = create_workload_res.workload_id +print(f"Created workload id {workload_id}") + history = History() for server_round in range(num_rounds): print(f"Commencing server round {server_round + 1}") @@ -113,7 +119,7 @@ def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: # loop and wait until enough client nodes are available. while True: # Get a list of node ID's from the server - get_nodes_req = driver_pb2.GetNodesRequest() + get_nodes_req = driver_pb2.GetNodesRequest(workload_id=workload_id) # ---------------------------------------------------------------------- Driver SDK get_nodes_res: driver_pb2.GetNodesResponse = driver.get_nodes( @@ -121,7 +127,7 @@ def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: ) # ---------------------------------------------------------------------- Driver SDK - all_node_ids: List[int] = get_nodes_res.node_ids + all_node_ids: List[int] = [node.node_id for node in get_nodes_res.nodes] if len(all_node_ids) >= num_client_nodes_per_round: # Sample client nodes diff --git a/pyproject.toml b/pyproject.toml index 91d7d810f810..dfdd75ba11ab 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -66,7 +66,7 @@ cryptography = "^41.0.2" pycryptodome = "^3.18.0" iterators = "^0.0.2" # Optional dependencies (VCE) -ray = { version = "==2.6.3", extras = ["default"], optional = true } +ray = { version = "==2.6.3", optional = true } pydantic = { version = "<2.0.0", optional = true } # Optional dependencies (REST transport layer) requests = { version = "^2.31.0", optional = true } diff --git a/src/proto/flwr/proto/driver.proto b/src/proto/flwr/proto/driver.proto index 1cfb77135d5a..1caaad88a0da 100644 --- a/src/proto/flwr/proto/driver.proto +++ b/src/proto/flwr/proto/driver.proto @@ -36,10 +36,10 @@ service Driver { // CreateWorkload message CreateWorkloadRequest {} -message CreateWorkloadResponse { string workload_id = 1; } +message CreateWorkloadResponse { uint64 workload_id = 1; } // GetNodes messages -message GetNodesRequest { string workload_id = 1; } +message GetNodesRequest { uint64 workload_id = 1; } message GetNodesResponse { repeated Node nodes = 1; } // PushTaskIns messages diff --git a/src/proto/flwr/proto/task.proto b/src/proto/flwr/proto/task.proto index 29e07641bb1c..d87fb39c2637 100644 --- a/src/proto/flwr/proto/task.proto +++ b/src/proto/flwr/proto/task.proto @@ -36,14 +36,14 @@ message Task { message TaskIns { string task_id = 1; string group_id = 2; - string workload_id = 3; + uint64 workload_id = 3; Task task = 4; } message TaskRes { string task_id = 1; string group_id = 2; - string workload_id = 3; + uint64 workload_id = 3; Task task = 4; } diff --git a/src/py/flwr/client/grpc_client/connection.py b/src/py/flwr/client/grpc_client/connection.py index eda869d3a326..cc64ec9a268a 100644 --- a/src/py/flwr/client/grpc_client/connection.py +++ b/src/py/flwr/client/grpc_client/connection.py @@ -117,7 +117,7 @@ def receive() -> TaskIns: return TaskIns( task_id=str(uuid.uuid4()), group_id="", - workload_id="", + workload_id=0, task=Task( producer=Node(node_id=0, anonymous=True), consumer=Node(node_id=0, anonymous=True), diff --git a/src/py/flwr/client/message_handler/message_handler.py b/src/py/flwr/client/message_handler/message_handler.py index 9b26a9bd5ca0..f50923450f62 100644 --- a/src/py/flwr/client/message_handler/message_handler.py +++ b/src/py/flwr/client/message_handler/message_handler.py @@ -76,7 +76,7 @@ def handle(client_fn: ClientFn, task_ins: TaskIns) -> Tuple[TaskRes, int, bool]: task_res = TaskRes( task_id="", group_id="", - workload_id="", + workload_id=0, task=Task( ancestry=[], sa=SecureAggregation(named_values=serde.named_values_to_proto(res)), diff --git a/src/py/flwr/client/message_handler/message_handler_test.py b/src/py/flwr/client/message_handler/message_handler_test.py index 31cbb00edf63..1fc2269ad75d 100644 --- a/src/py/flwr/client/message_handler/message_handler_test.py +++ b/src/py/flwr/client/message_handler/message_handler_test.py @@ -120,7 +120,7 @@ def test_client_without_get_properties() -> None: task_ins: TaskIns = TaskIns( task_id=str(uuid.uuid4()), group_id="", - workload_id="", + workload_id=0, task=Task( producer=Node(node_id=0, anonymous=True), consumer=Node(node_id=0, anonymous=True), @@ -146,7 +146,7 @@ def test_client_without_get_properties() -> None: TaskRes( task_id=str(uuid.uuid4()), group_id="", - workload_id="", + workload_id=0, ) ) # pylint: disable=no-member @@ -183,7 +183,7 @@ def test_client_with_get_properties() -> None: task_ins = TaskIns( task_id=str(uuid.uuid4()), group_id="", - workload_id="", + workload_id=0, task=Task( producer=Node(node_id=0, anonymous=True), consumer=Node(node_id=0, anonymous=True), @@ -209,7 +209,7 @@ def test_client_with_get_properties() -> None: TaskRes( task_id=str(uuid.uuid4()), group_id="", - workload_id="", + workload_id=0, ) ) # pylint: disable=no-member diff --git a/src/py/flwr/client/message_handler/task_handler.py b/src/py/flwr/client/message_handler/task_handler.py index 03688c52ac8f..b48c7433c1da 100644 --- a/src/py/flwr/client/message_handler/task_handler.py +++ b/src/py/flwr/client/message_handler/task_handler.py @@ -129,7 +129,7 @@ def wrap_client_message_in_task_res(client_message: ClientMessage) -> TaskRes: return TaskRes( task_id="", group_id="", - workload_id="", + workload_id=0, task=Task(ancestry=[], legacy_client_message=client_message), ) diff --git a/src/py/flwr/client/message_handler/task_handler_test.py b/src/py/flwr/client/message_handler/task_handler_test.py index 347b9ad32c4b..e1b7fac69d24 100644 --- a/src/py/flwr/client/message_handler/task_handler_test.py +++ b/src/py/flwr/client/message_handler/task_handler_test.py @@ -92,7 +92,7 @@ def test_validate_task_res() -> None: assert not validate_task_res(task_res) task_res.Clear() - task_res.workload_id = "123" + task_res.workload_id = 61016 assert not validate_task_res(task_res) task_res.Clear() diff --git a/src/py/flwr/driver/app_test.py b/src/py/flwr/driver/app_test.py index 792bd84b6106..4fcd924f8432 100644 --- a/src/py/flwr/driver/app_test.py +++ b/src/py/flwr/driver/app_test.py @@ -43,7 +43,7 @@ def test_simple_client_manager_update(self) -> None: ] driver = MagicMock() driver.stub = "driver stub" - driver.create_workload.return_value = CreateWorkloadResponse(workload_id="1") + driver.create_workload.return_value = CreateWorkloadResponse(workload_id=1) driver.get_nodes.return_value = GetNodesResponse(nodes=expected_nodes) client_manager = SimpleClientManager() lock = threading.Lock() diff --git a/src/py/flwr/driver/driver.py b/src/py/flwr/driver/driver.py index 64e61ec4cb61..130cd2bbc707 100644 --- a/src/py/flwr/driver/driver.py +++ b/src/py/flwr/driver/driver.py @@ -23,7 +23,17 @@ from flwr.common import EventType, event from flwr.common.grpc import create_channel from flwr.common.logger import log -from flwr.proto import driver_pb2, driver_pb2_grpc +from flwr.proto.driver_pb2 import ( + CreateWorkloadRequest, + CreateWorkloadResponse, + GetNodesRequest, + GetNodesResponse, + PullTaskResRequest, + PullTaskResResponse, + PushTaskInsRequest, + PushTaskInsResponse, +) +from flwr.proto.driver_pb2_grpc import DriverStub DEFAULT_SERVER_ADDRESS_DRIVER = "[::]:9091" @@ -46,7 +56,7 @@ def __init__( self.driver_service_address = driver_service_address self.certificates = certificates self.channel: Optional[grpc.Channel] = None - self.stub: Optional[driver_pb2_grpc.DriverStub] = None + self.stub: Optional[DriverStub] = None def connect(self) -> None: """Connect to the Driver API.""" @@ -58,7 +68,7 @@ def connect(self) -> None: server_address=self.driver_service_address, root_certificates=self.certificates, ) - self.stub = driver_pb2_grpc.DriverStub(self.channel) + self.stub = DriverStub(self.channel) log(INFO, "[Driver] Connected to %s", self.driver_service_address) def disconnect(self) -> None: @@ -73,9 +83,7 @@ def disconnect(self) -> None: channel.close() log(INFO, "[Driver] Disconnected") - def create_workload( - self, req: driver_pb2.CreateWorkloadRequest - ) -> driver_pb2.CreateWorkloadResponse: + def create_workload(self, req: CreateWorkloadRequest) -> CreateWorkloadResponse: """Request for workload ID.""" # Check if channel is open if self.stub is None: @@ -83,10 +91,10 @@ def create_workload( raise Exception("`Driver` instance not connected") # Call Driver API - res: driver_pb2.CreateWorkloadResponse = self.stub.CreateWorkload(request=req) + res: CreateWorkloadResponse = self.stub.CreateWorkload(request=req) return res - def get_nodes(self, req: driver_pb2.GetNodesRequest) -> driver_pb2.GetNodesResponse: + def get_nodes(self, req: GetNodesRequest) -> GetNodesResponse: """Get client IDs.""" # Check if channel is open if self.stub is None: @@ -94,12 +102,10 @@ def get_nodes(self, req: driver_pb2.GetNodesRequest) -> driver_pb2.GetNodesRespo raise Exception("`Driver` instance not connected") # Call Driver API - res: driver_pb2.GetNodesResponse = self.stub.GetNodes(request=req) + res: GetNodesResponse = self.stub.GetNodes(request=req) return res - def push_task_ins( - self, req: driver_pb2.PushTaskInsRequest - ) -> driver_pb2.PushTaskInsResponse: + def push_task_ins(self, req: PushTaskInsRequest) -> PushTaskInsResponse: """Schedule tasks.""" # Check if channel is open if self.stub is None: @@ -107,12 +113,10 @@ def push_task_ins( raise Exception("`Driver` instance not connected") # Call Driver API - res: driver_pb2.PushTaskInsResponse = self.stub.PushTaskIns(request=req) + res: PushTaskInsResponse = self.stub.PushTaskIns(request=req) return res - def pull_task_res( - self, req: driver_pb2.PullTaskResRequest - ) -> driver_pb2.PullTaskResResponse: + def pull_task_res(self, req: PullTaskResRequest) -> PullTaskResResponse: """Get task results.""" # Check if channel is open if self.stub is None: @@ -120,5 +124,5 @@ def pull_task_res( raise Exception("`Driver` instance not connected") # Call Driver API - res: driver_pb2.PullTaskResResponse = self.stub.PullTaskRes(request=req) + res: PullTaskResResponse = self.stub.PullTaskRes(request=req) return res diff --git a/src/py/flwr/driver/driver_client_proxy.py b/src/py/flwr/driver/driver_client_proxy.py index cd5d36cafdd7..deb472458a15 100644 --- a/src/py/flwr/driver/driver_client_proxy.py +++ b/src/py/flwr/driver/driver_client_proxy.py @@ -31,7 +31,7 @@ class DriverClientProxy(ClientProxy): """Flower client proxy which delegates work using the Driver API.""" - def __init__(self, node_id: int, driver: Driver, anonymous: bool, workload_id: str): + def __init__(self, node_id: int, driver: Driver, anonymous: bool, workload_id: int): super().__init__(str(node_id)) self.node_id = node_id self.driver = driver diff --git a/src/py/flwr/driver/driver_client_proxy_test.py b/src/py/flwr/driver/driver_client_proxy_test.py index fa2a29e88687..f413b8d8d99d 100644 --- a/src/py/flwr/driver/driver_client_proxy_test.py +++ b/src/py/flwr/driver/driver_client_proxy_test.py @@ -52,7 +52,7 @@ def test_get_properties(self) -> None: task_pb2.TaskRes( task_id="554bd3c8-8474-4b93-a7db-c7bec1bf0012", group_id="", - workload_id="", + workload_id=0, task=task_pb2.Task( legacy_client_message=ClientMessage( get_properties_res=ClientMessage.GetPropertiesRes( @@ -64,7 +64,7 @@ def test_get_properties(self) -> None: ] ) client = DriverClientProxy( - node_id=1, driver=self.driver, anonymous=True, workload_id="" + node_id=1, driver=self.driver, anonymous=True, workload_id=0 ) request_properties: Config = {"tensor_type": "str"} ins: flwr.common.GetPropertiesIns = flwr.common.GetPropertiesIns( @@ -88,7 +88,7 @@ def test_get_parameters(self) -> None: task_pb2.TaskRes( task_id="554bd3c8-8474-4b93-a7db-c7bec1bf0012", group_id="", - workload_id="", + workload_id=0, task=task_pb2.Task( legacy_client_message=ClientMessage( get_parameters_res=ClientMessage.GetParametersRes( @@ -100,7 +100,7 @@ def test_get_parameters(self) -> None: ] ) client = DriverClientProxy( - node_id=1, driver=self.driver, anonymous=True, workload_id="" + node_id=1, driver=self.driver, anonymous=True, workload_id=0 ) get_parameters_ins = GetParametersIns(config={}) @@ -123,7 +123,7 @@ def test_fit(self) -> None: task_pb2.TaskRes( task_id="554bd3c8-8474-4b93-a7db-c7bec1bf0012", group_id="", - workload_id="", + workload_id=0, task=task_pb2.Task( legacy_client_message=ClientMessage( fit_res=ClientMessage.FitRes( @@ -136,7 +136,7 @@ def test_fit(self) -> None: ] ) client = DriverClientProxy( - node_id=1, driver=self.driver, anonymous=True, workload_id="" + node_id=1, driver=self.driver, anonymous=True, workload_id=0 ) parameters = flwr.common.ndarrays_to_parameters([np.ones((2, 2))]) ins: flwr.common.FitIns = flwr.common.FitIns(parameters, {}) @@ -160,7 +160,7 @@ def test_evaluate(self) -> None: task_pb2.TaskRes( task_id="554bd3c8-8474-4b93-a7db-c7bec1bf0012", group_id="", - workload_id="", + workload_id=0, task=task_pb2.Task( legacy_client_message=ClientMessage( evaluate_res=ClientMessage.EvaluateRes( @@ -172,7 +172,7 @@ def test_evaluate(self) -> None: ] ) client = DriverClientProxy( - node_id=1, driver=self.driver, anonymous=True, workload_id="" + node_id=1, driver=self.driver, anonymous=True, workload_id=0 ) parameters = flwr.common.Parameters(tensors=[], tensor_type="np") evaluate_ins: flwr.common.EvaluateIns = flwr.common.EvaluateIns(parameters, {}) diff --git a/src/py/flwr/proto/driver_pb2.py b/src/py/flwr/proto/driver_pb2.py index c18d9c593c28..6ac066d7eab3 100644 --- a/src/py/flwr/proto/driver_pb2.py +++ b/src/py/flwr/proto/driver_pb2.py @@ -16,7 +16,7 @@ from flwr.proto import task_pb2 as flwr_dot_proto_dot_task__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x17\x66lwr/proto/driver.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x15\x66lwr/proto/task.proto\"\x17\n\x15\x43reateWorkloadRequest\"-\n\x16\x43reateWorkloadResponse\x12\x13\n\x0bworkload_id\x18\x01 \x01(\t\"&\n\x0fGetNodesRequest\x12\x13\n\x0bworkload_id\x18\x01 \x01(\t\"3\n\x10GetNodesResponse\x12\x1f\n\x05nodes\x18\x01 \x03(\x0b\x32\x10.flwr.proto.Node\"@\n\x12PushTaskInsRequest\x12*\n\rtask_ins_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.TaskIns\"\'\n\x13PushTaskInsResponse\x12\x10\n\x08task_ids\x18\x02 \x03(\t\"F\n\x12PullTaskResRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x10\n\x08task_ids\x18\x02 \x03(\t\"A\n\x13PullTaskResResponse\x12*\n\rtask_res_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.TaskRes2\xd0\x02\n\x06\x44river\x12Y\n\x0e\x43reateWorkload\x12!.flwr.proto.CreateWorkloadRequest\x1a\".flwr.proto.CreateWorkloadResponse\"\x00\x12G\n\x08GetNodes\x12\x1b.flwr.proto.GetNodesRequest\x1a\x1c.flwr.proto.GetNodesResponse\"\x00\x12P\n\x0bPushTaskIns\x12\x1e.flwr.proto.PushTaskInsRequest\x1a\x1f.flwr.proto.PushTaskInsResponse\"\x00\x12P\n\x0bPullTaskRes\x12\x1e.flwr.proto.PullTaskResRequest\x1a\x1f.flwr.proto.PullTaskResResponse\"\x00\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x17\x66lwr/proto/driver.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x15\x66lwr/proto/task.proto\"\x17\n\x15\x43reateWorkloadRequest\"-\n\x16\x43reateWorkloadResponse\x12\x13\n\x0bworkload_id\x18\x01 \x01(\x04\"&\n\x0fGetNodesRequest\x12\x13\n\x0bworkload_id\x18\x01 \x01(\x04\"3\n\x10GetNodesResponse\x12\x1f\n\x05nodes\x18\x01 \x03(\x0b\x32\x10.flwr.proto.Node\"@\n\x12PushTaskInsRequest\x12*\n\rtask_ins_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.TaskIns\"\'\n\x13PushTaskInsResponse\x12\x10\n\x08task_ids\x18\x02 \x03(\t\"F\n\x12PullTaskResRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x10\n\x08task_ids\x18\x02 \x03(\t\"A\n\x13PullTaskResResponse\x12*\n\rtask_res_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.TaskRes2\xd0\x02\n\x06\x44river\x12Y\n\x0e\x43reateWorkload\x12!.flwr.proto.CreateWorkloadRequest\x1a\".flwr.proto.CreateWorkloadResponse\"\x00\x12G\n\x08GetNodes\x12\x1b.flwr.proto.GetNodesRequest\x1a\x1c.flwr.proto.GetNodesResponse\"\x00\x12P\n\x0bPushTaskIns\x12\x1e.flwr.proto.PushTaskInsRequest\x1a\x1f.flwr.proto.PushTaskInsResponse\"\x00\x12P\n\x0bPullTaskRes\x12\x1e.flwr.proto.PullTaskResRequest\x1a\x1f.flwr.proto.PullTaskResResponse\"\x00\x62\x06proto3') diff --git a/src/py/flwr/proto/driver_pb2.pyi b/src/py/flwr/proto/driver_pb2.pyi index 486bddb0f76f..8b940972cb6d 100644 --- a/src/py/flwr/proto/driver_pb2.pyi +++ b/src/py/flwr/proto/driver_pb2.pyi @@ -23,10 +23,10 @@ global___CreateWorkloadRequest = CreateWorkloadRequest class CreateWorkloadResponse(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor WORKLOAD_ID_FIELD_NUMBER: builtins.int - workload_id: typing.Text + workload_id: builtins.int def __init__(self, *, - workload_id: typing.Text = ..., + workload_id: builtins.int = ..., ) -> None: ... def ClearField(self, field_name: typing_extensions.Literal["workload_id",b"workload_id"]) -> None: ... global___CreateWorkloadResponse = CreateWorkloadResponse @@ -35,10 +35,10 @@ class GetNodesRequest(google.protobuf.message.Message): """GetNodes messages""" DESCRIPTOR: google.protobuf.descriptor.Descriptor WORKLOAD_ID_FIELD_NUMBER: builtins.int - workload_id: typing.Text + workload_id: builtins.int def __init__(self, *, - workload_id: typing.Text = ..., + workload_id: builtins.int = ..., ) -> None: ... def ClearField(self, field_name: typing_extensions.Literal["workload_id",b"workload_id"]) -> None: ... global___GetNodesRequest = GetNodesRequest diff --git a/src/py/flwr/proto/task_pb2.py b/src/py/flwr/proto/task_pb2.py index 42d3952f61df..69bad48d0d37 100644 --- a/src/py/flwr/proto/task_pb2.py +++ b/src/py/flwr/proto/task_pb2.py @@ -16,7 +16,7 @@ from flwr.proto import transport_pb2 as flwr_dot_proto_dot_transport__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x66lwr/proto/task.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x1a\x66lwr/proto/transport.proto\"\xbe\x02\n\x04Task\x12\"\n\x08producer\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\"\n\x08\x63onsumer\x18\x02 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x12\n\ncreated_at\x18\x03 \x01(\t\x12\x14\n\x0c\x64\x65livered_at\x18\x04 \x01(\t\x12\x0b\n\x03ttl\x18\x05 \x01(\t\x12\x10\n\x08\x61ncestry\x18\x06 \x03(\t\x12)\n\x02sa\x18\x07 \x01(\x0b\x32\x1d.flwr.proto.SecureAggregation\x12<\n\x15legacy_server_message\x18\x65 \x01(\x0b\x32\x19.flwr.proto.ServerMessageB\x02\x18\x01\x12<\n\x15legacy_client_message\x18\x66 \x01(\x0b\x32\x19.flwr.proto.ClientMessageB\x02\x18\x01\"a\n\x07TaskIns\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\x10\n\x08group_id\x18\x02 \x01(\t\x12\x13\n\x0bworkload_id\x18\x03 \x01(\t\x12\x1e\n\x04task\x18\x04 \x01(\x0b\x32\x10.flwr.proto.Task\"a\n\x07TaskRes\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\x10\n\x08group_id\x18\x02 \x01(\t\x12\x13\n\x0bworkload_id\x18\x03 \x01(\t\x12\x1e\n\x04task\x18\x04 \x01(\x0b\x32\x10.flwr.proto.Task\"\xf3\x03\n\x05Value\x12\x10\n\x06\x64ouble\x18\x01 \x01(\x01H\x00\x12\x10\n\x06sint64\x18\x02 \x01(\x12H\x00\x12\x0e\n\x04\x62ool\x18\x03 \x01(\x08H\x00\x12\x10\n\x06string\x18\x04 \x01(\tH\x00\x12\x0f\n\x05\x62ytes\x18\x05 \x01(\x0cH\x00\x12\x33\n\x0b\x64ouble_list\x18\x15 \x01(\x0b\x32\x1c.flwr.proto.Value.DoubleListH\x00\x12\x33\n\x0bsint64_list\x18\x16 \x01(\x0b\x32\x1c.flwr.proto.Value.Sint64ListH\x00\x12/\n\tbool_list\x18\x17 \x01(\x0b\x32\x1a.flwr.proto.Value.BoolListH\x00\x12\x33\n\x0bstring_list\x18\x18 \x01(\x0b\x32\x1c.flwr.proto.Value.StringListH\x00\x12\x31\n\nbytes_list\x18\x19 \x01(\x0b\x32\x1b.flwr.proto.Value.BytesListH\x00\x1a\x1a\n\nDoubleList\x12\x0c\n\x04vals\x18\x01 \x03(\x01\x1a\x1a\n\nSint64List\x12\x0c\n\x04vals\x18\x01 \x03(\x12\x1a\x18\n\x08\x42oolList\x12\x0c\n\x04vals\x18\x01 \x03(\x08\x1a\x1a\n\nStringList\x12\x0c\n\x04vals\x18\x01 \x03(\t\x1a\x19\n\tBytesList\x12\x0c\n\x04vals\x18\x01 \x03(\x0c\x42\x07\n\x05value\"\xa0\x01\n\x11SecureAggregation\x12\x44\n\x0cnamed_values\x18\x01 \x03(\x0b\x32..flwr.proto.SecureAggregation.NamedValuesEntry\x1a\x45\n\x10NamedValuesEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12 \n\x05value\x18\x02 \x01(\x0b\x32\x11.flwr.proto.Value:\x02\x38\x01\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x66lwr/proto/task.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x1a\x66lwr/proto/transport.proto\"\xbe\x02\n\x04Task\x12\"\n\x08producer\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\"\n\x08\x63onsumer\x18\x02 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x12\n\ncreated_at\x18\x03 \x01(\t\x12\x14\n\x0c\x64\x65livered_at\x18\x04 \x01(\t\x12\x0b\n\x03ttl\x18\x05 \x01(\t\x12\x10\n\x08\x61ncestry\x18\x06 \x03(\t\x12)\n\x02sa\x18\x07 \x01(\x0b\x32\x1d.flwr.proto.SecureAggregation\x12<\n\x15legacy_server_message\x18\x65 \x01(\x0b\x32\x19.flwr.proto.ServerMessageB\x02\x18\x01\x12<\n\x15legacy_client_message\x18\x66 \x01(\x0b\x32\x19.flwr.proto.ClientMessageB\x02\x18\x01\"a\n\x07TaskIns\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\x10\n\x08group_id\x18\x02 \x01(\t\x12\x13\n\x0bworkload_id\x18\x03 \x01(\x04\x12\x1e\n\x04task\x18\x04 \x01(\x0b\x32\x10.flwr.proto.Task\"a\n\x07TaskRes\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\x10\n\x08group_id\x18\x02 \x01(\t\x12\x13\n\x0bworkload_id\x18\x03 \x01(\x04\x12\x1e\n\x04task\x18\x04 \x01(\x0b\x32\x10.flwr.proto.Task\"\xf3\x03\n\x05Value\x12\x10\n\x06\x64ouble\x18\x01 \x01(\x01H\x00\x12\x10\n\x06sint64\x18\x02 \x01(\x12H\x00\x12\x0e\n\x04\x62ool\x18\x03 \x01(\x08H\x00\x12\x10\n\x06string\x18\x04 \x01(\tH\x00\x12\x0f\n\x05\x62ytes\x18\x05 \x01(\x0cH\x00\x12\x33\n\x0b\x64ouble_list\x18\x15 \x01(\x0b\x32\x1c.flwr.proto.Value.DoubleListH\x00\x12\x33\n\x0bsint64_list\x18\x16 \x01(\x0b\x32\x1c.flwr.proto.Value.Sint64ListH\x00\x12/\n\tbool_list\x18\x17 \x01(\x0b\x32\x1a.flwr.proto.Value.BoolListH\x00\x12\x33\n\x0bstring_list\x18\x18 \x01(\x0b\x32\x1c.flwr.proto.Value.StringListH\x00\x12\x31\n\nbytes_list\x18\x19 \x01(\x0b\x32\x1b.flwr.proto.Value.BytesListH\x00\x1a\x1a\n\nDoubleList\x12\x0c\n\x04vals\x18\x01 \x03(\x01\x1a\x1a\n\nSint64List\x12\x0c\n\x04vals\x18\x01 \x03(\x12\x1a\x18\n\x08\x42oolList\x12\x0c\n\x04vals\x18\x01 \x03(\x08\x1a\x1a\n\nStringList\x12\x0c\n\x04vals\x18\x01 \x03(\t\x1a\x19\n\tBytesList\x12\x0c\n\x04vals\x18\x01 \x03(\x0c\x42\x07\n\x05value\"\xa0\x01\n\x11SecureAggregation\x12\x44\n\x0cnamed_values\x18\x01 \x03(\x0b\x32..flwr.proto.SecureAggregation.NamedValuesEntry\x1a\x45\n\x10NamedValuesEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12 \n\x05value\x18\x02 \x01(\x0b\x32\x11.flwr.proto.Value:\x02\x38\x01\x62\x06proto3') diff --git a/src/py/flwr/proto/task_pb2.pyi b/src/py/flwr/proto/task_pb2.pyi index dcd4686944bc..7cf96cb61edf 100644 --- a/src/py/flwr/proto/task_pb2.pyi +++ b/src/py/flwr/proto/task_pb2.pyi @@ -63,14 +63,14 @@ class TaskIns(google.protobuf.message.Message): TASK_FIELD_NUMBER: builtins.int task_id: typing.Text group_id: typing.Text - workload_id: typing.Text + workload_id: builtins.int @property def task(self) -> global___Task: ... def __init__(self, *, task_id: typing.Text = ..., group_id: typing.Text = ..., - workload_id: typing.Text = ..., + workload_id: builtins.int = ..., task: typing.Optional[global___Task] = ..., ) -> None: ... def HasField(self, field_name: typing_extensions.Literal["task",b"task"]) -> builtins.bool: ... @@ -85,14 +85,14 @@ class TaskRes(google.protobuf.message.Message): TASK_FIELD_NUMBER: builtins.int task_id: typing.Text group_id: typing.Text - workload_id: typing.Text + workload_id: builtins.int @property def task(self) -> global___Task: ... def __init__(self, *, task_id: typing.Text = ..., group_id: typing.Text = ..., - workload_id: typing.Text = ..., + workload_id: builtins.int = ..., task: typing.Optional[global___Task] = ..., ) -> None: ... def HasField(self, field_name: typing_extensions.Literal["task",b"task"]) -> builtins.bool: ... diff --git a/src/py/flwr/server/fleet/message_handler/message_handler_test.py b/src/py/flwr/server/fleet/message_handler/message_handler_test.py index 10f678e3479e..da92b267f082 100644 --- a/src/py/flwr/server/fleet/message_handler/message_handler_test.py +++ b/src/py/flwr/server/fleet/message_handler/message_handler_test.py @@ -109,7 +109,7 @@ def test_push_task_res() -> None: TaskRes( task_id="", group_id="", - workload_id="", + workload_id=0, task=Task(), ), ], diff --git a/src/py/flwr/server/state/in_memory_state.py b/src/py/flwr/server/state/in_memory_state.py index 075ba2cf304d..d6292571cd6d 100644 --- a/src/py/flwr/server/state/in_memory_state.py +++ b/src/py/flwr/server/state/in_memory_state.py @@ -32,7 +32,7 @@ class InMemoryState(State): def __init__(self) -> None: self.node_ids: Set[int] = set() - self.workload_ids: Set[str] = set() + self.workload_ids: Set[int] = set() self.task_ins_store: Dict[UUID, TaskIns] = {} self.task_res_store: Dict[UUID, TaskRes] = {} @@ -194,7 +194,7 @@ def unregister_node(self, node_id: int) -> None: raise ValueError(f"Node {node_id} is not registered") self.node_ids.remove(node_id) - def get_nodes(self, workload_id: str) -> Set[int]: + def get_nodes(self, workload_id: int) -> Set[int]: """Return all available client nodes. Constraints @@ -206,14 +206,13 @@ def get_nodes(self, workload_id: str) -> Set[int]: return set() return self.node_ids - def create_workload(self) -> str: + def create_workload(self) -> int: """Create one workload.""" - # String representation of random integer from 0 to 9223372036854775807 - random_workload_id: int = random.randrange(9223372036854775808) - workload_id = str(random_workload_id) + # Sample random integer from 0 to 9223372036854775807 + workload_id: int = random.randrange(9223372036854775808) if workload_id not in self.workload_ids: self.workload_ids.add(workload_id) return workload_id log(ERROR, "Unexpected workload creation failure.") - return "" + return 0 diff --git a/src/py/flwr/server/state/sqlite_state.py b/src/py/flwr/server/state/sqlite_state.py index e971c11da2f5..0c853409b844 100644 --- a/src/py/flwr/server/state/sqlite_state.py +++ b/src/py/flwr/server/state/sqlite_state.py @@ -39,7 +39,7 @@ SQL_CREATE_TABLE_WORKLOAD = """ CREATE TABLE IF NOT EXISTS workload( - workload_id TEXT UNIQUE + workload_id INTEGER UNIQUE ); """ @@ -47,7 +47,7 @@ CREATE TABLE IF NOT EXISTS task_ins( task_id TEXT UNIQUE, group_id TEXT, - workload_id TEXT, + workload_id INTEGER, producer_anonymous BOOLEAN, producer_node_id INTEGER, consumer_anonymous BOOLEAN, @@ -67,7 +67,7 @@ CREATE TABLE IF NOT EXISTS task_res( task_id TEXT UNIQUE, group_id TEXT, - workload_id TEXT, + workload_id INTEGER, producer_anonymous BOOLEAN, producer_node_id INTEGER, consumer_anonymous BOOLEAN, @@ -479,7 +479,7 @@ def unregister_node(self, node_id: int) -> None: query = "DELETE FROM node WHERE node_id = :node_id;" self.query(query, {"node_id": node_id}) - def get_nodes(self, workload_id: str) -> Set[int]: + def get_nodes(self, workload_id: int) -> Set[int]: """Retrieve all currently stored node IDs as a set. Constraints @@ -498,11 +498,10 @@ def get_nodes(self, workload_id: str) -> Set[int]: result: Set[int] = {row["node_id"] for row in rows} return result - def create_workload(self) -> str: + def create_workload(self) -> int: """Create one workload and store it in state.""" - # String representation of random integer from 0 to 9223372036854775807 - random_workload_id: int = random.randrange(9223372036854775808) - workload_id = str(random_workload_id) + # Sample random integer from 0 to 9223372036854775807 + workload_id: int = random.randrange(9223372036854775808) # Check conflicts query = "SELECT COUNT(*) FROM workload WHERE workload_id = ?;" @@ -512,7 +511,7 @@ def create_workload(self) -> str: self.query(query, {"workload_id": workload_id}) return workload_id log(ERROR, "Unexpected workload creation failure.") - return "" + return 0 def dict_factory( diff --git a/src/py/flwr/server/state/sqlite_state_test.py b/src/py/flwr/server/state/sqlite_state_test.py index e3bb72e34118..b9c0df9ed134 100644 --- a/src/py/flwr/server/state/sqlite_state_test.py +++ b/src/py/flwr/server/state/sqlite_state_test.py @@ -27,7 +27,7 @@ class SqliteStateTest(unittest.TestCase): def test_ins_res_to_dict(self) -> None: """Check if all required keys are included in return value.""" # Prepare - ins_res = create_task_ins(consumer_node_id=1, anonymous=True, workload_id="") + ins_res = create_task_ins(consumer_node_id=1, anonymous=True, workload_id=0) expected_keys = [ "task_id", "group_id", diff --git a/src/py/flwr/server/state/state.py b/src/py/flwr/server/state/state.py index cfd68c589b6e..a0b9e663f637 100644 --- a/src/py/flwr/server/state/state.py +++ b/src/py/flwr/server/state/state.py @@ -140,7 +140,7 @@ def unregister_node(self, node_id: int) -> None: """Remove `node_id` from state.""" @abc.abstractmethod - def get_nodes(self, workload_id: str) -> Set[int]: + def get_nodes(self, workload_id: int) -> Set[int]: """Retrieve all currently stored node IDs as a set. Constraints @@ -150,5 +150,5 @@ def get_nodes(self, workload_id: str) -> Set[int]: """ @abc.abstractmethod - def create_workload(self) -> str: + def create_workload(self) -> int: """Create one workload.""" diff --git a/src/py/flwr/server/state/state_test.py b/src/py/flwr/server/state/state_test.py index e80bd55352ed..bc3015ba5cc2 100644 --- a/src/py/flwr/server/state/state_test.py +++ b/src/py/flwr/server/state/state_test.py @@ -283,7 +283,7 @@ def test_task_ins_store_invalid_workload_id_and_fail(self) -> None: # Prepare state: State = self.state_factory() task_ins = create_task_ins( - consumer_node_id=0, anonymous=True, workload_id="I'm invalid" + consumer_node_id=0, anonymous=True, workload_id=61016 ) # Execute @@ -362,7 +362,7 @@ def test_get_nodes_invalid_workload_id(self) -> None: # Prepare state: State = self.state_factory() state.create_workload() - invalid_workload_id = "" + invalid_workload_id = 61016 node_id = 2 # Execute @@ -420,7 +420,7 @@ def test_num_task_res(self) -> None: def create_task_ins( consumer_node_id: int, anonymous: bool, - workload_id: str, + workload_id: int, delivered_at: str = "", ) -> TaskIns: """Create a TaskIns for testing.""" @@ -448,7 +448,7 @@ def create_task_res( producer_node_id: int, anonymous: bool, ancestry: List[str], - workload_id: str, + workload_id: int, ) -> TaskRes: """Create a TaskRes for testing.""" task_res = TaskRes( diff --git a/src/py/flwr/server/utils/validator_test.py b/src/py/flwr/server/utils/validator_test.py index 533e3a236572..54840731048f 100644 --- a/src/py/flwr/server/utils/validator_test.py +++ b/src/py/flwr/server/utils/validator_test.py @@ -135,7 +135,7 @@ def create_task_ins( task = TaskIns( task_id="", group_id="", - workload_id="", + workload_id=0, task=Task( delivered_at=delivered_at, producer=Node(node_id=0, anonymous=True), @@ -162,7 +162,7 @@ def create_task_res( task_res = TaskRes( task_id="", group_id="", - workload_id="", + workload_id=0, task=Task( producer=Node(node_id=producer_node_id, anonymous=anonymous), consumer=Node(node_id=0, anonymous=True), diff --git a/src/py/flwr/simulation/app.py b/src/py/flwr/simulation/app.py index 6fcfb82be2c0..5c7a3e7423a3 100644 --- a/src/py/flwr/simulation/app.py +++ b/src/py/flwr/simulation/app.py @@ -18,6 +18,7 @@ import sys import threading import traceback +import warnings from logging import ERROR, INFO from typing import Any, Dict, List, Optional, Type, Union @@ -68,7 +69,7 @@ """ -def start_simulation( # pylint: disable=too-many-arguments +def start_simulation( # pylint: disable=too-many-arguments,too-many-statements *, client_fn: ClientFn, num_clients: Optional[int] = None, @@ -214,6 +215,12 @@ def start_simulation( # pylint: disable=too-many-arguments cluster_resources, ) + log( + INFO, + "Optimize your simulation with Flower VCE: " + "https://flower.dev/docs/framework/how-to-run-simulations.html", + ) + # Log the resources that a single client will be able to use if client_resources is None: log( @@ -222,6 +229,15 @@ def start_simulation( # pylint: disable=too-many-arguments ) client_resources = {"num_cpus": 1, "num_gpus": 0.0} + # Each client needs at the very least one CPU + if "num_cpus" not in client_resources: + warnings.warn( + "No `num_cpus` specified in `client_resources`. " + "Using `num_cpus=1` for each client.", + stacklevel=2, + ) + client_resources["num_cpus"] = 1 + log( INFO, "Flower VCE: Resources for each Virtual Client: %s",