diff --git a/.github/workflows/refresh-one-notebook.yaml b/.github/workflows/refresh-one-notebook.yaml
index d3be4a806..1761e7562 100644
--- a/.github/workflows/refresh-one-notebook.yaml
+++ b/.github/workflows/refresh-one-notebook.yaml
@@ -6,7 +6,6 @@ on:
# --- refresh_notebooks_list.py: refresh list of notebooks currently available [START] ---
# --- do not edit, auto generated part by `make refresh_notebooks_list` ---
description: "Notebook file name only in: \n
- - Cifar10 \n
- CifarInFhe \n
- CifarInFheWithSmallerAccumulators \n
- CifarQuantizationAwareTraining \n
@@ -51,7 +50,6 @@ env:
ACTION_RUN_URL: ${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }}
# --- refresh_notebooks_list.py: refresh list of notebook paths currently available [START] ---
# --- do not edit, auto generated part by `make refresh_notebooks_list` ---
- Cifar10: "use_case_examples/cifar/cifar_brevitas_with_model_splitting/Cifar10.ipynb"
CifarInFhe: "use_case_examples/cifar/cifar_brevitas_finetuning/CifarInFhe.ipynb"
CifarInFheWithSmallerAccumulators: "use_case_examples/cifar/cifar_brevitas_finetuning/CifarInFheWithSmallerAccumulators.ipynb"
CifarQuantizationAwareTraining: "use_case_examples/cifar/cifar_brevitas_finetuning/CifarQuantizationAwareTraining.ipynb"
diff --git a/README.md b/README.md
index cae47fcfd..7452fab51 100644
--- a/README.md
+++ b/README.md
@@ -200,7 +200,6 @@ Concrete ML built-in models have APIs that are almost identical to their scikit-
- [Titanic](use_case_examples/titanic/KaggleTitanic.ipynb): solving the [Kaggle Titanic competition](https://www.kaggle.com/c/titanic/). Implemented with XGBoost from Concrete ML, this example comes as a companion of the [Kaggle notebook](https://www.kaggle.com/code/concretemlteam/titanic-with-privacy-preserving-machine-learning), and was the subject of a blogpost in [KDnuggets](https://www.kdnuggets.com/2022/08/machine-learning-encrypted-data.html).
- [CIFAR10 FHE-friendly model with Brevitas](use_case_examples/cifar/cifar_brevitas_training): training a VGG9 FHE-compatible neural network using Brevitas, and a script to run the neural network in FHE. Execution in FHE takes ~4 minutes per image and shows an accuracy of 88.7%.
- [CIFAR10 / CIFAR100 FHE-friendly models with Transfer Learning approach](use_case_examples/cifar/cifar_brevitas_finetuning): series of three notebooks, that convert a pre-trained FP32 VGG11 neural network into a quantized model using Brevitas. The model is fine-tuned on the CIFAR data-sets, converted for FHE execution with Concrete ML and evaluated using FHE simulation. For CIFAR10 and CIFAR100, respectively, our simulations show an accuracy of 90.2% and 68.2%.
-- [FHE neural network splitting for client/server deployment](use_case_examples/cifar/cifar_brevitas_with_model_splitting): explaining how to split a computationally-intensive neural network model in two parts. First, we execute the first part on the client side in the clear, and the output of this step is encrypted. Next, to complete the computation, the second part of the model is evaluated with FHE. This tutorial also shows the impact of FHE speed/accuracy trade-off on CIFAR10, limiting PBS to 8-bit, and thus achieving 62% accuracy.
*If you have built awesome projects using Concrete ML, please let us know and we will be happy to showcase them here!*
diff --git a/use_case_examples/cifar/README.md b/use_case_examples/cifar/README.md
index 758180e3f..60809c2b4 100644
--- a/use_case_examples/cifar/README.md
+++ b/use_case_examples/cifar/README.md
@@ -7,7 +7,6 @@ This repository provides resources and documentation on different use-cases for
1. [Use-Cases](#use-cases)
- [Fine-Tuning VGG11 CIFAR-10/100](#fine-tuning-cifar)
- [Training Ternary VGG9 on CIFAR10](#training-ternary-vgg-on-cifar10)
- - [CIFAR-10 VGG9 with one client-side layer](#cifar-10-with-a-split-model)
1. [Installation](#installation)
1. [Further Reading & Resources](#further-reading)
@@ -33,14 +32,6 @@ Notebooks:
[Results & Metrics](./cifar_brevitas_training/README.md#accuracy-and-performance)
-### CIFAR-10 with a Split Model
-
-- **Description**: This method divides the model into two segments: one that operates in plaintext (clear) and the other in Fully Homomorphic Encryption (FHE). This division allows for greater precision in the input layer while taking advantage of FHE's privacy-preserving capabilities in the subsequent layers.
-- **Model Design**: Aims at using 8-bit accumulators to speed up FHE inference. The design incorporates pruning techniques and employs 2-bit weights to meet this aim.
-- **Implementation**: Provides step-by-step guidance on how to execute the hybrid clear/FHE model, focusing on the details and decisions behind selecting the optimal `p_error` value. Special attention is given to the binary search method to balance accuracy and FHE performance.
-
-[Results & Metrics](./cifar_brevitas_with_model_splitting/README.md#results)
-
## Installation
All use-cases can be quickly set up with:
diff --git a/use_case_examples/cifar/cifar_brevitas_with_model_splitting/8_bit_model.pt b/use_case_examples/cifar/cifar_brevitas_with_model_splitting/8_bit_model.pt
deleted file mode 100644
index 327e7561d..000000000
--- a/use_case_examples/cifar/cifar_brevitas_with_model_splitting/8_bit_model.pt
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:fb76cc33e43565fe15014c13764fcd0d88422e65d4b23cf5ad1ea4b3b12abfec
-size 18670189
diff --git a/use_case_examples/cifar/cifar_brevitas_with_model_splitting/Cifar10.ipynb b/use_case_examples/cifar/cifar_brevitas_with_model_splitting/Cifar10.ipynb
deleted file mode 100644
index 53af3a831..000000000
--- a/use_case_examples/cifar/cifar_brevitas_with_model_splitting/Cifar10.ipynb
+++ /dev/null
@@ -1,472 +0,0 @@
-{
- "cells": [
- {
- "attachments": {},
- "cell_type": "markdown",
- "id": "442b1f40-24b9-4651-83f0-dee2f042a17a",
- "metadata": {},
- "source": [
- "# CIFAR-10 FHE classification with 8-bit split VGG\n",
- "\n",
- "As mentioned in the [README](./README.md) we present in this notebook how to compile to FHE a split torch model.\n",
- "The model we will be considering is a CIFAR-10 classifier based on the VGG architecture. It was trained with pruning and accumulator bit-width monitoring so that the classifier does not exceed the 8 bit-width accumulator constraint.\n",
- "\n",
- "The first layers of the models should be run on the clear data on the client's side and the rest of the model in FHE on the server's side."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 1,
- "id": "423cb30c-febf-4b5a-b5ce-505ac632b8b4",
- "metadata": {},
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "No CUDA runtime is found, using CUDA_HOME='/usr/local/cuda'\n"
- ]
- }
- ],
- "source": [
- "import time\n",
- "\n",
- "import pandas as pd\n",
- "import torch\n",
- "import torchvision\n",
- "from model import CNV # pylint: disable=no-name-in-module\n",
- "from sklearn.metrics import top_k_accuracy_score\n",
- "from torchvision import transforms\n",
- "\n",
- "from concrete.ml.torch.compile import compile_brevitas_qat_model"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "5bff5c60-669a-4dcc-a0e2-fd17d7e2cdd7",
- "metadata": {},
- "source": [
- "In `model.py` we define our model architecture.\n",
- "\n",
- "As one can see we split the main model `CNV` into two sub-models `ClearModule` and `EncryptedModule`.\n",
- "\n",
- "- `ClearModule` will be used to run on clear data on the client's side. It can do any float operations and does not require quantization.\n",
- "- `EncryptedModule` will run on the server side. This part of the model running in FHE we need to quantize it, thus why we leverage Brevitas for Quantization Aware Training."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 2,
- "id": "16bb6a59-1002-4744-bfe7-baef97a4c7b1",
- "metadata": {},
- "outputs": [],
- "source": [
- "model = CNV(num_classes=10, weight_bit_width=2, act_bit_width=2, in_bit_width=3, in_ch=3)"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "ea920f40-82e1-4c00-8c65-035de4b11d78",
- "metadata": {},
- "source": [
- "We won't be training the model is this notebook as it would be quite computationnaly intensive but we provide an already trained model that satisfies the 8-bit accumulator size constraint and that performs better than random on CIFAR-10."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 3,
- "id": "e5cd2622-b0f4-4632-8ff7-8d17a82ac812",
- "metadata": {},
- "outputs": [],
- "source": [
- "loaded = torch.load(\"./8_bit_model.pt\")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 4,
- "id": "476ddd8e-bb04-4d42-844c-cde14b5817e3",
- "metadata": {},
- "outputs": [],
- "source": [
- "model.load_state_dict(loaded[\"model_state_dict\"])\n",
- "model = model.eval()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 5,
- "id": "29e7ce1e-c287-48e5-ae6d-29df15a12110",
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to .data/cifar-10-python.tar.gz\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "65c15f224f214f9798a29c77b0433e90",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- " 0%| | 0/170498071 [00:00, ?it/s]"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Extracting .data/cifar-10-python.tar.gz to .data/\n",
- "(Dataset CIFAR10\n",
- " Number of datapoints: 50000\n",
- " Root location: .data/\n",
- " Split: Train\n",
- " StandardTransform\n",
- "Transform: Compose(\n",
- " ToTensor()\n",
- " Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))\n",
- " ), Dataset CIFAR10\n",
- " Number of datapoints: 10000\n",
- " Root location: .data/\n",
- " Split: Test\n",
- " StandardTransform\n",
- "Transform: Compose(\n",
- " ToTensor()\n",
- " Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))\n",
- " ))\n"
- ]
- }
- ],
- "source": [
- "IMAGE_TRANSFORM = transforms.Compose(\n",
- " [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]\n",
- ")\n",
- "\n",
- "try:\n",
- " train_set = torchvision.datasets.CIFAR10(\n",
- " root=\".data/\",\n",
- " train=True,\n",
- " download=False,\n",
- " transform=IMAGE_TRANSFORM,\n",
- " target_transform=None,\n",
- " )\n",
- "except RuntimeError:\n",
- " train_set = torchvision.datasets.CIFAR10(\n",
- " root=\".data/\",\n",
- " train=True,\n",
- " download=True,\n",
- " transform=IMAGE_TRANSFORM,\n",
- " target_transform=None,\n",
- " )\n",
- "test_set = torchvision.datasets.CIFAR10(\n",
- " root=\".data/\",\n",
- " train=False,\n",
- " download=False,\n",
- " transform=IMAGE_TRANSFORM,\n",
- " target_transform=None,\n",
- ")\n",
- "\n",
- "print((train_set, test_set))"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "576c9182-ae7b-47f3-b810-d623bddad1dc",
- "metadata": {},
- "source": [
- "We use a sub-sample of the training set for the FHE compilation to maintain acceptable compilation times and avoid out-of-memory errors."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 6,
- "id": "e9de82bb-a05d-4b34-b6ed-15ffe75014ad",
- "metadata": {},
- "outputs": [],
- "source": [
- "num_samples = 1000\n",
- "train_sub_set = torch.stack(\n",
- " [train_set[index][0] for index in range(min(num_samples, len(train_set)))]\n",
- ")"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "d841a675-185c-4723-9c8f-a1bcc9e17cf2",
- "metadata": {},
- "source": [
- "Since we will be compiling only a part of the network we need to give it representative inputs, in our case the first feature map of the network."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 7,
- "id": "e8ed8dfb-5f5c-4dad-ba2b-6bc74dfccc97",
- "metadata": {},
- "outputs": [],
- "source": [
- "# Pre-processing -> images -> feature maps\n",
- "with torch.no_grad():\n",
- " train_features_sub_set = model.clear_module(train_sub_set)"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "6f93f390-88d6-41a5-acf4-67dc42b80fe4",
- "metadata": {},
- "source": [
- "# FHE Simulation\n",
- "\n",
- "In a first time we can make sure that our FHE constraints are respected."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 8,
- "id": "9e2af71c-baba-4edb-8cc3-f3795ae31313",
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Compiling the model\n",
- "Compilation finished in 93.09 seconds\n",
- "Max bitwidth: 8 bits!\n"
- ]
- }
- ],
- "source": [
- "optional_kwargs = {}\n",
- "\n",
- "# Compile the model\n",
- "compilation_onnx_path = \"compilation_model.onnx\"\n",
- "print(\"Compiling the model\")\n",
- "start_compile = time.time()\n",
- "quantized_numpy_module = compile_brevitas_qat_model(\n",
- " # our encrypted model\n",
- " torch_model=model.encrypted_module,\n",
- " # a representative input-set to be used for compilation\n",
- " torch_inputset=train_features_sub_set,\n",
- " **optional_kwargs,\n",
- " output_onnx_file=compilation_onnx_path,\n",
- ")\n",
- "\n",
- "end_compile = time.time()\n",
- "print(f\"Compilation finished in {end_compile - start_compile:.2f} seconds\")\n",
- "\n",
- "# Check that the network is compatible with FHE constraints\n",
- "assert quantized_numpy_module.fhe_circuit is not None\n",
- "bitwidth = quantized_numpy_module.fhe_circuit.graph.maximum_integer_bit_width()\n",
- "print(f\"Max bit-width: {bitwidth} bits!\")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 9,
- "id": "114f4fad-9f9a-4006-9337-6a0ee7a1c0c8",
- "metadata": {},
- "outputs": [],
- "source": [
- "img, _ = train_set[0]\n",
- "with torch.no_grad():\n",
- " feature_maps = model.clear_module(img[None, :])"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 10,
- "id": "2e1e3763-9925-4537-94e1-bb6d32bd41cc",
- "metadata": {},
- "outputs": [],
- "source": [
- "output_simulated = quantized_numpy_module.forward(feature_maps.numpy(), fhe=\"simulate\")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 11,
- "id": "9db4f92d-6092-418f-b5e1-0aba544ffa6f",
- "metadata": {},
- "outputs": [],
- "source": [
- "with torch.no_grad():\n",
- " torch_output = model(img[None, :])"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 12,
- "id": "19c3aa51-fb4d-4da3-9564-6be636b42799",
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "tensor([[ 0.0171, 0.0171, -0.0215, 0.0122, 0.0232, -0.0144, 0.0042, -0.0115,\n",
- " 0.0180, 0.0065]], dtype=torch.float64)\n"
- ]
- }
- ],
- "source": [
- "print(torch_output - output_simulated)"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "3398cdb2-b9aa-4937-97c2-0e5c41d1ac00",
- "metadata": {},
- "source": [
- "We see that we have some differences between the output of the torch model output and the FHE simulation.\n",
- "\n",
- "This is expected but as we can see in the following code blocks we have no difference in top-k accuracies between Pytorch and the FHE simulation mode.\n",
- "\n",
- "It appears that there are some differences between the output of the Torch model and the FHE simulation. While this outcome was expected, it is important to note that, as demonstrated in the following code blocks, there are no differences in the top-k accuracies between PyTorch and the FHE simulation mode."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 13,
- "id": "6c34ffdb-6aca-4464-b26d-bb12c4afb5cd",
- "metadata": {},
- "outputs": [],
- "source": [
- "def evaluate(file_path: str, k=3):\n",
- " predictions = pd.read_csv(file_path)\n",
- " prob_columns = [elt for elt in predictions.columns if elt.endswith(\"_prob\")]\n",
- " predictions[\"pred_label\"] = predictions[prob_columns].values.argmax(axis=1)\n",
- "\n",
- " # Equivalent to top-1-accuracy\n",
- " for k_ in range(1, k + 1):\n",
- " print(\n",
- " f\"top-{k}-accuracy: \",\n",
- " top_k_accuracy_score(\n",
- " y_true=predictions[\"label\"], y_score=predictions[prob_columns], k=k_\n",
- " ),\n",
- " )"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "732267b9-c7cb-42ee-a40a-b21337d41430",
- "metadata": {},
- "source": [
- "We can use the `infer_fhe_simulation.py` script to generate the predictions of the model using Pytorch for the first layer and FHE simulation for the rest of the network."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 14,
- "id": "c5918f4d-91dc-434a-90c8-baf79e13f8dc",
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Compiling the model\n",
- "Compilation finished in 86.79 seconds\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- " \r"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Finished inference\n",
- "top-3-accuracy: 0.6231\n",
- "top-3-accuracy: 0.8072\n",
- "top-3-accuracy: 0.8906\n"
- ]
- }
- ],
- "source": [
- "%run infer_fhe_simulation.py\n",
- "evaluate(\"./fhe_simulated_predictions.csv\")"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "07ee3362-62c0-4e4a-a4c2-4af63710e7dc",
- "metadata": {},
- "source": [
- "And the `infer.py` script to generate the pure Pytorch predictions."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 15,
- "id": "65849856-cf80-4b93-99c1-56deeac83c0a",
- "metadata": {},
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- " \r"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Finished inference\n",
- "top-3-accuracy: 0.6231\n",
- "top-3-accuracy: 0.8072\n",
- "top-3-accuracy: 0.8906\n"
- ]
- }
- ],
- "source": [
- "%run infer_torch.py\n",
- "evaluate(\"./predictions.csv\")"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "713a2eb9-7e4f-4890-b8c5-71765bb8bb18",
- "metadata": {},
- "source": [
- "# FHE execution results\n",
- "\n",
- "In this notebook we showed how to compile a split-VGG model trained to classify CIFAR-10 images in FHE.\n",
- "\n",
- "While satisfying the FHE constraints the model achieves the following performances:\n",
- "\n",
- "- top-1-accuracy: 0.6234\n",
- "- top-2-accuracy: 0.8075\n",
- "- top-3-accuracy: 0.8905\n",
- "\n",
- "*We don't launch the inference in FHE in this notebook as it takes quite some time just to infer on one image.*\n",
- "\n",
- "For reference we ran the inference of one image on an AWS c6i.metal compute machine, using the `fhe_inference.py` script, and got the following timings:\n",
- "\n",
- "- Time to compile: 103 seconds\n",
- "- Time to keygen: 639 seconds\n",
- "- Time to infer: ~1800 seconds"
- ]
- }
- ],
- "metadata": {
- "execution": {
- "timeout": 10800
- }
- },
- "nbformat": 4,
- "nbformat_minor": 5
-}
diff --git a/use_case_examples/cifar/cifar_brevitas_with_model_splitting/Makefile b/use_case_examples/cifar/cifar_brevitas_with_model_splitting/Makefile
deleted file mode 100644
index 53d36d2e7..000000000
--- a/use_case_examples/cifar/cifar_brevitas_with_model_splitting/Makefile
+++ /dev/null
@@ -1,12 +0,0 @@
-# Useful for jupyter notebooks
-export LC_ALL=en_US.UTF-8
-export LANG=en_US.UTF-8
-
-EXAMPLE_NAME=cifar_brevitas_finetuning
-JUPYTER_RUN=jupyter nbconvert --to notebook --inplace --execute
-TIME_NB="${USE_CASE_DIR}/time_notebook_execution.sh"
-
-run_example: one
-
-one:
- @$(TIME_NB) Cifar10.ipynb
diff --git a/use_case_examples/cifar/cifar_brevitas_with_model_splitting/README.md b/use_case_examples/cifar/cifar_brevitas_with_model_splitting/README.md
deleted file mode 100644
index 5bcc8a2cb..000000000
--- a/use_case_examples/cifar/cifar_brevitas_with_model_splitting/README.md
+++ /dev/null
@@ -1,74 +0,0 @@
-# CIFAR-10 classification with a split clear/FHE model
-
-In this [notebook](./Cifar10.ipynb) we show how train and compile an FHE VGG-like model to achieve a good speed/accuracy trade-off on CIFAR-10 images.
-
-## Installation
-
-To use this code, you need to have Python 3.8 and install the following dependencies:
-
-```
-pip install -r requirements.txt
-```
-
-## Model design
-
-As there is a trade-off between accumulator bit-width and FHE inference speed, this tutorial targets
-8-bit accumulators, to achieve faster FHE inference times. Moreover, we split the model in two parts to allow
-higher precision in the input layer. For the FHE part of the model, we used pruning and 2-bit weights to try to satisfy this constraint.
-
-The first layer of any deep vision model processes the raw images that are usually represented using 8-bit integers.
-With respect to FHE constraints, such large bit-widths for inputs are a bottleneck with regards to the accumulator size constraint. Therefore, we opted to split the model into 2 sub-models:
-
-- The first layer of the VGG model will run in floats and in clear,
-- The rest of the network will run in integers and in FHE.
-
-The method is generic and can be applied to any neural network model, but the compilation and deployment steps are a bit more intricate in this case. We show how to compile this split model in the [notebook](./Cifar10.ipynb).
-
-## Running this example
-
-To run this notebook properly you will need the usual Concrete ML dependencies plus the extra dependencies from `requirements.txt` that you can install using `pip install -r requirements.txt` .
-
-We also provide a script to run the model in FHE.
-
-The naive approach to run this model in FHE would be to choose a low `p_error` or `global_p_error` and compile the model with it to run the FHE inference.
-By trial and error we found that a `global_p_error` of 0.15 was one of the lowest value for which we could find crypto-parameters.
-On an AWS c6i.metal compute machine, doing the inference of one CIFAR-10 image with a `global_p_error` of 0.15, we got the following timings:
-
-- Time to compile: 112 seconds
-- Time to keygen: 1231 seconds
-- Time to infer: 35619 seconds (around 10 hours)
-
-But this can be improved by searching for a better `p_error`.
-
-One way to do this is to do a binary search using FHE simulation to estimate the impact of the `p_error` on the final accuracy of our model.
-Using the first 1000 samples of CIFAR-10 train set we ran the search to find the highest `p_error` such that the difference in accuracy between the FHE-simulated and the clear model was below 1 point. This search yielded a `p_error` of approximately 0.05.
-We use only a subset of the training set to make the search time acceptable, but one can either modify this number, or even do [bootstrapping](), to have a better estimate.
-We provide a [script](./p_error_search.py) to run the `p_error` search. Results may differ since it relies on random simulation.
-
-Obviously the accuracy difference observed is only a simulation on these 1000 samples so a verification of this result is important to do. We validated this `p_error` choice by running 40 times the inference of the 1000 samples using simulation and the maximum difference in accuracy that we observed was of 2 points, which seemed acceptable.
-
-Once we had this `p_error` validated we re-run the FHE inference using this new `p_error`, on the same machine (c6i.metal) and got the following results:
-
-- Time to compile: 109 seconds
-- Time to keygen: 30 seconds
-- Time to infer: 1738 seconds
-
-We see a 20x improvement with a simple change in the `p_error` parameter, for more details on how to handle `p_error` please refer to the [documentation](../../../docs/explanations/advanced_features.md#approximate-computations).
-
-## Results
-
-Anyone can reproduce the FHE inference results using the [dedicated script](./infer_fhe.py).
-We also provide `infer_vl.py` and `infer_torch.py` to infer using FHE-simulation or directly through PyTorch.
-
-The PyTorch model and the inference using PyTorch for the first layer and simulation for the encrypted part yielded the same top-k accuracies:
-
-- top-1-accuracy: 0.6234
-- top-2-accuracy: 0.8075
-- top-3-accuracy: 0.8905
-
-which are decent metrics for a traditional VGG model under such constraints.
-
-The accuracy of the model running in FHE was not measured because of the computational cost it would require.
-This is something we plan on measuring once FHE runtimes become more acceptable.
-
-
diff --git a/use_case_examples/cifar/cifar_brevitas_with_model_splitting/brevitas_utils.py b/use_case_examples/cifar/cifar_brevitas_with_model_splitting/brevitas_utils.py
deleted file mode 100644
index b38f2e844..000000000
--- a/use_case_examples/cifar/cifar_brevitas_with_model_splitting/brevitas_utils.py
+++ /dev/null
@@ -1,130 +0,0 @@
-# Code coming from https://github.com/Xilinx/brevitas/tree/master/src/brevitas_examples/bnn_pynq
-# MIT License
-#
-# Copyright (c) 2019 Xilinx
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in all
-# copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
-# SOFTWARE.
-
-
-import numpy as np
-import torch
-import torch.nn as nn
-import torch.nn.init as init
-import torch.nn.utils.prune as prune
-from brevitas.core.bit_width import BitWidthImplType
-from brevitas.core.quant import QuantType
-from brevitas.core.restrict_val import FloatToIntImplType, RestrictValueType
-from brevitas.core.scaling import ScalingImplType
-from brevitas.core.zero_point import ZeroZeroPoint
-from brevitas.inject import ExtendedInjector
-from brevitas.nn import QuantConv2d, QuantIdentity, QuantLinear
-from brevitas.quant import Int8WeightPerTensorFloat
-from brevitas.quant.solver import ActQuantSolver, WeightQuantSolver
-from dependencies import value
-from torch.nn import AvgPool2d, BatchNorm1d, BatchNorm2d, Conv2d, MaxPool2d, Module, ModuleList
-
-# -- Quantizers --
-
-
-class CommonQuant(ExtendedInjector):
- bit_width_impl_type = BitWidthImplType.CONST
- scaling_impl_type = ScalingImplType.CONST
- restrict_scaling_type = RestrictValueType.FP
- zero_point_impl = ZeroZeroPoint
- float_to_int_impl_type = FloatToIntImplType.ROUND
- scaling_per_output_channel = False
- narrow_range = True
- signed = True
-
- @value
- def quant_type(bit_width):
- if bit_width is None:
- return QuantType.FP
- elif bit_width == 1:
- return QuantType.BINARY
- else:
- return QuantType.INT
-
-
-class CommonIntWeightPerTensorQuant(Int8WeightPerTensorFloat):
- """
- Common per-tensor weight quantizer with bit-width set to None so that it is forced to be
- specified by each layer.
- """
-
- scaling_min_val = 2e-16
- bit_width = None
-
-
-class CommonIntWeightPerChannelQuant(CommonIntWeightPerTensorQuant):
- """
- Common per-channel weight quantizer with bit-width set to None so that it is forced to be
- specified by each layer.
- """
-
- scaling_per_output_channel = True
-
-
-class CommonWeightQuant(CommonQuant, WeightQuantSolver):
- scaling_const = 1.0
-
-
-class CommonActQuant(CommonQuant, ActQuantSolver):
- min_val = -1.0
- max_val = 1.0
-
-
-# -- Custom layer --
-
-
-class TensorNorm(nn.Module):
- def __init__(self, eps=1e-4, momentum=0.1):
- super().__init__()
-
- self.eps = eps
- self.momentum = momentum
- self.weight = nn.Parameter(torch.rand(1))
- self.bias = nn.Parameter(torch.rand(1))
- self.register_buffer("running_mean", torch.zeros(1))
- self.register_buffer("running_var", torch.ones(1))
- self.reset_running_stats()
-
- def reset_running_stats(self):
- self.running_mean.zero_()
- self.running_var.fill_(1)
- init.ones_(self.weight)
- init.zeros_(self.bias)
-
- def forward(self, x):
- if self.training:
- mean = x.mean()
- unbias_var = x.var(unbiased=True)
- biased_var = x.var(unbiased=False)
- self.running_mean = (
- 1 - self.momentum
- ) * self.running_mean + self.momentum * mean.detach()
- self.running_var = (
- 1 - self.momentum
- ) * self.running_var + self.momentum * unbias_var.detach()
- inv_std = 1 / (biased_var + self.eps).pow(0.5)
- return (x - mean) * inv_std * self.weight + self.bias
- else:
- return (
- (x - self.running_mean) / (self.running_var + self.eps).pow(0.5)
- ) * self.weight + self.bias
diff --git a/use_case_examples/cifar/cifar_brevitas_with_model_splitting/clear_module.pt b/use_case_examples/cifar/cifar_brevitas_with_model_splitting/clear_module.pt
deleted file mode 100644
index 0cf721eab..000000000
--- a/use_case_examples/cifar/cifar_brevitas_with_model_splitting/clear_module.pt
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:eda3f0ed14c2f066008c1bda58a578fc43cedb598d41c516396cca29151a66fa
-size 12928
diff --git a/use_case_examples/cifar/cifar_brevitas_with_model_splitting/clear_module.py b/use_case_examples/cifar/cifar_brevitas_with_model_splitting/clear_module.py
deleted file mode 100644
index dadcc3654..000000000
--- a/use_case_examples/cifar/cifar_brevitas_with_model_splitting/clear_module.py
+++ /dev/null
@@ -1,52 +0,0 @@
-import torch
-from brevitas.nn import QuantIdentity
-from brevitas_utils import CommonActQuant
-from constants import CNV_OUT_CH_POOL, KERNEL_SIZE, SPLIT_INDEX
-from torch.nn import AvgPool2d, BatchNorm2d, Conv2d
-
-
-# First layers of the model
-# They will in-fine run in clear on the client side
-# It can use float or quantization
-class ClearModule(torch.nn.Module):
- def __init__(
- self,
- in_ch: int,
- out_bit_width: int,
- ):
- super().__init__()
-
- self.in_ch = in_ch
- out_ch = in_ch
-
- self.conv_features = torch.nn.ModuleList()
- for out_ch, is_pool_enabled in CNV_OUT_CH_POOL[:SPLIT_INDEX]:
- self.conv_features.append(
- Conv2d(
- kernel_size=KERNEL_SIZE,
- in_channels=in_ch,
- out_channels=out_ch,
- bias=True,
- )
- )
- in_ch = out_ch
- self.conv_features.append(BatchNorm2d(in_ch, eps=1e-4))
- if is_pool_enabled:
- self.conv_features.append(AvgPool2d(kernel_size=2))
- self.out_ch = out_ch
-
- self.conv_features.append(
- QuantIdentity(
- act_quant=CommonActQuant,
- return_quant_tensor=False,
- bit_width=out_bit_width,
- min_val=-1.0,
- max_val=1.0 - 2.0 ** (-7),
- narrow_range=True,
- )
- )
-
- def forward(self, x):
- for mod in self.conv_features:
- x = mod(x)
- return x
diff --git a/use_case_examples/cifar/cifar_brevitas_with_model_splitting/constants.py b/use_case_examples/cifar/cifar_brevitas_with_model_splitting/constants.py
deleted file mode 100644
index 92a92d38e..000000000
--- a/use_case_examples/cifar/cifar_brevitas_with_model_splitting/constants.py
+++ /dev/null
@@ -1,16 +0,0 @@
-CNV_OUT_CH_POOL = [
- (64, False),
- (64, True),
- (128, False),
- (128, True),
- (256, False),
- (256, False),
-]
-
-INTERMEDIATE_FC_FEATURES = [(256, 512), (512, 512)]
-LAST_FC_IN_FEATURES = 512
-LAST_FC_PER_OUT_CH_SCALING = False
-POOL_SIZE = 2
-KERNEL_SIZE = 3
-EPSILON_VALUE = 0.5
-SPLIT_INDEX = 1
diff --git a/use_case_examples/cifar/cifar_brevitas_with_model_splitting/encrypted_module.pt b/use_case_examples/cifar/cifar_brevitas_with_model_splitting/encrypted_module.pt
deleted file mode 100644
index ba3b435e2..000000000
--- a/use_case_examples/cifar/cifar_brevitas_with_model_splitting/encrypted_module.pt
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:7e6f9d2d4cb71ec669d4375db47ca7ad526d71b0c7b16902c938a8b73bedcc7e
-size 6248401
diff --git a/use_case_examples/cifar/cifar_brevitas_with_model_splitting/encrypted_module.py b/use_case_examples/cifar/cifar_brevitas_with_model_splitting/encrypted_module.py
deleted file mode 100644
index b95025f34..000000000
--- a/use_case_examples/cifar/cifar_brevitas_with_model_splitting/encrypted_module.py
+++ /dev/null
@@ -1,98 +0,0 @@
-import torch
-from brevitas.nn import QuantConv2d, QuantIdentity, QuantLinear
-from brevitas_utils import CommonActQuant, CommonWeightQuant, TensorNorm
-from constants import (
- CNV_OUT_CH_POOL,
- INTERMEDIATE_FC_FEATURES,
- KERNEL_SIZE,
- LAST_FC_IN_FEATURES,
- SPLIT_INDEX,
-)
-from torch.nn import AvgPool2d, BatchNorm1d, BatchNorm2d
-
-
-class EncryptedModule(torch.nn.Module):
- def __init__(
- self,
- num_classes: int,
- weight_bit_width: int,
- act_bit_width: int,
- in_bit_width: int,
- in_ch: int,
- ):
- super().__init__()
-
- self.num_classes = num_classes
- self.weight_bit_width = weight_bit_width
- self.act_bit_width = act_bit_width
-
- self.conv_features = torch.nn.ModuleList()
- self.linear_features = torch.nn.ModuleList()
-
- self.conv_features.append(
- QuantIdentity( # for Q1.7 input format
- act_quant=CommonActQuant,
- return_quant_tensor=True,
- bit_width=in_bit_width,
- min_val=-1.0,
- max_val=1.0 - 2.0 ** (-7),
- narrow_range=True,
- # restrict_scaling_type=RestrictValueType.POWER_OF_TWO,
- )
- )
-
- for out_ch, is_pool_enabled in CNV_OUT_CH_POOL[SPLIT_INDEX:]:
- self.conv_features.append(
- QuantConv2d(
- kernel_size=KERNEL_SIZE,
- in_channels=in_ch,
- out_channels=out_ch,
- bias=False,
- weight_quant=CommonWeightQuant,
- weight_bit_width=weight_bit_width,
- )
- )
- in_ch = out_ch
- self.conv_features.append(BatchNorm2d(in_ch, eps=1e-4))
- self.conv_features.append(
- QuantIdentity(act_quant=CommonActQuant, bit_width=act_bit_width)
- )
- if is_pool_enabled:
- self.conv_features.append(AvgPool2d(kernel_size=2))
- self.conv_features.append(
- QuantIdentity(act_quant=CommonActQuant, bit_width=act_bit_width)
- )
-
- for in_features, out_features in INTERMEDIATE_FC_FEATURES:
- self.linear_features.append(
- QuantLinear(
- in_features=in_features,
- out_features=out_features,
- bias=False,
- weight_quant=CommonWeightQuant,
- weight_bit_width=weight_bit_width,
- )
- )
- self.linear_features.append(BatchNorm1d(out_features, eps=1e-4))
- self.linear_features.append(
- QuantIdentity(act_quant=CommonActQuant, bit_width=act_bit_width)
- )
-
- self.linear_features.append(
- QuantLinear(
- in_features=LAST_FC_IN_FEATURES,
- out_features=num_classes,
- bias=False,
- weight_quant=CommonWeightQuant,
- weight_bit_width=weight_bit_width,
- )
- )
- self.linear_features.append(TensorNorm())
-
- def forward(self, x):
- for mod in self.conv_features:
- x = mod(x)
- x = torch.flatten(x, 1)
- for mod in self.linear_features:
- x = mod(x)
- return x
diff --git a/use_case_examples/cifar/cifar_brevitas_with_model_splitting/infer_fhe.py b/use_case_examples/cifar/cifar_brevitas_with_model_splitting/infer_fhe.py
deleted file mode 100644
index d2c14e717..000000000
--- a/use_case_examples/cifar/cifar_brevitas_with_model_splitting/infer_fhe.py
+++ /dev/null
@@ -1,181 +0,0 @@
-#!/usr/bin/env python
-# coding: utf-8
-import os
-import time
-from pathlib import Path
-from typing import List
-
-import torch
-import torchvision
-import torchvision.transforms as transforms
-from concrete.fhe import Circuit, Configuration
-from model import CNV
-
-from concrete.ml.deployment.fhe_client_server import FHEModelDev
-from concrete.ml.torch.compile import compile_brevitas_qat_model
-
-NUM_SAMPLES = int(os.environ.get("NUM_SAMPLES", 400))
-
-
-def main():
- model = CNV(num_classes=10, weight_bit_width=2, act_bit_width=2, in_bit_width=3, in_ch=3)
- loaded = torch.load(Path(__file__).parent / "8_bit_model.pt")
- model.load_state_dict(loaded["model_state_dict"])
- model = model.eval()
- IMAGE_TRANSFORM = transforms.Compose(
- [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
- )
-
- try:
- train_set = torchvision.datasets.CIFAR10(
- root=".data/",
- train=True,
- download=False,
- transform=IMAGE_TRANSFORM,
- target_transform=None,
- )
- except:
- train_set = torchvision.datasets.CIFAR10(
- root=".data/",
- train=True,
- download=True,
- transform=IMAGE_TRANSFORM,
- target_transform=None,
- )
-
- num_samples = 1000
- train_sub_set = torch.stack(
- [train_set[index][0] for index in range(min(num_samples, len(train_set)))]
- )
-
- # Create a representative input-set that will be used used for both computing quantization
- # parameters and compiling the model
- with torch.no_grad():
- train_features_sub_set = model.clear_module(train_sub_set)
-
- # Multi-parameter strategy is used in order to speed-up the FHE executions
- configuration = Configuration(show_optimizer=True)
-
- compilation_onnx_path = "compilation_model.onnx"
- print("Compiling the model ...")
- start_compile = time.time()
-
- # Compile the quantized model
- quantized_numpy_module = compile_brevitas_qat_model(
- torch_model=model.encrypted_module,
- torch_inputset=train_features_sub_set,
- configuration=configuration,
- p_error=0.05,
- output_onnx_file=compilation_onnx_path,
- )
- end_compile = time.time()
- print(f"Compilation finished in {end_compile - start_compile:.2f} seconds")
-
- # Save the graph and mlir
- print("Saving graph and mlir to disk.")
- open("cifar10.graph", "w").write(str(quantized_numpy_module.fhe_circuit))
- open("cifar10.mlir", "w").write(quantized_numpy_module.fhe_circuit.mlir)
-
- dev = FHEModelDev(path_dir="./client_server", model=quantized_numpy_module)
- dev.save()
-
- # Key generation
- print("Generating keys ...")
- start_keygen = time.time()
- assert isinstance(quantized_numpy_module.fhe_circuit, Circuit)
- quantized_numpy_module.fhe_circuit.keygen()
- end_keygen = time.time()
- print(f"Keygen finished in {end_keygen - start_keygen:.2f} seconds")
-
- # Initialize file
- inference_file = Path("inference_results.csv")
- open(inference_file, "w", encoding="utf-8").close()
-
- # Inference part
- columns: List[str] = []
- for image_index in range(NUM_SAMPLES):
- print("Infering ...")
- img, label = train_set[image_index] # Get the image
-
- # Clear extraction of the feature maps
- feature_extraction_start = time.time()
- with torch.no_grad():
- feature_maps = model.clear_module(img[None, :])
- feature_extraction_end = time.time()
- feature_extraction_time = feature_extraction_end - feature_extraction_start
-
- # Quantization of the feature maps
- quantization_start = time.time()
- quantized_feature_maps = quantized_numpy_module.quantize_input(feature_maps.numpy())
- quantization_end = time.time()
- quantization_time = quantization_end - quantization_start
-
- # Encryption of the feature maps
- encryption_start = time.time()
- encryped_feature_maps = quantized_numpy_module.fhe_circuit.encrypt(quantized_feature_maps)
- encryption_end = time.time()
- encryption_time = encryption_end - encryption_start
-
- # FHE computation
- fhe_start = time.time()
- encrypted_output = quantized_numpy_module.fhe_circuit.run(encryped_feature_maps)
- fhe_end = time.time()
- fhe_time = fhe_end - fhe_start
-
- # Decryption of the output
- decryption_start = time.time()
- quantized_output = quantized_numpy_module.fhe_circuit.decrypt(encrypted_output)
- decryption_end = time.time()
- decryption_time = decryption_end - decryption_start
-
- # De-quantization of the output
- dequantization_start = time.time()
- output = quantized_numpy_module.dequantize_output(quantized_output)
- dequantization_end = time.time()
- dequantization_time = dequantization_end - dequantization_start
-
- inference_time = dequantization_end - feature_extraction_start
-
- # Torch reference
- torch_start = time.time()
- with torch.no_grad():
- torch_output = model.encrypted_module(feature_maps).numpy()
- torch_end = time.time()
- torch_time = torch_end - torch_start
-
- # Dump everything in a csv
- to_dump = {
- "image_index": image_index,
- # Timings
- "feature_extraction_time": feature_extraction_time,
- "quantization_time": quantization_time,
- "encryption_time": encryption_time,
- "fhe_time": fhe_time,
- "decryption_time": decryption_time,
- "dequantization_time": dequantization_time,
- "inference_time": inference_time,
- "torch_time": torch_time,
- "label": label,
- }
-
- for prediction_index, prediction in enumerate(quantized_output[0]):
- to_dump[f"quantized_prediction_{prediction_index}"] = prediction
- for prediction_index, prediction in enumerate(output[0]):
- to_dump[f"prediction_{prediction_index}"] = prediction
- for prediction_index, prediction in enumerate(torch_output[0]):
- to_dump[f"torch_prediction_{prediction_index}"] = prediction
-
- # Write to file
- with open(inference_file, "a", encoding="utf-8") as file:
- if image_index == 0:
- columns = list(to_dump.keys())
- file.write(",".join(columns) + "\n")
- file.write(",".join(str(to_dump[column]) for column in columns) + "\n")
-
- print("Output:", output)
- print(f"FHE computation finished in {fhe_time:.2f} seconds")
- print(f"Full inference finished in {inference_time:.2f} seconds")
-
-
-if __name__ == "__main__":
- main()
diff --git a/use_case_examples/cifar/cifar_brevitas_with_model_splitting/infer_fhe_simulation.py b/use_case_examples/cifar/cifar_brevitas_with_model_splitting/infer_fhe_simulation.py
deleted file mode 100644
index b537b9837..000000000
--- a/use_case_examples/cifar/cifar_brevitas_with_model_splitting/infer_fhe_simulation.py
+++ /dev/null
@@ -1,149 +0,0 @@
-"""Run the inference of the model using the FHE simulation mode."""
-import csv
-import random
-import time
-from pathlib import Path
-
-import numpy as np
-import torch
-import torchvision
-from brevitas import config
-from model import CNV
-from scipy.special import softmax
-from torch.backends import cudnn
-from torch.utils import data as torch_data
-from torchvision import transforms
-from tqdm import tqdm
-
-from concrete.ml.torch.compile import compile_brevitas_qat_model
-
-
-def seed_worker(worker_id: int):
- worker_seed = torch.initial_seed() % 2**32
- np.random.seed(worker_seed)
- random.seed(worker_seed)
-
-
-def main():
- config.IGNORE_MISSING_KEYS = True
- g = torch.Generator()
- g.manual_seed(0)
- np.random.seed(0)
- torch.manual_seed(0)
- random.seed(0)
- torch.use_deterministic_algorithms(True)
- cudnn.deterministic = True
-
- batch_size = 4
-
- IMAGE_TRANSFORM = transforms.Compose(
- [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
- )
-
- try:
- train_set = torchvision.datasets.CIFAR10(
- root=".data/",
- train=True,
- download=False,
- transform=IMAGE_TRANSFORM,
- target_transform=None,
- )
- except RuntimeError:
- train_set = torchvision.datasets.CIFAR10(
- root=".data/",
- train=True,
- download=True,
- transform=IMAGE_TRANSFORM,
- target_transform=None,
- )
-
- test_set = torchvision.datasets.CIFAR10(
- root=".data/",
- train=False,
- download=False,
- transform=IMAGE_TRANSFORM,
- target_transform=None,
- )
-
- num_samples = 1000
- train_sub_set = torch.stack(
- [train_set[index][0] for index in range(min(num_samples, len(train_set)))]
- )
-
- testloader = torch_data.DataLoader(
- test_set,
- batch_size=batch_size,
- shuffle=False,
- num_workers=2,
- worker_init_fn=seed_worker,
- generator=g,
- )
-
- classes = (
- "plane",
- "car",
- "bird",
- "cat",
- "deer",
- "dog",
- "frog",
- "horse",
- "ship",
- "truck",
- )
-
- nb_steps = len(test_set) // batch_size
-
- checkpoint_path = Path(__file__).parent
- model_path = checkpoint_path / "8_bit_model.pt"
- loaded = torch.load(model_path)
-
- net = CNV(
- num_classes=len(classes), weight_bit_width=2, act_bit_width=2, in_bit_width=3, in_ch=3
- )
- net.load_state_dict(loaded["model_state_dict"])
- net.eval()
-
- # Create a representative input-set that will be used used for both computing quantization
- # parameters and compiling the model
- with torch.no_grad():
- train_features_sub_set = net.clear_module(train_sub_set)
-
- compilation_onnx_path = "compilation_model.onnx"
- print("Compiling the model")
- start_compile = time.time()
-
- # Compile the quantized model
- quantized_numpy_module = compile_brevitas_qat_model(
- torch_model=net.encrypted_module,
- torch_inputset=train_features_sub_set,
- output_onnx_file=compilation_onnx_path,
- )
-
- end_compile = time.time()
- print(f"Compilation finished in {end_compile - start_compile:.2f} seconds")
-
- prediction_file = checkpoint_path / "predictions_fhe_simulation.csv"
- with open(prediction_file, "w", newline="") as csv_file:
- csv_writer = csv.writer(csv_file, delimiter=",", quotechar="|", quoting=csv.QUOTE_MINIMAL)
- csv_writer.writerow([f"{elt}_prob" for elt in classes] + ["label"])
-
- for _, data in (p_bar := tqdm(enumerate(testloader, 0), leave=False, total=nb_steps)):
- p_bar.set_description("Inference")
-
- # get the inputs; data is a list of [inputs, labels]
- inputs, labels = data
- with torch.no_grad():
- # forward + backward + optimize
- feat_maps = net.clear_module(inputs)
-
- outputs = softmax(quantized_numpy_module.forward(feat_maps.numpy(), fhe="disable"))
-
- for preds, label in zip(outputs, labels):
- csv_writer.writerow(preds.tolist() + [label.numpy().tolist()])
-
- print("Finished inference")
-
-
-if __name__ == "__main__":
- main()
diff --git a/use_case_examples/cifar/cifar_brevitas_with_model_splitting/infer_torch.py b/use_case_examples/cifar/cifar_brevitas_with_model_splitting/infer_torch.py
deleted file mode 100644
index 5f4eb6bdc..000000000
--- a/use_case_examples/cifar/cifar_brevitas_with_model_splitting/infer_torch.py
+++ /dev/null
@@ -1,114 +0,0 @@
-"""Run the model using torch"""
-import csv
-import random
-from datetime import datetime
-from pathlib import Path
-
-import numpy as np
-import torch
-import torchvision
-from brevitas import config
-from model import CNV
-from torch.backends import cudnn
-from torch.utils import data as torch_data
-from torchvision import transforms
-from tqdm import tqdm
-
-DATE_FORMAT = "%Y_%m_%d_%H_%M_%S"
-
-
-def seed_worker(worker_id: int):
- worker_seed = torch.initial_seed() % 2**32
- np.random.seed(worker_seed)
- random.seed(worker_seed)
-
-
-def main():
- config.IGNORE_MISSING_KEYS = True
- g = torch.Generator()
- g.manual_seed(0)
- np.random.seed(0)
- torch.manual_seed(0)
- random.seed(0)
- torch.use_deterministic_algorithms(True)
- cudnn.deterministic = True
-
- batch_size = 4
-
- IMAGE_TRANSFORM = transforms.Compose(
- [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
- )
-
- try:
- test_set = torchvision.datasets.CIFAR10(
- root=".data/",
- train=False,
- download=False,
- transform=IMAGE_TRANSFORM,
- target_transform=None,
- )
- except RuntimeError:
- test_set = torchvision.datasets.CIFAR10(
- root=".data/",
- train=False,
- download=True,
- transform=IMAGE_TRANSFORM,
- target_transform=None,
- )
-
- testloader = torch_data.DataLoader(
- test_set,
- batch_size=batch_size,
- shuffle=False,
- num_workers=2,
- worker_init_fn=seed_worker,
- generator=g,
- )
-
- classes = (
- "plane",
- "car",
- "bird",
- "cat",
- "deer",
- "dog",
- "frog",
- "horse",
- "ship",
- "truck",
- )
-
- nb_steps = len(test_set) // batch_size
-
- checkpoint_path = Path(__file__).parent
- model_path = checkpoint_path / "8_bit_model.pt"
- loaded = torch.load(model_path)
-
- net = CNV(
- num_classes=len(classes), weight_bit_width=2, act_bit_width=2, in_bit_width=3, in_ch=3
- )
- net.load_state_dict(loaded["model_state_dict"])
- net.eval()
-
- prediction_file = checkpoint_path / "predictions.csv"
- with open(prediction_file, "w", newline="") as csv_file:
- csv_writer = csv.writer(csv_file, delimiter=",", quotechar="|", quoting=csv.QUOTE_MINIMAL)
- csv_writer.writerow([f"{elt}_prob" for elt in classes] + ["label"])
-
- for _, data in (p_bar := tqdm(enumerate(testloader, 0), leave=False, total=nb_steps)):
- p_bar.set_description("Inference")
-
- # get the inputs; data is a list of [inputs, labels]
- inputs, labels = data
- with torch.no_grad():
- # forward + backward + optimize
- outputs = net(inputs)
- outputs = torch.nn.functional.softmax(outputs, dim=1)
- for preds, label in zip(outputs, labels):
- csv_writer.writerow(preds.numpy().tolist() + [label.numpy().tolist()])
-
- print("Finished inference")
-
-
-if __name__ == "__main__":
- main()
diff --git a/use_case_examples/cifar/cifar_brevitas_with_model_splitting/model.py b/use_case_examples/cifar/cifar_brevitas_with_model_splitting/model.py
deleted file mode 100644
index b8323f355..000000000
--- a/use_case_examples/cifar/cifar_brevitas_with_model_splitting/model.py
+++ /dev/null
@@ -1,52 +0,0 @@
-# https://github.com/Xilinx/brevitas/blob/8c3d9de0113528cf6693c6474a13d802a66682c6/src/brevitas_examples/bnn_pynq/models/CNV.py
-import torch
-from clear_module import ClearModule
-from encrypted_module import EncryptedModule
-
-CNV_OUT_CH_POOL = [
- (64, False),
- (64, True),
- (128, False),
- (128, True),
- (256, False),
- (256, False),
-]
-
-INTERMEDIATE_FC_FEATURES = [(256, 512), (512, 512)]
-LAST_FC_IN_FEATURES = 512
-LAST_FC_PER_OUT_CH_SCALING = False
-POOL_SIZE = 2
-KERNEL_SIZE = 3
-EPSILON_VALUE = 0.5
-SPLIT_INDEX = 1
-
-
-# The model combining both modules
-class CNV(torch.nn.Module):
- def __init__(
- self,
- num_classes: int,
- weight_bit_width: int,
- act_bit_width: int,
- in_bit_width: int,
- in_ch: int,
- ) -> None:
- super().__init__()
- self.num_classes = num_classes
- self.weight_bit_width = weight_bit_width
- self.act_bit_width = act_bit_width
- self.in_bit_width = in_bit_width
- self.in_ch = in_ch
-
- self.clear_module = ClearModule(in_ch=in_ch, out_bit_width=in_bit_width)
- self.encrypted_module = EncryptedModule(
- num_classes=num_classes,
- weight_bit_width=weight_bit_width,
- act_bit_width=act_bit_width,
- in_bit_width=in_bit_width,
- in_ch=self.clear_module.out_ch,
- )
-
- def forward(self, x):
- x = self.clear_module(x)
- return self.encrypted_module(x)
diff --git a/use_case_examples/cifar/cifar_brevitas_with_model_splitting/p_error_search.py b/use_case_examples/cifar/cifar_brevitas_with_model_splitting/p_error_search.py
deleted file mode 100644
index 8f66dea04..000000000
--- a/use_case_examples/cifar/cifar_brevitas_with_model_splitting/p_error_search.py
+++ /dev/null
@@ -1,154 +0,0 @@
-#!/usr/bin/env python
-# coding: utf-8
-
-# Run p_error search for CIFAR-10 8-bits network.
-
-# Steps:
-# 1. Load CIFAR-10 8-bits model
-
-# 2. Load CIFAR-10 data-set
-
-# 3. Pre-process the data-set, in our case, that means:
-# - Reducing the data-set size (that we will can calibration data in our experiments)
-# - Computing the features maps of the model on the client side
-
-# 4. Run the search for a given set of hyper-parameters
-# - The objective is to look for the largest `p_error = i`, with i ∈ ]0,0.9[ ∩ ℝ,
-# which gives a model_i that has `accuracy_i`, such that:
-# | accuracy_i - accuracy_0| <= Threshold, where:
-# - Threshold is given by the user and
-# - `accuracy_0` refers to original model with `p_error ~ 0`
-
-# - Define our objective:
-# - If the objective is matched -> update the lower bound to be the current p-error
-# - Else, update the upper bound to be the current p-error
-# - Update the current p-error with the mean of the bounds
-
-# - The search terminates once it reaches the maximum number of iterations
-
-# - The inference is performed via the FHE simulation mode
-
-# `p_error` is bounded between 0 and 0.9
-# - `p_error ~ 0.0`, refers to the original model in clear, that gives an accuracy
-# that we note as `accuracy_0`
-# - By default, `lower = 0.0` and `uppder` bound to 0.9.
-
-# - Run the inference in FHE simulation mode
-# - Define our objective:
-# - If the objective is matched -> update the lower bound to be the current p-error
-# - Else, update the upper bound to be the current p-error
-# - Update the current p-error with the mean of the bounds
-
-import argparse
-
-import torch
-from model import CNV
-from sklearn.metrics import top_k_accuracy_score
-from torchvision import datasets, transforms
-
-from concrete.ml.pytest.utils import (
- data_calibration_processing,
- get_torchvision_dataset,
- load_torch_model,
-)
-from concrete.ml.search_parameters import BinarySearch
-
-DATASETS_ARGS = {
- "CIFAR10": {
- "dataset": datasets.CIFAR10,
- "train_transform": transforms.Compose(
- [
- transforms.ToTensor(),
- transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
- ]
- ),
- "test_transform": transforms.Compose(
- [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
- ),
- }
-}
-
-
-MODELS_ARGS = {
- "CIFAR10_8b": {
- "model_class": CNV,
- "path": "./use_case_examples/cifar_brevitas_with_model_splitting/8_bit_model.pt",
- "params": {
- "num_classes": 10,
- "weight_bit_width": 2,
- "act_bit_width": 2,
- "in_bit_width": 3,
- "in_ch": 3,
- },
- },
-}
-
-
-def main(args):
-
- if args.verbose:
- print(f"** Download `{args.dataset_name=}` dataset")
- dataset = get_torchvision_dataset(DATASETS_ARGS[args.dataset_name], train_set=True)
- x_calib, y = data_calibration_processing(dataset, n_sample=args.n_sample)
-
- if args.verbose:
- print(f"** Load `{args.model_name=}` network")
-
- checkpoint = torch.load(MODELS_ARGS[args.model_name]["path"], map_location=args.device)
- state_dict = checkpoint["model_state_dict"]
-
- model = load_torch_model(
- MODELS_ARGS[args.model_name]["model_class"],
- state_dict,
- MODELS_ARGS[args.model_name]["params"],
- )
- model.eval()
-
- with torch.no_grad():
- x_calib = model.clear_module(torch.tensor(x_calib)).numpy()
-
- model = model.encrypted_module
-
- if args.verbose:
- print("** `p_error` search")
-
- search = BinarySearch(
- estimator=model, predict="predict", metric=top_k_accuracy_score, verbose=args.verbose
- )
-
- p_error = search.run(x=x_calib, ground_truth=y, strategy=all)
-
- if args.verbose:
- print(f"Optimal p_error: {p_error}")
-
-
-if __name__ == "__main__":
-
- parser = argparse.ArgumentParser()
-
- parser.add_argument("--seed", type=int, default=42, help="Seed")
- parser.add_argument("--verbose", type=bool, default=True, help="Verbose")
- parser.add_argument("--n_sample", type=int, default=500, help="n_sample")
- parser.add_argument(
- "--dataset_name",
- type=str,
- default="CIFAR10",
- choices=["CIFAR10"],
- help="The selected dataset",
- )
- parser.add_argument(
- "--model_name",
- type=str,
- default="CIFAR10_8b",
- choices=["CIFAR10_8b"],
- help="The selected model",
- )
-
- args = parser.parse_args()
-
- # Add MPS (for macOS with Apple Silicon or AMD GPUs) support when error is fixed. For now, we
- # observe a decrease in torch's top1 accuracy when using MPS devices
- # FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/3953
- args.device = "cuda" if torch.cuda.is_available() else "cpu"
-
- main(args)
diff --git a/use_case_examples/cifar/cifar_brevitas_with_model_splitting/requirements.txt b/use_case_examples/cifar/cifar_brevitas_with_model_splitting/requirements.txt
deleted file mode 100644
index 559968300..000000000
--- a/use_case_examples/cifar/cifar_brevitas_with_model_splitting/requirements.txt
+++ /dev/null
@@ -1,4 +0,0 @@
-concrete-ml
-jupyter
-torchvision
-pandas
diff --git a/use_case_examples/cifar/cifar_brevitas_with_model_splitting/split_model.py b/use_case_examples/cifar/cifar_brevitas_with_model_splitting/split_model.py
deleted file mode 100644
index eb01a415c..000000000
--- a/use_case_examples/cifar/cifar_brevitas_with_model_splitting/split_model.py
+++ /dev/null
@@ -1,18 +0,0 @@
-from pathlib import Path
-
-import torch
-from model import CNV
-
-
-def main():
- checkpoint_path = Path(__file__).parent
- model_path = checkpoint_path / "8_bit_model.pt"
- loaded = torch.load(model_path)
- net = CNV(num_classes=10, weight_bit_width=2, act_bit_width=2, in_bit_width=3, in_ch=3)
- net.load_state_dict(loaded["model_state_dict"])
- torch.save(net.clear_module.state_dict(), checkpoint_path / "clear_module.pt")
- torch.save(net.encrypted_module.state_dict(), checkpoint_path / "encrypted_module.pt")
-
-
-if __name__ == "__main__":
- main()
diff --git a/use_case_examples/deployment/cifar_8_bit/Dockerfile.client b/use_case_examples/deployment/cifar_8_bit/Dockerfile.client
index 2382c5dce..53ae6763a 100644
--- a/use_case_examples/deployment/cifar_8_bit/Dockerfile.client
+++ b/use_case_examples/deployment/cifar_8_bit/Dockerfile.client
@@ -1,11 +1,7 @@
FROM zamafhe/concrete-ml
WORKDIR /project
COPY client.py .
-COPY clear_module.py .
-COPY constants.py .
-COPY brevitas_utils.py .
COPY client_requirements.txt .
-COPY clear_module.pt .
RUN python -m pip install -r client_requirements.txt
RUN python -m pip install torchvision==0.14.1 --no-deps
ENTRYPOINT /bin/bash
diff --git a/use_case_examples/deployment/cifar_8_bit/Dockerfile.compile b/use_case_examples/deployment/cifar_8_bit/Dockerfile.compile
index e22513f6a..dc29700f9 100644
--- a/use_case_examples/deployment/cifar_8_bit/Dockerfile.compile
+++ b/use_case_examples/deployment/cifar_8_bit/Dockerfile.compile
@@ -1,12 +1,12 @@
FROM zamafhe/concrete-ml
WORKDIR /project
-RUN python -m pip install torchvision==0.14.1 --no-deps
+COPY requirements.txt requirements.txt
+#RUN python -m pip install torchvision==0.14.1 --no-deps
+RUN python -m pip install -r requirements.txt
RUN python -m pip install requests
-COPY ./compile.py ./compile.py
-COPY ./encrypted_module.py ./encrypted_module.py
-COPY ./model.py ./model.py
-COPY ./clear_module.py ./clear_module.py
-COPY ./brevitas_utils.py ./brevitas_utils.py
-COPY ./constants.py ./constants.py
-COPY ./8_bit_model.pt ./8_bit_model.pt
+
+COPY models/ models/
+COPY experiments/ experiments/
+COPY compile.py compile.py
+
ENTRYPOINT python compile.py
diff --git a/use_case_examples/deployment/cifar_8_bit/Makefile b/use_case_examples/deployment/cifar_8_bit/Makefile
index 5043905b0..ad1c98316 100644
--- a/use_case_examples/deployment/cifar_8_bit/Makefile
+++ b/use_case_examples/deployment/cifar_8_bit/Makefile
@@ -14,4 +14,4 @@ two: one
@python -m concrete.ml.deployment.deploy_to_docker --only-build
three: two
- @python build_docker_client_image.py
+ @docker build --tag cifar_client -f Dockerfile.client .
diff --git a/use_case_examples/deployment/cifar_8_bit/README.md b/use_case_examples/deployment/cifar_8_bit/README.md
index d6357df28..f7c0ef9e4 100644
--- a/use_case_examples/deployment/cifar_8_bit/README.md
+++ b/use_case_examples/deployment/cifar_8_bit/README.md
@@ -1,7 +1,7 @@
# Deployment
In this folder we show how to deploy a Concrete ML model that classifies images from CIFAR-10, either through Docker or Amazon Web Services.
-We use the model showcased in [cifar split model](../../cifar/cifar_brevitas_with_model_splitting/README.md).
+We use the model showcased in [CIFAR QAT training from scratch](../../cifar/cifar_brevitas_training/README.md).
## Get started
diff --git a/use_case_examples/deployment/cifar_8_bit/build_docker_client_image.py b/use_case_examples/deployment/cifar_8_bit/build_docker_client_image.py
deleted file mode 100644
index 67a092cf7..000000000
--- a/use_case_examples/deployment/cifar_8_bit/build_docker_client_image.py
+++ /dev/null
@@ -1,60 +0,0 @@
-import os
-import shutil
-import subprocess
-from pathlib import Path
-
-import torchvision
-import torchvision.transforms as transforms
-
-
-def main():
- path_of_script = Path(__file__).parent.resolve()
- IMAGE_TRANSFORM = transforms.Compose(
- [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
- )
-
- # Load data
- try:
- train_set = torchvision.datasets.CIFAR10(
- root=path_of_script / "data",
- train=True,
- download=False,
- transform=IMAGE_TRANSFORM,
- target_transform=None,
- )
- except:
- train_set = torchvision.datasets.CIFAR10(
- root=path_of_script / "data",
- train=True,
- download=True,
- transform=IMAGE_TRANSFORM,
- target_transform=None,
- )
- del train_set
-
- files = ["clear_module.py", "constants.py", "brevitas_utils.py", "clear_module.pt"]
-
- # Copy files
- for file_name in files:
- source = Path(
- path_of_script / f"../../cifar_brevitas_with_model_splitting/{file_name}"
- ).resolve()
- target = Path(path_of_script / file_name).resolve()
- if not target.exists():
- print(f"{source} -> {target}")
- shutil.copyfile(src=source, dst=target)
-
- # Build image
- os.chdir(path_of_script)
- command = f'docker build --tag cml_client_cifar_10_8_bit --file "{path_of_script}/Dockerfile.client" .'
- print(command)
- subprocess.check_output(command, shell=True)
-
- # Remove files
- for file_name in files:
- target = Path(path_of_script / file_name).resolve()
- target.unlink()
-
-
-if __name__ == "__main__":
- main()
diff --git a/use_case_examples/deployment/cifar_8_bit/client.py b/use_case_examples/deployment/cifar_8_bit/client.py
index d84175d83..2b2cdbf51 100644
--- a/use_case_examples/deployment/cifar_8_bit/client.py
+++ b/use_case_examples/deployment/cifar_8_bit/client.py
@@ -13,17 +13,12 @@
import sys
from pathlib import Path
-# Append CIFAR-10 8-bit example
-PATH_TO_CIFAR_MODEL = Path(__file__).parent / "../../cifar_brevitas_with_model_splitting/"
-sys.path.append(str(PATH_TO_CIFAR_MODEL.resolve()))
-
import grequests
import numpy
import requests
import torch
import torchvision
import torchvision.transforms as transforms
-from clear_module import ClearModule
from concrete.ml.deployment import FHEModelClient
@@ -35,12 +30,6 @@
def main():
- # Load clear part of the model
- model = ClearModule(out_bit_width=3, in_ch=3)
- loaded = torch.load(Path(__file__).parent / "clear_module.pt")
- model.load_state_dict(loaded)
- model = model.eval()
-
# Load data
IMAGE_TRANSFORM = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
@@ -68,10 +57,6 @@ def main():
[train_set[index][0] for index in range(min(num_samples, len(train_set)))]
)
- # Pre-processing -> images -> feature maps
- with torch.no_grad():
- train_features_sub_set = model(train_sub_set)
-
# Get the necessary data for the client
# client.zip
zip_response = requests.get(f"{URL}/get_client")
@@ -80,7 +65,7 @@ def main():
file.write(zip_response.content)
# Get the data to infer
- X = train_features_sub_set[:1]
+ X = train_sub_set[:1]
# Create the client
client = FHEModelClient(path_dir="./", key_dir="./keys")
diff --git a/use_case_examples/deployment/cifar_8_bit/compile.py b/use_case_examples/deployment/cifar_8_bit/compile.py
index c2aea0430..f7b5eac83 100644
--- a/use_case_examples/deployment/cifar_8_bit/compile.py
+++ b/use_case_examples/deployment/cifar_8_bit/compile.py
@@ -7,7 +7,7 @@
import torchvision
import torchvision.transforms as transforms
from concrete.fhe import Configuration
-from model import CNV
+from models import cnv_2w2a
from concrete.ml.deployment import FHEModelDev
from concrete.ml.torch.compile import compile_brevitas_qat_model
@@ -15,10 +15,19 @@
def main():
# Load model
- model = CNV(num_classes=10, weight_bit_width=2, act_bit_width=2, in_bit_width=3, in_ch=3)
- loaded = torch.load(Path(__file__).parent / "8_bit_model.pt")
- model.load_state_dict(loaded["model_state_dict"])
- model = model.eval()
+ # model = CNV(num_classes=10, weight_bit_width=2, act_bit_width=2, in_bit_width=3, in_ch=3)
+ # loaded = torch.load(Path(__file__).parent / "8_bit_model.pt")
+ # model.load_state_dict(loaded["model_state_dict"])
+
+ # Instantiate the model
+ model = cnv_2w2a(pre_trained=False)
+ model.eval()
+ # Load the saved parameters using the available checkpoint
+ checkpoint = torch.load(
+ Path(__file__).parent / "experiments/CNV_2W2A_2W2A_20221114_131345/checkpoints/best.tar",
+ map_location=torch.device("cpu"),
+ )
+ model.load_state_dict(checkpoint["state_dict"], strict=False)
IMAGE_TRANSFORM = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
@@ -47,19 +56,14 @@ def main():
[train_set[index][0] for index in range(min(num_samples, len(train_set)))]
)
- # Create a representative input-set that will be used used for both computing quantization
- # parameters and compiling the model
- with torch.no_grad():
- train_features_sub_set = model.clear_module(train_sub_set)
-
compilation_onnx_path = "compilation_model.onnx"
print("Compiling the model ...")
start_compile = time.time()
# Compile the quantized model
quantized_numpy_module = compile_brevitas_qat_model(
- torch_model=model.encrypted_module,
- torch_inputset=train_features_sub_set,
+ torch_model=model,
+ torch_inputset=train_sub_set,
p_error=0.05,
output_onnx_file=compilation_onnx_path,
n_bits=8,
diff --git a/use_case_examples/deployment/cifar_8_bit/compile_with_docker.py b/use_case_examples/deployment/cifar_8_bit/compile_with_docker.py
index d6132adc8..899a261d6 100644
--- a/use_case_examples/deployment/cifar_8_bit/compile_with_docker.py
+++ b/use_case_examples/deployment/cifar_8_bit/compile_with_docker.py
@@ -7,22 +7,21 @@
def main():
path_of_script = Path(__file__).parent.resolve()
files = [
- "encrypted_module.py",
- "clear_module.py",
- "model.py",
- "brevitas_utils.py",
- "8_bit_model.pt",
- "constants.py",
+ "models/__init__.py",
+ "models/cnv_2w2a.ini",
+ "models/common.py",
+ "models/model.py",
+ "models/tensor_norm.py",
+ "experiments/CNV_2W2A_2W2A_20221114_131345/checkpoints/best.tar",
]
# Copy files
for file_name in files:
- source = Path(
- path_of_script / f"../../cifar_brevitas_with_model_splitting/{file_name}"
- ).resolve()
+ source = Path(path_of_script / f"../../cifar/cifar_brevitas_training/{file_name}").resolve()
target = Path(path_of_script / file_name).resolve()
if not target.exists():
print(f"{source} -> {target}")
+ target.parent.mkdir(parents=True, exist_ok=True)
shutil.copyfile(src=source, dst=target)
# Build image
diff --git a/use_case_examples/deployment/cifar_8_bit/requirements.txt b/use_case_examples/deployment/cifar_8_bit/requirements.txt
index e35531e56..d5fa2d7c9 100644
--- a/use_case_examples/deployment/cifar_8_bit/requirements.txt
+++ b/use_case_examples/deployment/cifar_8_bit/requirements.txt
@@ -1 +1,2 @@
-torchvision
+torchvision==0.14.1
+Pillow