Skip to content

Commit

Permalink
docs: improve torch support explanations
Browse files Browse the repository at this point in the history
  • Loading branch information
andrei-stoian-zama committed Jun 13, 2024
1 parent 346f997 commit becad10
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 64 deletions.
39 changes: 29 additions & 10 deletions docs/deep-learning/fhe_assistant.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,33 +52,52 @@ concrete_clf.fit(X, y)

concrete_clf.compile(X, debug_config)
```

## Common compilation errors

The most common compilation errors stem from the following causes:

1. TLU input maximum bit-width is exceeded
#### 1. TLU input maximum bit-width is exceeded

This error can occur when `rounding_threshold_bits` is not used and accumulated intermediate values in the computation exceed 16-bits. The most common approaches to fix this issue are:
- Reduce quantization `n_bits`. However, this may reduce accuracy. When quantization `n_bits` must be below 6, it is best to use [Quantization Aware Training](../deep-learning/fhe_friendly_models.md).
- Use `rounding_threshold_bits`. This feature is described [here](../explanations/advanced_features.md#rounded-activations-and-quantizers). It is recommended to use the [`fhe.Exactness.APPROXIMATE`](../references/api/concrete.ml.torch.compile.md#function-compile_torch_model) setting, and set the rounding bits to 1 or 2 bits higher than the quantization `n_bits`
- Use [pruning](../explanations/pruning.md)

1. No crypto-parameters can be found for the ML model: `RuntimeError: NoParametersFound` is raised by the compiler
- Reduce quantization `n_bits`. However, this may reduce accuracy. When quantization `n_bits` must be below 6, it is best to use [Quantization Aware Training](../deep-learning/fhe_friendly_models.md).
- Use `rounding_threshold_bits`. This feature is described [here](../explanations/advanced_features.md#rounded-activations-and-quantizers). It is recommended to use the [`fhe.Exactness.APPROXIMATE`](../references/api/concrete.ml.torch.compile.md#function-compile_torch_model) setting, and set the rounding bits to 1 or 2 bits higher than the quantization `n_bits`
- Use [pruning](../explanations/pruning.md)

#### 2. No crypto-parameters can be found for the ML model: `RuntimeError: NoParametersFound` is raised by the compiler

This error occurs when using `rounding_threshold_bits` in the `compile_torch_model` function. The solutions in this case are similar to the ones for the previous error.

1. Quantization failed with `Could not determine a unique scale for the quantization!`.
#### 3. Quantization import failed with

The error associated is `Error occurred during quantization aware training (QAT) import [...] Could not determine a unique scale for the quantization!`.

This error is a due to missing quantization operators in the model that is imported as a quantized aware training model. See [this guide](../deep-learning/fhe_friendly_models.md) for a guide on how to use Brevitas layers. This error message is generated when not all layers take inputs that are quantized through `QuantIdentity` layers.

A common example is related to the concatenation operator. Suppose two tensors `x` and `y` are produced by two layers and need to be concatenated:

<!--pytest-codeblocks:skip-->

```python
x = self.dense1(x)
y = self.dense2(y)
z = torch.cat([x, y])
```

This error is a related to the concatenation operator. When using quantization aware training with Brevitas the following approach will fix this error:
In the example above, the `x` and `y` layers need quantization before being concatenated. When using quantization aware training with Brevitas the following approach will fix this error:

1. Add a new `QuantIdentity` layer in your model. Suppose it is called `quant_concat`.
2. In the `forward` function, before concatenation of `x` and `y`, apply it to both tensors that are concatenated:
1. Add a new `QuantIdentity` layer in your model. Suppose it is called `quant_concat`.
1. In the `forward` function, before concatenation of `x` and `y`, apply it to both tensors that are concatenated:

<!--pytest-codeblocks:skip-->

```python
torch.cat([self.quant_concat(x), self.quant_concat(y)])
z = torch.cat([self.quant_concat(x), self.quant_concat(y)])
```

The usage of a common `Quantidentity` layer to quantize both tensors that are concatenated ensures that they have the same scale.

## Debugging compilation errors

Compilation errors due to FHE incompatible models, such as maximum bit-width exceeded or `NoParametersFound` can be debugged by examining the bit-widths associated with various intermediate values of the FHE computation.
Expand Down
99 changes: 46 additions & 53 deletions docs/deep-learning/torch_support.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,21 @@

In addition to the built-in models, Concrete ML supports generic machine learning models implemented with Torch, or [exported as ONNX graphs](onnx_support.md).

As [Quantization Aware Training (QAT)](../explanations/quantization.md) is the most appropriate method of training neural networks that are compatible with [FHE constraints](../getting-started/concepts.md#model-accuracy-considerations-under-fhe-constraints), Concrete ML works with [Brevitas](../explanations/inner-workings/external_libraries.md#brevitas), a library providing QAT support for PyTorch.
There are two approaches to build [FHE-compatible deep networks](../getting-started/concepts.md#model-accuracy-considerations-under-fhe-constraints):

The following example uses a simple QAT PyTorch model that implements a fully connected neural network with two hidden layers. Due to its small size, making this model respect FHE constraints is relatively easy.
- [Quantization Aware Training (QAT)](../explanations/quantization.md) requires using custom layers, but can quantize weights and activations to low bit-widths. Concrete ML works with [Brevitas](../explanations/inner-workings/external_libraries.md#brevitas), a library providing QAT support for PyTorch. To use this mode, compile models using `compile_brevitas_qat_model`
- Post-training Quantization: in this mode a vanilla PyTorch model can be compiled. However, when quantizing weights & activations to fewer than 7 bits the accuracy can decrease strongly. To use this mode, compile models with `compile_torch_model`.

Both approaches should be used with the `rounding_threshold_bits` parameter set accordingly. The best values for this parameter need to be determined through experimentation. A good initial value to try is `6`. See [here](../explanations/advanced_features.md#rounded-activations-and-quantizers) for more details.

{% hint style="info" %}
Converting neural networks to use FHE can be done with `compile_brevitas_qat_model` or with `compile_torch_model` for post-training quantization. If the model can not be converted to FHE two types of errors can be raised: (1) crypto-parameters can not be found and, (2) table look-up bit-width limit is exceeded. See the [debugging section](fhe_assistant.md#compilation-error-debugging) if you encounter these errors.
**See the [common compilation errors page](./fhe_assistant.md#common-compilation-errors) for an explanation of some error messages that the compilation function may raise.**
{% endhint %}

## Quantization-aware training

The following example uses a simple QAT PyTorch model that implements a fully connected neural network with two hidden layers. Due to its small size, making this model respect FHE constraints is relatively easy.

```python
import brevitas.nn as qnn
import torch.nn as nn
Expand Down Expand Up @@ -51,38 +58,52 @@ torch_model = QATSimpleNet(30)
quantized_module = compile_brevitas_qat_model(
torch_model, # our model
torch_input, # a representative input-set to be used for both quantization and compilation
rounding_threshold_bits={"n_bits": 6, "method": "approximate"}
)

