diff --git a/use_case_examples/cifar_brevitas_training/README.md b/use_case_examples/cifar_brevitas_training/README.md index 3c2c52b35..c83d1f26d 100644 --- a/use_case_examples/cifar_brevitas_training/README.md +++ b/use_case_examples/cifar_brevitas_training/README.md @@ -69,7 +69,7 @@ Testing with different rounding_threshold_bits values can help you understand th python3 evaluate_torch_cml.py --rounding_threshold_bits 1 2 3 4 5 6 7 8 ``` -Using rounding with 6 bits for all accumulators provides a significant speedup for FHE, with only a 1.3% loss in accuracy compared to the original model. More details can be found in the Accuracy and Performance section below. +Using rounding with 6 bits for all accumulators provides a significant speedup for FHE, with only a 2.7% loss in accuracy compared to the original model. More details can be found in the Accuracy and Performance section below. ## Fully Homomorphic Encryption (FHE) @@ -93,11 +93,10 @@ Experiments were conducted on an m6i.metal machine offering 128 CPU cores and 51 | ---------------------- | -------- | -------- | | VGG Torch | None | 88.7 | | VGG FHE (simulation\*) | None | 88.7 | -| VGG FHE (simulation\*) | 8 bits | 88.3 | -| VGG FHE (simulation\*) | 7 bits | 88.3 | -| VGG FHE (simulation\*) | 6 bits | 87.5 | -| VGG FHE (simulation\*) | 5 bits | 84.9 | -| VGG FHE | 6 bits | 87.5\*\* | +| VGG FHE (simulation\*) | 8 bits | 88.0 | +| VGG FHE (simulation\*) | 7 bits | 87.2 | +| VGG FHE (simulation\*) | 6 bits | 86.0 | +| VGG FHE | 6 bits | 86.0\*\* | We ran the FHE inference over 10 examples and achieved 100% similar predictions between the simulation and FHE. The overall accuracy for the entire data-set is expected to match the simulation. The original model with a maximum of 13 bits of precision ran in around 9 hours on the specified hardware. Using the rounding approach, the final model ran in **31 minutes**, providing a speedup factor of 18x while preserving accuracy. This significant performance improvement demonstrates the benefits of the rounding operator in the FHE setting. diff --git a/use_case_examples/cifar_brevitas_training/evaluate_torch_cml.py b/use_case_examples/cifar_brevitas_training/evaluate_torch_cml.py index c55c2e230..95d6136d2 100644 --- a/use_case_examples/cifar_brevitas_training/evaluate_torch_cml.py +++ b/use_case_examples/cifar_brevitas_training/evaluate_torch_cml.py @@ -1,5 +1,5 @@ import argparse -import pathlib +from pathlib import Path import numpy as np import torch @@ -11,11 +11,13 @@ from concrete.ml.torch.compile import compile_brevitas_qat_model +CURRENT_DIR = Path(__file__).resolve().parent + def evaluate(torch_model, cml_model, device, num_workers): # Import and load the CIFAR test dataset (following bnn_pynq_train.py) - test_set = get_test_set(dataset="CIFAR10", datadir=".datasets/") + test_set = get_test_set(dataset="CIFAR10", datadir=CURRENT_DIR / ".datasets/") test_loader = DataLoader(test_set, batch_size=128, shuffle=False, num_workers=num_workers) torch_top_1_batches = [] @@ -76,7 +78,7 @@ def main(args): print("Device in use:", device) # Find relative path to this file - dir_path = pathlib.Path(__file__).parent.absolute() + dir_path = Path(__file__).parent.absolute() # Load checkpoint checkpoint = torch.load( @@ -88,7 +90,7 @@ def main(args): model.load_state_dict(checkpoint["state_dict"], strict=False) # Load the training set - train_set = get_train_set(dataset="CIFAR10", datadir=".datasets/") + train_set = get_train_set(dataset="CIFAR10", datadir=CURRENT_DIR / ".datasets/") # Create a representative input-set from the training set that will be used used for both # computing quantization parameters and compiling the model