diff --git a/docs/deep-learning/fhe_assistant.md b/docs/deep-learning/fhe_assistant.md index 93d198675..63bba03ec 100644 --- a/docs/deep-learning/fhe_assistant.md +++ b/docs/deep-learning/fhe_assistant.md @@ -53,12 +53,60 @@ concrete_clf.fit(X, y) concrete_clf.compile(X, debug_config) ``` -## Compilation error debugging +## Common compilation errors -Compilation errors that signal that the ML model is not FHE compatible are usually of two types: +#### 1. TLU input maximum bit-width is exceeded -1. TLU input maximum bit-width is exceeded -1. No crypto-parameters can be found for the ML model: `RuntimeError: NoParametersFound` is raised by the compiler +**Error message**: `this [N]-bit value is used as an input to a table lookup` + +**Cause**: This error can occur when `rounding_threshold_bits` is not used and accumulated intermediate values in the computation exceed 16 bits. + +**Possible solutions**: + +- Reduce quantization `n_bits`. However, this may reduce accuracy. When quantization `n_bits` must be below 6, it is best to use [Quantization Aware Training](../deep-learning/fhe_friendly_models.md). +- Use `rounding_threshold_bits`. This feature is described [here](../explanations/advanced_features.md#rounded-activations-and-quantizers). It is recommended to use the [`fhe.Exactness.APPROXIMATE`](../references/api/concrete.ml.torch.compile.md#function-compile_torch_model) setting, and set the rounding bits to 1 or 2 bits higher than the quantization `n_bits` +- Use [pruning](../explanations/pruning.md) + +#### 2. No crypto-parameters can be found + +**Error message**: `RuntimeError: NoParametersFound` + +**Cause**: This error occurs when using `rounding_threshold_bits` in the `compile_torch_model` function. + +**Possible solutions**: The solutions in this case are similar to the ones for the previous error. + +#### 3. Quantization import failed + +**Error message**: `Error occurred during quantization aware training (QAT) import [...] Could not determine a unique scale for the quantization!`. + +**Cause**: This error occurs when the model imported as a quantized-aware training model lacks quantization operators. See [this guide](../deep-learning/fhe_friendly_models.md) on how to use Brevitas layers. This error message indicates that some layers do not take inputs quantized through `QuantIdentity` layers. + +A common example is related to the concatenation operator. Suppose two tensors `x` and `y` are produced by two layers and need to be concatenated: + + + +```python +x = self.dense1(x) +y = self.dense2(y) +z = torch.cat([x, y]) +``` + +In the example above, the `x` and `y` layers need quantization before being concatenated. + +**Possible solutions**: + +1. If the error occurs for the first layer of the model: Add a `QuantIdentity` layer in your model and apply it on the input of the `forward` function, before the first layer is computed. +1. If the error occurs for a concatenation or addition layer: Add a new `QuantIdentity` layer in your model. Suppose it is called `quant_concat`. In the `forward` function, before concatenation of `x` and `y`, apply it to both tensors that are concatenated. The usage of a common `Quantidentity` layer to quantize both tensors that are concatenated ensures that they have the same scale: + + + +```python +z = torch.cat([self.quant_concat(x), self.quant_concat(y)]) +``` + +## Debugging compilation errors + +Compilation errors due to FHE incompatible models, such as maximum bit-width exceeded or `NoParametersFound` can be debugged by examining the bit-widths associated with various intermediate values of the FHE computation. The following produces a neural network that is not FHE-compatible: @@ -116,6 +164,8 @@ Function you are trying to compile cannot be compiled: The error `this 17-bit value is used as an input to a table lookup` indicates that the 16-bit limit on the input of the Table Lookup (TLU) operation has been exceeded. To pinpoint the model layer that causes the error, Concrete ML provides the [bitwidth_and_range_report](../references/api/concrete.ml.quantization.quantized_module.md#method-bitwidth_and_range_report) helper function. First, the model must be compiled so that it can be [simulated](fhe_assistant.md#simulation). +On the other hand, `NoParametersFound` is encountered when using `rounding_threshold_bits`. When using this setting, the 16-bit accumulator limit is relaxed. However, reducing bit-width, or reducing the `rounding_threshold_bits`, or using using the [`fhe.Exactness.APPROXIMATE`](../references/api/concrete.ml.torch.compile.md#function-compile_torch_model) rounding method can help. + ### Fixing compilation errors To make this network FHE-compatible one can apply several techniques: diff --git a/docs/deep-learning/torch_support.md b/docs/deep-learning/torch_support.md index d63754114..8c6497715 100644 --- a/docs/deep-learning/torch_support.md +++ b/docs/deep-learning/torch_support.md @@ -2,14 +2,21 @@ In addition to the built-in models, Concrete ML supports generic machine learning models implemented with Torch, or [exported as ONNX graphs](onnx_support.md). -As [Quantization Aware Training (QAT)](../explanations/quantization.md) is the most appropriate method of training neural networks that are compatible with [FHE constraints](../getting-started/concepts.md#model-accuracy-considerations-under-fhe-constraints), Concrete ML works with [Brevitas](../explanations/inner-workings/external_libraries.md#brevitas), a library providing QAT support for PyTorch. +There are two approaches to build [FHE-compatible deep networks](../getting-started/concepts.md#model-accuracy-considerations-under-fhe-constraints): -The following example uses a simple QAT PyTorch model that implements a fully connected neural network with two hidden layers. Due to its small size, making this model respect FHE constraints is relatively easy. +- [Quantization Aware Training (QAT)](../explanations/quantization.md) requires using custom layers, but can quantize weights and activations to low bit-widths. Concrete ML works with [Brevitas](../explanations/inner-workings/external_libraries.md#brevitas), a library providing QAT support for PyTorch. To use this mode, compile models using `compile_brevitas_qat_model` +- **Post-training Quantization**: This mode allows a vanilla PyTorch model to be compiled. However, when quantizing weights & activations to fewer than 7 bits, the accuracy can decrease strongly. On the other hand, depending on the model size, quantizing with 6-8 bits can be incompatible with FHE constraints. To use this mode, compile models with `compile_torch_model`. + +Both approaches require the `rounding_threshold_bits` parameter to be set accordingly. The best values for this parameter need to be determined through experimentation. A good initial value to try is `6`. See [here](../explanations/advanced_features.md#rounded-activations-and-quantizers) for more details. {% hint style="info" %} -Converting neural networks to use FHE can be done with `compile_brevitas_qat_model` or with `compile_torch_model` for post-training quantization. If the model can not be converted to FHE two types of errors can be raised: (1) crypto-parameters can not be found and, (2) table look-up bit-width limit is exceeded. See the [debugging section](fhe_assistant.md#compilation-error-debugging) if you encounter these errors. +**See the [common compilation errors page](./fhe_assistant.md#common-compilation-errors) for an explanation of some error messages that the compilation function may raise.** {% endhint %} +## Quantization-aware training + +The following example uses a simple QAT PyTorch model that implements a fully connected neural network with two hidden layers. Due to its small size, making this model respect FHE constraints is relatively easy. To use QAT, Brevitas `QuantIdentity` nodes must be inserted in the PyTorch model, including one that quantizes the input of the `forward` function. + ```python import brevitas.nn as qnn import torch.nn as nn @@ -51,38 +58,60 @@ torch_model = QATSimpleNet(30) quantized_module = compile_brevitas_qat_model( torch_model, # our model torch_input, # a representative input-set to be used for both quantization and compilation + rounding_threshold_bits={"n_bits": 6, "method": "approximate"} ) ``` -## Configuring quantization parameters +{% hint style="warning" %} +If `QuantIdentity` layers are missing for any input or intermediate value, the compile function will raise an error. See the [common compilation errors page](./fhe_assistant.md#common-compilation-errors) for an explanation. +{% endhint %} -The PyTorch/Brevitas models, created following the example above, require the user to configure quantization parameters such as `bit_width` (activation bit-width) and `weight_bit_width`. The quantization parameters, along with the number of neurons on each layer, will determine the accumulator bit-width of the network. Larger accumulator bit-widths result in higher accuracy but slower FHE inference time. +## Post-training quantization -The following configurations were determined through experimentation for convolutional and dense layers. +The following example uses a simple PyTorch model that implements a fully connected neural network with two hidden layers. The model is compiled to use FHE using `compile_torch_model`. + +```python +import torch.nn as nn +import torch + +N_FEAT = 12 +n_bits = 6 + +class PTQSimpleNet(nn.Module): + def __init__(self, n_hidden): + super().__init__() + + self.fc1 = nn.Linear(N_FEAT, n_hidden) + self.fc2 = nn.Linear(n_hidden, n_hidden) + self.fc3 = nn.Linear(n_hidden, 2) + + def forward(self, x): + x = torch.relu(self.fc1(x)) + x = torch.relu(self.fc2(x)) + x = self.fc3(x) + return x + +from concrete.ml.torch.compile import compile_torch_model +import numpy -| target accumulator bit-width | activation bit-width | weight bit-width | number of active neurons | -| ---------------------------- | -------------------- | ---------------- | ------------------------ | -| 8 | 3 | 3 | 80 | -| 10 | 4 | 3 | 90 | -| 12 | 5 | 5 | 110 | -| 14 | 6 | 6 | 110 | -| 16 | 7 | 6 | 120 | +torch_input = torch.randn(100, N_FEAT) +torch_model = PTQSimpleNet(5) +quantized_module = compile_torch_model( + torch_model, # our model + torch_input, # a representative input-set to be used for both quantization and compilation + n_bits=6, + rounding_threshold_bits={"n_bits": 6, "method": "approximate"} +) +``` -Using the templates above, the probability of obtaining the target accumulator bit-width, for a single layer, was determined experimentally by training 10 models for each of the following data-sets. +## Configuring quantization parameters -|
probability of obtaining
the accumulator bit-width
accuracy for target
accumulator bit-width