```

## Configuring quantization parameters
## Post-training quantization

The PyTorch/Brevitas models, created following the example above, require the user to configure quantization parameters such as `bit_width` (activation bit-width) and `weight_bit_width`. The quantization parameters, along with the number of neurons on each layer, will determine the accumulator bit-width of the network. Larger accumulator bit-widths result in higher accuracy but slower FHE inference time.
The following example uses a simple PyTorch model that implements a fully connected neural network with two hidden layers. The model is compiled to use FHE using `compile_torch_model`.

```python
import torch.nn as nn
import torch

The following configurations were determined through experimentation for convolutional and dense layers.
N_FEAT = 12
n_bits = 6

| target accumulator bit-width | activation bit-width | weight bit-width | number of active neurons |
| ---------------------------- | -------------------- | ---------------- | ------------------------ |
| 8 | 3 | 3 | 80 |
| 10 | 4 | 3 | 90 |
| 12 | 5 | 5 | 110 |
| 14 | 6 | 6 | 110 |
| 16 | 7 | 6 | 120 |
class PTQSimpleNet(nn.Module):
def __init__(self, n_hidden):
super().__init__()

Using the templates above, the probability of obtaining the target accumulator bit-width, for a single layer, was determined experimentally by training 10 models for each of the following data-sets.
self.fc1 = nn.Linear(N_FEAT, n_hidden)
self.fc2 = nn.Linear(n_hidden, n_hidden)
self.fc3 = nn.Linear(n_hidden, 2)

| <p>probability of obtaining<br>the accumulator bit-width</p> | 8 | 10 | 12 | 14 | 16 |
| ------------------------------------------------------------ | --- | ---- | --- | --- | ---- |
| mnist,fashion | 72% | 100% | 72% | 85% | 100% |
| cifar10 | 88% | 88% | 75% | 75% | 88% |
| cifar100 | 73% | 88% | 61% | 66% | 100% |
def forward(self, x):
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
x = self.fc3(x)
return x

