Skip to content

Commit

Permalink
docs: update cifar_brevitas_training accuracy using representative ca…
Browse files Browse the repository at this point in the history
…libration set
  • Loading branch information
RomanBredehoft committed Sep 12, 2023
1 parent 65d55fb commit 39480ef
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 10 deletions.
11 changes: 5 additions & 6 deletions use_case_examples/cifar_brevitas_training/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ Testing with different rounding_threshold_bits values can help you understand th
python3 evaluate_torch_cml.py --rounding_threshold_bits 1 2 3 4 5 6 7 8
```

Using rounding with 6 bits for all accumulators provides a significant speedup for FHE, with only a 1.3% loss in accuracy compared to the original model. More details can be found in the Accuracy and Performance section below.
Using rounding with 6 bits for all accumulators provides a significant speedup for FHE, with only a 2.7% loss in accuracy compared to the original model. More details can be found in the Accuracy and Performance section below.

## Fully Homomorphic Encryption (FHE)

Expand All @@ -93,11 +93,10 @@ Experiments were conducted on an m6i.metal machine offering 128 CPU cores and 51
| ---------------------- | -------- | -------- |
| VGG Torch | None | 88.7 |
| VGG FHE (simulation\*) | None | 88.7 |
| VGG FHE (simulation\*) | 8 bits | 88.3 |
| VGG FHE (simulation\*) | 7 bits | 88.3 |
| VGG FHE (simulation\*) | 6 bits | 87.5 |
| VGG FHE (simulation\*) | 5 bits | 84.9 |
| VGG FHE | 6 bits | 87.5\*\* |
| VGG FHE (simulation\*) | 8 bits | 88.0 |
| VGG FHE (simulation\*) | 7 bits | 87.2 |
| VGG FHE (simulation\*) | 6 bits | 86.0 |
| VGG FHE | 6 bits | 86.0\*\* |

We ran the FHE inference over 10 examples and achieved 100% similar predictions between the simulation and FHE. The overall accuracy for the entire data-set is expected to match the simulation. The original model with a maximum of 13 bits of precision ran in around 9 hours on the specified hardware. Using the rounding approach, the final model ran in **31 minutes**, providing a speedup factor of 18x while preserving accuracy. This significant performance improvement demonstrates the benefits of the rounding operator in the FHE setting.

Expand Down
10 changes: 6 additions & 4 deletions use_case_examples/cifar_brevitas_training/evaluate_torch_cml.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import argparse
import pathlib
from pathlib import Path

import numpy as np
import torch
Expand All @@ -11,11 +11,13 @@

from concrete.ml.torch.compile import compile_brevitas_qat_model

CURRENT_DIR = Path(__file__).resolve().parent


def evaluate(torch_model, cml_model, device, num_workers):

# Import and load the CIFAR test dataset (following bnn_pynq_train.py)
test_set = get_test_set(dataset="CIFAR10", datadir=".datasets/")
test_set = get_test_set(dataset="CIFAR10", datadir=CURRENT_DIR / ".datasets/")
test_loader = DataLoader(test_set, batch_size=128, shuffle=False, num_workers=num_workers)

torch_top_1_batches = []
Expand Down Expand Up @@ -76,7 +78,7 @@ def main(args):
print("Device in use:", device)

# Find relative path to this file
dir_path = pathlib.Path(__file__).parent.absolute()
dir_path = Path(__file__).parent.absolute()

# Load checkpoint
checkpoint = torch.load(
Expand All @@ -88,7 +90,7 @@ def main(args):
model.load_state_dict(checkpoint["state_dict"], strict=False)

# Load the training set
train_set = get_train_set(dataset="CIFAR10", datadir=".datasets/")
train_set = get_train_set(dataset="CIFAR10", datadir=CURRENT_DIR / ".datasets/")

# Create a representative input-set from the training set that will be used used for both
# computing quantization parameters and compiling the model
Expand Down

0 comments on commit 39480ef

Please sign in to comment.