Skip to content

Commit

Permalink
chore: add gpu device to cifar and resnet18 use cases
Browse files Browse the repository at this point in the history
  • Loading branch information
kcelia authored Sep 25, 2024
1 parent 4556ed3 commit b37ceed
Show file tree
Hide file tree
Showing 9 changed files with 144 additions and 23 deletions.
33 changes: 31 additions & 2 deletions use_case_examples/cifar/cifar_brevitas_finetuning/CifarInFhe.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
"source": [
"import warnings\n",
"\n",
"import concrete.compiler\n",
"import torch\n",
"from cifar_utils import (\n",
" fhe_compatibility,\n",
Expand Down Expand Up @@ -67,6 +68,34 @@
"print(f\"Device Type: {device}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Concrete ML also supports a CUDA-enabled backend. To set it up, follow the instructions in the official [guide](../../../docs/guides/using_gpu.md) for installing the GPU-enabled Concrete compiler."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Is GPU enabled: False\n",
"Is GPU available: False\n"
]
}
],
"source": [
"compilation_device = \"cuda\" if concrete.compiler.check_gpu_available() else \"cpu\"\n",
"\n",
"print(f\"Is GPU enabled: {concrete.compiler.check_gpu_enabled()}\")\n",
"print(f\"Is GPU available: {concrete.compiler.check_gpu_available()}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down Expand Up @@ -206,7 +235,7 @@
"\n",
"data_calibration, _ = next(iter(train_loader_c10))\n",
"\n",
"qmodel_c10 = fhe_compatibility(quant_vgg_c10, data_calibration)\n",
"qmodel_c10 = fhe_compatibility(quant_vgg_c10, data_calibration, device=compilation_device)\n",
"\n",
"print(\n",
" f\"Maximum bit-width in the circuit: {qmodel_c10.fhe_circuit.graph.maximum_integer_bit_width()}\"\n",
Expand Down Expand Up @@ -394,7 +423,7 @@
"\n",
"data_calibration, _ = next(iter(train_loader_c100))\n",
"\n",
"qmodel_c100 = fhe_compatibility(quant_vgg_c100, data_calibration)\n",
"qmodel_c100 = fhe_compatibility(quant_vgg_c100, data_calibration, device=compilation_device)\n",
"\n",
"print(\n",
" f\"Maximum bit-width in the circuit: {qmodel_c100.fhe_circuit.graph.maximum_integer_bit_width()}\"\n",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
"import warnings\n",
"from typing import Callable, List, Tuple\n",
"\n",
"import concrete.compiler\n",
"import matplotlib.pyplot as plt\n",
"import torch\n",
"from cifar_utils import fhe_simulation_inference, get_dataloader, torch_inference\n",
Expand Down Expand Up @@ -62,6 +63,25 @@
"device = \"cuda\" if torch.cuda.is_available() else \"cpu\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Concrete ML also supports a CUDA-enabled backend. To set it up, follow the instructions in the official [guide](../../../docs/guides/using_gpu.md) for installing the GPU-enabled Concrete compiler."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"compilation_device = \"cuda\" if concrete.compiler.check_gpu_available() else \"cpu\"\n",
"\n",
"print(f\"Is GPU enabled: {concrete.compiler.check_gpu_enabled()}\")\n",
"print(f\"Is GPU available: {concrete.compiler.check_gpu_available()}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down Expand Up @@ -91,6 +111,7 @@
" model.to(\"cpu\"),\n",
" torch_inputset=X_train,\n",
" rounding_threshold_bits=max_bitwidth,\n",
" device=compilation_device,\n",
" )\n",
"\n",
" acc_fhe_s = fhe_simulation_inference(qmodel, test_loader, True)\n",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
"metadata": {},
"outputs": [],
"source": [
"import concrete.compiler\n",
"import torch\n",
"from cifar_utils import (\n",
" fhe_compatibility,\n",
Expand Down Expand Up @@ -93,6 +94,25 @@
"print(f\"Device Type: {device}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Concrete ML also supports a CUDA-enabled backend. To set it up, follow the instructions in the official [guide](../../../docs/guides/using_gpu.md) for installing the GPU-enabled Concrete compiler."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"compilation_device = \"cuda\" if concrete.compiler.check_gpu_available() else \"cpu\"\n",
"\n",
"print(f\"Is GPU enabled: {concrete.compiler.check_gpu_enabled()}\")\n",
"print(f\"Is GPU available: {concrete.compiler.check_gpu_available()}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down Expand Up @@ -280,7 +300,7 @@
"\n",
"data_calibration, _ = next(iter(train_loader_c100))\n",
"\n",
"qmodel = fhe_compatibility(quant_vgg, data_calibration)\n",
"qmodel = fhe_compatibility(quant_vgg, data_calibration, device=compilation_device)\n",
"\n",
"print(\n",
" f\"With {param_c100['dataset_name']}, the maximum bit-width in the circuit = \"\n",
Expand Down Expand Up @@ -544,7 +564,7 @@
"# Check the FHE-compatibility.\n",
"data, _ = next(iter(train_loader_c10))\n",
"\n",
"qmodel = fhe_compatibility(quant_vgg, data)\n",
"qmodel = fhe_compatibility(quant_vgg, data, device=compilation_device)\n",
"\n",
"print(\n",
" f\"With {param_c10['dataset_name']}, the circuit has a maximum bit-width of \"\n",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
"from itertools import chain\n",
"from time import time\n",
"\n",
"import concrete.compiler\n",
"import matplotlib.pylab as plt\n",
"import numpy\n",
"import torch\n",
Expand Down Expand Up @@ -78,6 +79,25 @@
"print(f\"Device Type: {device}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Concrete ML also supports a CUDA-enabled backend. To set it up, follow the instructions in the official [guide](../../../docs/guides/using_gpu.md) for installing the GPU-enabled Concrete compiler."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"compilation_device = \"cuda\" if concrete.compiler.check_gpu_available() else \"cpu\"\n",
"\n",
"print(f\"Is GPU enabled: {concrete.compiler.check_gpu_enabled()}\")\n",
"print(f\"Is GPU available: {concrete.compiler.check_gpu_available()}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down Expand Up @@ -245,7 +265,10 @@
"\n",
" start_time = time()\n",
" qmodel = compile_brevitas_qat_model(\n",
" torch_model=quant_model, torch_inputset=X_calib, p_error=p_error\n",
" torch_model=quant_model,\n",
" torch_inputset=X_calib,\n",
" p_error=p_error,\n",
" device=compilation_device,\n",
" )\n",
" compilation_time.append((time() - start_time) / 60.0)\n",
"\n",
Expand Down Expand Up @@ -353,7 +376,10 @@
"\n",
"# Compile the model with the optimal `p_error`\n",
"qmodel = compile_brevitas_qat_model(\n",
" torch_model=quant_model, torch_inputset=X_calib, p_error=largest_p_error\n",
" torch_model=quant_model,\n",
" torch_inputset=X_calib,\n",
" p_error=largest_p_error,\n",
" device=compilation_device,\n",
")\n",
"\n",
"# Key Generation\n",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import warnings
from collections import OrderedDict
from pathlib import Path
from time import time
from typing import Callable, Dict, Optional, Tuple

import matplotlib.pyplot as plt
Expand All @@ -14,7 +13,6 @@
from brevitas import config
from concrete.fhe.compilation import Configuration
from models import Fp32VGG11
from sklearn.metrics import top_k_accuracy_score
from torch.utils.data.dataloader import DataLoader
from torchvision import datasets, transforms
from torchvision.utils import make_grid
Expand Down Expand Up @@ -441,12 +439,13 @@ def torch_inference(
return np.mean(np.vstack(correct), dtype="float64")


def fhe_compatibility(model: Callable, data: DataLoader) -> Callable:
def fhe_compatibility(model: Callable, data: DataLoader, device: str) -> Callable:
"""Test if the model is FHE-compatible.
Args:
model (Callable): The Brevitas model.
data (DataLoader): The data loader.
device (str): Specifies the device to run during the compilation, either 'cpu' or 'gpu'.
Returns:
Callable: Quantized model.
Expand All @@ -458,6 +457,7 @@ def fhe_compatibility(model: Callable, data: DataLoader) -> Callable:
torch_inputset=data,
show_mlir=False,
output_onnx_file="test.onnx",
device=device,
)

return qmodel
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from pathlib import Path

import torch
from concrete.compiler import check_gpu_available
from concrete.fhe import Exactness
from concrete.fhe.compilation.configuration import Configuration
from models import cnv_2w2a
Expand All @@ -22,6 +23,8 @@
# observe a decrease in torch's top1 accuracy when using MPS devices
# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/3953
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
COMPILATION_DEVICE = "cuda" if check_gpu_available() else "cpu"

NUM_SAMPLES = int(os.environ.get("NUM_SAMPLES", 1))
P_ERROR = float(os.environ.get("P_ERROR", 0.01))

Expand Down Expand Up @@ -93,6 +96,7 @@ def wrapper(*args, **kwargs):
configuration=configuration,
rounding_threshold_bits={"method": Exactness.APPROXIMATE, "n_bits": 6},
p_error=P_ERROR,
device=COMPILATION_DEVICE,
)
assert isinstance(quantized_numpy_module, QuantizedModule)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import argparse
from pathlib import Path

import concrete.compiler
import numpy as np
import torch
from concrete.fhe import Configuration
Expand Down Expand Up @@ -74,8 +75,14 @@ def main(args):
# observe a decrease in torch's top1 accuracy when using MPS devices
# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/3953
device = "cuda" if torch.cuda.is_available() else "cpu"
compilation_device = "cuda" if concrete.compiler.check_gpu_available() else "cpu"

print("Device in use:", device)
print("Torch device in use:", device)
print(
"To leverage the CUDA backend, follow the GPU setup guide to install the Concrete ML compiler."
)
print("GPU Enabled:", concrete.compiler.check_gpu_enabled())
print("GPU Available:", concrete.compiler.check_gpu_available())

# Find relative path to this file
dir_path = Path(__file__).parent.absolute()
Expand Down Expand Up @@ -123,6 +130,7 @@ def main(args):
if rounding_threshold_bits is not None
else None
),
device=compilation_device,
)

# Print max bit-width in the circuit
Expand Down
16 changes: 8 additions & 8 deletions use_case_examples/resnet/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -104,14 +104,14 @@ GPU machine: 8xH100 GPU machine

Summary of the accuracy evaluation on ImageNet (100 images):

| w&a bits | p_error | Accuracy | Top-5 Accuracy | Runtime* | Device |
| -------- | ------- | -------- | -------------- | --------------- | ------ |
| fp32 | - | 67% | 87% | - | - |
| 6/6 | 0.05 | 55% | 78% | 56 min | GPU |
| 6/6 | 0.05 | 55% | 78% | 1 h 31 min | CPU |
| 7/7 | 0.05 | **66%** | **87%** | **2 h 12 min** | CPU |

*Runtime reported to run the inference on a single image
| w&a bits | p_error | Accuracy | Top-5 Accuracy | Runtime\* | Device |
| -------- | ------- | -------- | -------------- | -------------- | ------ |
| fp32 | - | 67% | 87% | - | - |
| 6/6 | 0.05 | 55% | 78% | 56 min | GPU |
| 6/6 | 0.05 | 55% | 78% | 1 h 31 min | CPU |
| 7/7 | 0.05 | **66%** | **87%** | **2 h 12 min** | CPU |

\*Runtime reported to run the inference on a single image

6/6 `n_bits` configuration: {"model_inputs": 8, "op_inputs": 6, "op_weights": 6, "model_outputs": 9}

Expand Down
Loading

0 comments on commit b37ceed

Please sign in to comment.