from concrete.ml.torch.compile import compile_torch_model
import numpy

torch_input = torch.randn(100, N_FEAT)
torch_model = PTQSimpleNet(5)
quantized_module = compile_torch_model(
torch_model, # our model
torch_input, # a representative input-set to be used for both quantization and compilation
n_bits=6,
rounding_threshold_bits={"n_bits": 6, "method": "approximate"}
)
```

Note that the accuracy on larger data-sets, when the accumulator size is low, is also reduced strongly.
## Configuring quantization parameters

| <p>accuracy for target<br>accumulator bit-width</p> | 8 | 10 | 12 | 14 | 16 |
| --------------------------------------------------- | --- | --- | --- | --- | --- |
| cifar10 | 20% | 37% | 89% | 90% | 90% |
| cifar100 | 6% | 30% | 67% | 69% | 69% |
The PyTorch/Brevitas models, created following the example above, require the user to configure quantization parameters such as `bit_width` (activation bit-width) and `weight_bit_width`. The quantization parameters, along with the number of neurons on each layer, will determine the accumulator bit-width of the network. Larger accumulator bit-widths result in higher accuracy but slower FHE inference time.

## Running encrypted inference

Expand All @@ -100,7 +121,7 @@ In this example, the input values `x_test` and the predicted values `y_pred` are

## Simulated FHE Inference in the clear

The user can also perform the inference on clear data. Two approaches exist:
One can perform the inference on clear data in order to evaluate the impact of quantization and of FHE computation on the accuracy of their model. See [this section](../deep-learning/fhe_assistant.md#simulation) for more details. Two approaches exist:

- `quantized_module.forward(quantized_x, fhe="simulate")`: simulates FHE execution taking into account Table Lookup errors.\
De-quantization must be done in a second step as for actual FHE execution. Simulation takes into account the `p_error`/`global_p_error` parameters
Expand All @@ -110,34 +131,6 @@ The user can also perform the inference on clear data. Two approaches exist:
FHE simulation allows to measure the impact of the Table Lookup error on the model accuracy. The Table Lookup error can be adjusted using `p_error`/`global_p_error`, as described in the [approximate computation ](../explanations/advanced_features.md#approximate-computations)section.
{% endhint %}

## Generic Quantization Aware Training import

While the example above shows how to import a Brevitas/PyTorch model, Concrete ML also provides an option to import generic QAT models implemented in PyTorch or through ONNX. Deep learning models made with TensorFlow or Keras should be usable by preliminary converting them to ONNX.

QAT models contain quantizers in the PyTorch graph. These quantizers ensure that the inputs to the Linear/Dense and Conv layers are quantized.

Suppose that `n_bits_qat` is the bit-width of activations and weights during the QAT process. To import a PyTorch QAT network, you can use the [`compile_torch_model`](../references/api/concrete.ml.torch.compile.md#function-compile_torch_model) library function, passing `import_qat=True`:

<!--pytest-codeblocks:skip-->

```python
from concrete.ml.torch.compile import compile_torch_model
n_bits_qat = 3

quantized_module = compile_torch_model(
torch_model,
torch_input,
import_qat=True,
n_bits=n_bits_qat,
)
```

Alternatively, if you want to import an ONNX model directly, please see [the ONNX guide](onnx_support.md). The [`compile_onnx_model`](../references/api/concrete.ml.torch.compile.md#function-compile_onnx_model) also supports the `import_qat` parameter.

{% hint style="warning" %}
When importing QAT models using this generic pipeline, a representative calibration set should be given as quantization parameters in the model need to be inferred from the statistics of the values encountered during inference.
{% endhint %}

## Supported operators and activations

Concrete ML supports a variety of PyTorch operators that can be used to build fully connected or convolutional neural networks, with normalization and activation layers. Moreover, many element-wise operators are supported.
Expand Down
4 changes: 3 additions & 1 deletion src/concrete/ml/onnx/ops_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@


class RawOpOutput(numpy.ndarray):
"""Type construct that marks an ndarray as a raw output of a quantized op."""
"""Type construct that marks an ndarray as a raw output of a quantized op.
A raw output is an output that is a clear constant such as a shape, a constant float, an index..
"""


# This function is only used for comparison operators that return boolean values by default.
Expand Down

0 comments on commit becad10

Please sign in to comment.