Skip to content

Commit

Permalink
feat: add new approx rounding
Browse files Browse the repository at this point in the history
  • Loading branch information
jfrery authored Apr 2, 2024
1 parent 45e491d commit 9ef890e
Show file tree
Hide file tree
Showing 10 changed files with 261 additions and 180 deletions.
3 changes: 2 additions & 1 deletion conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,8 @@ def default_configuration():
enable_unsafe_features=True,
use_insecure_key_cache=True,
insecure_key_cache_location="ConcreteNumpyKeyCache",
fhe_simulation=False, # Simulation compilation is done lazilly on circuit.simulate
# Simulation compilation is done lazily on circuit.simulate
fhe_simulation=False,
fhe_execution=True,
)

Expand Down
2 changes: 0 additions & 2 deletions src/concrete/ml/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,6 @@

# Indicate if the old virtual library method should be used instead of the compiler simulation
# when simulating FHE executions
# Set 'USE_OLD_VL' to False by default once the new simulation is fixed
# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/4091
USE_OLD_VL = False

# Debug option for testing round PBS optimization
Expand Down
45 changes: 37 additions & 8 deletions src/concrete/ml/quantization/base_quantized_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
UniformQuantizationParameters,
)

# pylint: disable=too-many-lines

ONNXOpInputOutputType = Union[
numpy.ndarray,
QuantizedArray,
Expand Down Expand Up @@ -873,13 +875,22 @@ class QuantizedMixingOp(QuantizedOp, is_utility=True):
"""

lsbs_to_remove: Optional[Union[int, dict]] = None
rounding_threshold_bits: Optional[int] = None
rounding_threshold_bits: Union[None, int, Dict[str, Union[str, int]]] = None

def __init__(self, *args, rounding_threshold_bits: Optional[int] = None, **kwargs) -> None:
def __init__(
self,
*args,
rounding_threshold_bits: Union[None, int, Dict[str, Union[str, int]]] = None,
**kwargs,
) -> None:
"""Initialize quantized ops parameters plus specific parameters.
Args:
rounding_threshold_bits (Optional[int]): Number of bits to round to.
rounding_threshold_bits (Union[None, int, Dict[str, Union[str, int]]]): if not None,
every accumulators in the model are rounded down to the given bits of precision.
Can be an int or a dictionary with keys 'method' and 'n_bits', where 'method' is
either fhe.Exactness.EXACT or fhe.Exactness.APPROXIMATE, and 'n_bits' is either
'auto' or an int.
*args: positional argument to pass to the parent class.
**kwargs: named argument to pass to the parent class.
"""
Expand Down Expand Up @@ -959,15 +970,27 @@ def cnp_round(
Returns:
numpy.ndarray: The rounded array.
"""

# Ensure lsbs_to_remove is initialized as a dictionary
if not hasattr(self, "lsbs_to_remove") or not isinstance(self.lsbs_to_remove, dict):
self.lsbs_to_remove = {}

if self.rounding_threshold_bits is not None and calibrate_rounding:
n_bits = None
exactness = fhe.Exactness.EXACT

if isinstance(self.rounding_threshold_bits, dict):
n_bits = self.rounding_threshold_bits.get("n_bits", None)
exactness = self.rounding_threshold_bits.get("method", exactness)
# PoT is replacing inplace the rounding_threshold_bits to an int
elif isinstance(self.rounding_threshold_bits, int):
n_bits = self.rounding_threshold_bits

if n_bits is not None and calibrate_rounding:
# Compute lsbs_to_remove only when calibration is True
current_n_bits_accumulator = compute_bits_precision(x)
computed_lsbs_to_remove = current_n_bits_accumulator - self.rounding_threshold_bits

# mypy
assert isinstance(n_bits, int)
computed_lsbs_to_remove = current_n_bits_accumulator - n_bits

assert_true(
not isinstance(x, fhe.tracing.Tracer),
Expand All @@ -987,6 +1010,12 @@ def cnp_round(
assert isinstance(lsbs_value, int)

if lsbs_value > 0:
x = fhe.round_bit_pattern(x, lsbs_to_remove=lsbs_value)

# Rounding to low bit-width with approximate can cause issues with overflow protection
# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/4345
if exactness == fhe.Exactness.APPROXIMATE:
x = fhe.round_bit_pattern(
x, lsbs_to_remove=lsbs_value, exactness=exactness, overflow_protection=False
)
else:
x = fhe.round_bit_pattern(x, lsbs_to_remove=lsbs_value, exactness=exactness)
return x
18 changes: 12 additions & 6 deletions src/concrete/ml/quantization/post_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,21 +212,23 @@ class ONNXConverter:
of the network's inputs. "op_inputs" and "op_weights" both control the quantization for
inputs and weights of all layers.
numpy_model (NumpyModule): Model in numpy.
rounding_threshold_bits (int): if not None, every accumulators in the model are rounded down
to the given bits of precision
rounding_threshold_bits (Union[None, int, Dict[str, Union[str, int]]]): Defines precision
rounding for model accumulators. Accepts None, an int, or a dict.
The dict can specify 'method' (fhe.Exactness.EXACT or fhe.Exactness.APPROXIMATE)
and 'n_bits' ('auto' or int)
"""

quant_ops_dict: Dict[str, Tuple[Tuple[str, ...], QuantizedOp]]
n_bits: Dict[str, int]
quant_params: Dict[str, numpy.ndarray]
numpy_model: NumpyModule
rounding_threshold_bits: Optional[int]
rounding_threshold_bits: Union[None, int, Dict[str, Union[str, int]]]

def __init__(
self,
n_bits: Union[int, Dict],
numpy_model: NumpyModule,
rounding_threshold_bits: Optional[int] = None,
rounding_threshold_bits: Union[None, int, Dict[str, Union[str, int]]] = None,
):
self.quant_ops_dict = {}

Expand Down Expand Up @@ -834,8 +836,12 @@ class PostTrainingAffineQuantization(ONNXConverter):
- op_weights: learned parameters or constants in the network
- model_outputs: final model output quantization bits
numpy_model (NumpyModule): Model in numpy.
rounding_threshold_bits (int): if not None, every accumulators in the model are rounded down
to the given bits of precision
rounding_threshold_bits (Union[None, int, Dict[str, Union[str, int]]]): if not None, every
accumulators in the model are rounded down to the given
bits of precision. Can be an int or a dictionary with keys
'method' and 'n_bits', where 'method' is either
fhe.Exactness.EXACT or fhe.Exactness.APPROXIMATE, and
'n_bits' is either 'auto' or an int.
is_signed: Whether the weights of the layers can be signed.
Currently, only the weights can be signed.
Expand Down
69 changes: 53 additions & 16 deletions src/concrete/ml/torch/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from brevitas.export.onnx.qonnx.manager import QONNXManager as BrevitasONNXManager
from brevitas.nn.quant_layer import QuantInputOutputLayer as QNNMixingLayer
from brevitas.nn.quant_layer import QuantNonLinearActLayer as QNNUnivariateLayer
from concrete.fhe import ParameterSelectionStrategy
from concrete.fhe import Exactness, ParameterSelectionStrategy
from concrete.fhe.compilation.artifacts import DebugArtifacts
from concrete.fhe.compilation.configuration import Configuration

Expand Down Expand Up @@ -72,7 +72,7 @@ def build_quantized_module(
torch_inputset: Dataset,
import_qat: bool = False,
n_bits: Union[int, Dict[str, int]] = MAX_BITWIDTH_BACKWARD_COMPATIBLE,
rounding_threshold_bits: Optional[int] = None,
rounding_threshold_bits: Union[None, int, Dict[str, Union[str, int]]] = None,
reduce_sum_copy=False,
) -> QuantizedModule:
"""Build a quantized module from a Torch or ONNX model.
Expand All @@ -88,8 +88,10 @@ def build_quantized_module(
import_qat (bool): Flag to signal that the network being imported contains quantizers in
in its computation graph and that Concrete ML should not re-quantize it
n_bits: the number of bits for the quantization
rounding_threshold_bits (int): if not None, every accumulators in the model are rounded down
to the given bits of precision
rounding_threshold_bits (Union[None, int, Dict[str, Union[str, int]]]): Defines precision
rounding for model accumulators. Accepts None, an int, or a dict.
The dict can specify 'method' (fhe.Exactness.EXACT or fhe.Exactness.APPROXIMATE)
and 'n_bits' ('auto' or int)
reduce_sum_copy (bool): if the inputs of QuantizedReduceSum should be copied to avoid
bit-width propagation
Expand Down Expand Up @@ -135,7 +137,7 @@ def _compile_torch_or_onnx_model(
artifacts: Optional[DebugArtifacts] = None,
show_mlir: bool = False,
n_bits: Union[int, Dict[str, int]] = MAX_BITWIDTH_BACKWARD_COMPATIBLE,
rounding_threshold_bits: Optional[int] = None,
rounding_threshold_bits: Union[None, int, Dict[str, Union[str, int]]] = None,
p_error: Optional[float] = None,
global_p_error: Optional[float] = None,
verbose: bool = False,
Expand Down Expand Up @@ -164,8 +166,10 @@ def _compile_torch_or_onnx_model(
- "model_inputs" and "model_outputs" (optional, default to 5 bits).
When using a single integer for n_bits, its value is assigned to "op_inputs" and
"op_weights" bits. Default is 8 bits.
rounding_threshold_bits (int): if not None, every accumulators in the model are rounded down
to the given bits of precision
rounding_threshold_bits (Union[None, int, Dict[str, Union[str, int]]]): Defines precision
rounding for model accumulators. Accepts None, an int, or a dict.
The dict can specify 'method' (fhe.Exactness.EXACT or fhe.Exactness.APPROXIMATE)
and 'n_bits' ('auto' or int)
p_error (Optional[float]): probability of error of a single PBS
global_p_error (Optional[float]): probability of error of the full circuit. In FHE
simulation `global_p_error` is set to 0
Expand All @@ -177,7 +181,34 @@ def _compile_torch_or_onnx_model(
Returns:
QuantizedModule: The resulting compiled QuantizedModule.
Raises:
NotImplementedError: If 'auto' rounding is specified but not implemented.
ValueError: If an invalid type or value is provided for rounding_threshold_bits.
"""
n_bits_rounding: Union[None, str, int] = None
method: Exactness = Exactness.EXACT

# Only process if rounding_threshold_bits is not None
if rounding_threshold_bits is not None:
if isinstance(rounding_threshold_bits, int):
n_bits_rounding = rounding_threshold_bits
elif isinstance(rounding_threshold_bits, dict):
n_bits_rounding = rounding_threshold_bits.get("n_bits")
if n_bits_rounding == "auto":
raise NotImplementedError("Automatic rounding is not implemented yet.")
method_str = rounding_threshold_bits.get("method", method).upper()
if method_str in ["EXACT", "APPROXIMATE"]:
method = Exactness[method_str]
else:
raise ValueError(
f"{method_str} is not a valid method. Must be one of EXACT, APPROXIMATE."
)
else:
raise ValueError("Invalid type for rounding_threshold_bits. Must be int or dict.")

assert n_bits_rounding is not None, "n_bits_rounding cannot be None"
rounding_threshold_bits = {"n_bits": n_bits_rounding, "method": method}

inputset_as_numpy_tuple = tuple(
convert_torch_tensor_or_numpy_array_to_numpy_array(val) for val in to_tuple(torch_inputset)
Expand Down Expand Up @@ -234,7 +265,7 @@ def compile_torch_model(
artifacts: Optional[DebugArtifacts] = None,
show_mlir: bool = False,
n_bits: Union[int, Dict[str, int]] = MAX_BITWIDTH_BACKWARD_COMPATIBLE,
rounding_threshold_bits: Optional[int] = None,
rounding_threshold_bits: Union[None, int, Dict[str, Union[str, int]]] = None,
p_error: Optional[float] = None,
global_p_error: Optional[float] = None,
verbose: bool = False,
Expand Down Expand Up @@ -264,8 +295,10 @@ def compile_torch_model(
- "model_inputs" and "model_outputs" (optional, default to 5 bits).
When using a single integer for n_bits, its value is assigned to "op_inputs" and
"op_weights" bits. Default is 8 bits.
rounding_threshold_bits (int): if not None, every accumulators in the model are rounded down
to the given bits of precision
rounding_threshold_bits (Union[None, int, Dict[str, Union[str, int]]]): Defines precision
rounding for model accumulators. Accepts None, an int, or a dict.
The dict can specify 'method' (fhe.Exactness.EXACT or fhe.Exactness.APPROXIMATE)
and 'n_bits' ('auto' or int)
p_error (Optional[float]): probability of error of a single PBS
global_p_error (Optional[float]): probability of error of the full circuit. In FHE
simulation `global_p_error` is set to 0
Expand Down Expand Up @@ -316,7 +349,7 @@ def compile_onnx_model(
artifacts: Optional[DebugArtifacts] = None,
show_mlir: bool = False,
n_bits: Union[int, Dict[str, int]] = MAX_BITWIDTH_BACKWARD_COMPATIBLE,
rounding_threshold_bits: Optional[int] = None,
rounding_threshold_bits: Union[None, int, Dict[str, Union[str, int]]] = None,
p_error: Optional[float] = None,
global_p_error: Optional[float] = None,
verbose: bool = False,
Expand Down Expand Up @@ -346,8 +379,10 @@ def compile_onnx_model(
- "model_inputs" and "model_outputs" (optional, default to 5 bits).
When using a single integer for n_bits, its value is assigned to "op_inputs" and
"op_weights" bits. Default is 8 bits.
rounding_threshold_bits (int): if not None, every accumulators in the model are rounded down
to the given bits of precision
rounding_threshold_bits (Union[None, int, Dict[str, Union[str, int]]]): Defines precision
rounding for model accumulators. Accepts None, an int, or a dict.
The dict can specify 'method' (fhe.Exactness.EXACT or fhe.Exactness.APPROXIMATE)
and 'n_bits' ('auto' or int)
p_error (Optional[float]): probability of error of a single PBS
global_p_error (Optional[float]): probability of error of the full circuit. In FHE
simulation `global_p_error` is set to 0
Expand Down Expand Up @@ -393,7 +428,7 @@ def compile_brevitas_qat_model(
configuration: Optional[Configuration] = None,
artifacts: Optional[DebugArtifacts] = None,
show_mlir: bool = False,
rounding_threshold_bits: Optional[int] = None,
rounding_threshold_bits: Union[None, int, Dict[str, Union[str, int]]] = None,
p_error: Optional[float] = None,
global_p_error: Optional[float] = None,
output_onnx_file: Union[None, Path, str] = None,
Expand Down Expand Up @@ -424,8 +459,10 @@ def compile_brevitas_qat_model(
during compilation
show_mlir (bool): if set, the MLIR produced by the converter and which is going
to be sent to the compiler backend is shown on the screen, e.g., for debugging or demo
rounding_threshold_bits (int): if not None, every accumulators in the model are rounded down
to the given bits of precision
rounding_threshold_bits (Union[None, int, Dict[str, Union[str, int]]]): Defines precision
rounding for model accumulators. Accepts None, an int, or a dict.
The dict can specify 'method' (fhe.Exactness.EXACT or fhe.Exactness.APPROXIMATE)
and 'n_bits' ('auto' or int)
p_error (Optional[float]): probability of error of a single PBS
global_p_error (Optional[float]): probability of error of the full circuit. In FHE
simulation `global_p_error` is set to 0
Expand Down
55 changes: 27 additions & 28 deletions tests/sklearn/test_sklearn_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@
from concrete.ml.common.serialization.dumpers import dump, dumps
from concrete.ml.common.serialization.loaders import load, loads
from concrete.ml.common.utils import (
USE_OLD_VL,
array_allclose_and_same_shape,
get_model_class,
get_model_name,
Expand Down Expand Up @@ -1689,14 +1688,15 @@ def test_fitted_compiled_error_raises(
@pytest.mark.flaky
@pytest.mark.parametrize("model_class, parameters", MODELS_AND_DATASETS)
@pytest.mark.parametrize(
"error_param",
[{"p_error": 0.9999999999990905}], # 1 - 2**-40
ids=["p_error"],
"error_param, expected_diff",
[({"p_error": 1 - 2**-40}, True), ({"p_error": 2**-40}, False)],
ids=["p_error_high", "p_error_low"],
)
def test_p_error_simulation(
model_class,
parameters,
error_param,
expected_diff,
load_data,
is_weekly_option,
):
Expand All @@ -1711,17 +1711,12 @@ def test_p_error_simulation(
# Get data-set, initialize and fit the model
model, x = preamble(model_class, parameters, n_bits, load_data, is_weekly_option)

# Check if model is linear
is_linear_model = is_model_class_in_a_list(model_class, _get_sklearn_linear_models())

# Do not run the test for linear models since there is no PBS (i.e. p_error has no impact)
if is_linear_model:
pytest.skip("Linear models do not have PBS")

# Compile with a large p_error to be sure the result is random.
# Compile with the specified p_error.
model.compile(x, **error_param)

def check_for_divergent_predictions(x, model, fhe, max_iterations=N_ALLOWED_FHE_RUN):
def check_for_divergent_predictions(
x, model, fhe, max_iterations=N_ALLOWED_FHE_RUN, tolerance=1e-5
):
"""Detect divergence between simulated/FHE execution and clear run."""

# KNeighborsClassifier does not provide a predict_proba method for now
Expand All @@ -1736,31 +1731,35 @@ def check_for_divergent_predictions(x, model, fhe, max_iterations=N_ALLOWED_FHE_
y_expected = predict_function(x, fhe="disable")
for i in range(max_iterations):
y_pred = predict_function(x[i : i + 1], fhe=fhe).ravel()
if not numpy.array_equal(y_pred, y_expected[i : i + 1].ravel()):
if not numpy.allclose(y_pred, y_expected[i : i + 1].ravel(), atol=tolerance):
return True
return False

simulation_diff_found = check_for_divergent_predictions(x, model, fhe="simulate")
fhe_diff_found = check_for_divergent_predictions(x, model, fhe="execute")

# Check for differences in predictions
# Remark that, with the old VL, linear models (or, more generally, circuits without PBS) were
# badly simulated. It has been fixed in the new simulation.
if is_linear_model and USE_OLD_VL:

# In FHE, high p_error affect the crypto parameters which
# makes the predictions slightly different
assert fhe_diff_found, "FHE predictions should be different for linear models"
# Check if model is linear
is_linear_model = is_model_class_in_a_list(model_class, _get_sklearn_linear_models())

# linear models p_error is not simulated
assert not simulation_diff_found, "SIMULATE predictions not the same for linear models"
# Skip the following if model is linear
# Simulation and FHE differs with very high p_error on leveled circuit
# FIXME https://github.com/zama-ai/concrete-ml-internal/issues/4343
if is_linear_model:
pytest.skip("Skipping test for linear models")

# Check for differences in predictions based on expected_diff
if expected_diff:
assert_msg = (
"With high p_error, predictions should differ in both FHE and simulation."
f" Found differences: FHE={fhe_diff_found}, Simulation={simulation_diff_found}"
)
assert fhe_diff_found and simulation_diff_found, assert_msg
else:
assert fhe_diff_found and simulation_diff_found, (
f"Predictions not different in at least one run.\n"
f"FHE predictions differ: {fhe_diff_found}\n"
f"SIMULATE predictions differ: {simulation_diff_found}"
assert_msg = (
"With low p_error, predictions should not differ in FHE or simulation."
f" Found differences: FHE={fhe_diff_found}, Simulation={simulation_diff_found}"
)
assert not (fhe_diff_found or simulation_diff_found), assert_msg


# This test is only relevant for classifier models
Expand Down
Loading

0 comments on commit 9ef890e

Please sign in to comment.