diff --git a/.gitleaksignore b/.gitleaksignore index 66e5189fe..8f212aff2 100644 --- a/.gitleaksignore +++ b/.gitleaksignore @@ -4,3 +4,4 @@ 2d3b4ca188efb338c03d8d2c921ef39ffc5537e3:tests/deployment/test_deployment.py:generic-api-key:59 198d3fef188aaf3e3a582b9f7943f7ac6e9b5186:tests/deployment/test_deployment.py:generic-api-key:59 5abc7e86bb192e1f9f829bb2f22173c9d663e1d1:use_case_examples/credit_scoring/CreditScoringWithGraphics.ipynb:easypost-test-api-token:1414 +e2904473898ddd325f245f4faca526a0e9520f49:builders/Dockerfile.zamalang-env:generic-api-key:5 \ No newline at end of file diff --git a/src/concrete/ml/pytest/torch_models.py b/src/concrete/ml/pytest/torch_models.py index 2963ebe22..a9c14140f 100644 --- a/src/concrete/ml/pytest/torch_models.py +++ b/src/concrete/ml/pytest/torch_models.py @@ -1488,3 +1488,24 @@ def predict(x, weights, bias): outputs = torch.sigmoid(torch.bmm(x, weights_expanded) + bias_expanded) return outputs.squeeze() + + +class AddNet(nn.Module): + """Torch model that performs a simple addition between two inputs.""" + + def __init__(self, use_conv, use_qat, input_output, n_bits): # pylint: disable=unused-argument + super().__init__() + # No initialization needed for simple addition + + @staticmethod + def forward(x, y): + """Forward pass. + + Args: + x: First input tensor. + y: Second input tensor. + + Returns: + Result of adding x and y. + """ + return x + y diff --git a/src/concrete/ml/quantization/post_training.py b/src/concrete/ml/quantization/post_training.py index 8d15b41ac..9389ab05f 100644 --- a/src/concrete/ml/quantization/post_training.py +++ b/src/concrete/ml/quantization/post_training.py @@ -423,10 +423,6 @@ def _quantize_layers(self, *input_calibration_data: numpy.ndarray): quantized_op_class = ONNX_OPS_TO_QUANTIZED_IMPL[op_type] - # Add rounding_threshold_bits to the attributes if available in quantized_op_class - if issubclass(quantized_op_class, QuantizedMixingOp): - attributes.update({"rounding_threshold_bits": self.rounding_threshold_bits}) - # All inputs, allow optional constants (they become None) # Note that input of a node can be duplicated, e.g., (%a, %a, %b) curr_inputs = [ @@ -479,6 +475,12 @@ def _quantize_layers(self, *input_calibration_data: numpy.ndarray): # If we depend on a variable input use the quantized version of the operator if has_variable_inputs: + # Add rounding_threshold_bits to the attributes if available in quantized_op_class + # rounding_thresholds_bits only applies to QuantizedOp for now so we can't use them + # if we use the original operator on float (ops_impl.py) + if issubclass(quantized_op_class, QuantizedMixingOp): + attributes.update({"rounding_threshold_bits": self.rounding_threshold_bits}) + assert_true( op_type in ONNX_OPS_TO_QUANTIZED_IMPL, f"{op_type} can't be found in {ONNX_OPS_TO_QUANTIZED_IMPL}", diff --git a/src/concrete/ml/quantization/quantized_ops.py b/src/concrete/ml/quantization/quantized_ops.py index 67478ef72..09590550b 100644 --- a/src/concrete/ml/quantization/quantized_ops.py +++ b/src/concrete/ml/quantization/quantized_ops.py @@ -26,7 +26,12 @@ QuantizedOp, QuantizedOpUnivariateOfEncrypted, ) -from .quantizers import QuantizationOptions, QuantizedArray, UniformQuantizationParameters +from .quantizers import ( + QuantizationOptions, + QuantizedArray, + UniformQuantizationParameters, + UniformQuantizer, +) def _check_op_input_zero_point(zero_point: Any, op_name: Optional[str]): @@ -492,7 +497,7 @@ class QuantizedMatMul(QuantizedGemm): _impl_for_op_named: str = "MatMul" -class QuantizedAdd(QuantizedOp): +class QuantizedAdd(QuantizedMixingOp): """Quantized Addition operator. Can add either two variables (both encrypted) or a variable and a constant @@ -554,22 +559,32 @@ def q_impl( assert q_input_1.quantizer.scale is not None assert q_input_1.quantizer.zero_point is not None - # De-quantize with input params and re-quantize with output parameters - # This will use TLUs over each element of the two inputs - # We do the de-quantization directly, instead of q_inputs[0].dequant(), - # So that we do not lose precision in the computation + # Dequantize + input_0 = q_input_0.dequant() + input_1 = q_input_1.dequant() - rescale_q0 = numpy.rint( - q_input_0.quantizer.scale - / self.output_quant_params.scale - * (q_input_0.qvalues + (-q_input_0.quantizer.zero_point)) - ).astype(numpy.int64) + # If this operator is the last one in the graph, + # we rescale using the smallest scale to keep all information + if self.produces_graph_output: + common_scale = min(q_input_0.quantizer.scale, q_input_1.quantizer.scale) + # Otherwise we use the output op quantization scale + else: + common_scale = self.output_quant_params.scale - rescale_q1 = numpy.rint( - q_input_1.quantizer.scale - / self.output_quant_params.scale - * (q_input_1.qvalues + (-q_input_1.quantizer.zero_point)) - ).astype(numpy.int64) + common_zero_point = 0 + offset = 0 + + output_quant_params = UniformQuantizationParameters( + scale=common_scale, + zero_point=common_zero_point, + offset=offset, + ) + + quantizer = UniformQuantizer(params=output_quant_params, no_clipping=True) + + # Re-quantize using the common quantization paramaters + q_input_0_rescaled = quantizer.quant(input_0) + q_input_1_rescaled = quantizer.quant(input_1) # The sum of quantized encrypted integer values # This sum has << max(in_bits0, in_bits1) + 1 >> bits @@ -580,12 +595,15 @@ def q_impl( # sum_q = rescale_q0 + self.b_sign * rescale_q1 # when zama-ai/concrete-numpy-internal#1749 is done if self.b_sign == 1: - sum_q = rescale_q0 + rescale_q1 + sum_q = q_input_0_rescaled + q_input_1_rescaled elif self.b_sign == -1: - sum_q = rescale_q0 - rescale_q1 + sum_q = q_input_0_rescaled - q_input_1_rescaled + + if self.produces_graph_output: + return self.make_output_quant_parameters(sum_q, common_scale, common_zero_point) # But we would like the output to have n_bits, so we de-quantize - dequant_sum = self.output_quant_params.scale * sum_q + dequant_sum = quantizer.dequant(sum_q) # Return the raw float values without re-quantizing them to the new scale, as any # following Gemm/Add/Conv will quantize them with _prepare_inputs_with_constants(...) diff --git a/tests/torch/test_compile_torch.py b/tests/torch/test_compile_torch.py index aad6f9199..592301208 100644 --- a/tests/torch/test_compile_torch.py +++ b/tests/torch/test_compile_torch.py @@ -20,6 +20,7 @@ from concrete.ml.onnx.convert import OPSET_VERSION_FOR_ONNX_EXPORT from concrete.ml.pytest.torch_models import ( FC, + AddNet, BranchingGemmModule, BranchingModule, CNNGrouped, @@ -1038,6 +1039,7 @@ def __init__(self, input_output, activation_function): (MultiInputNNConfigurable, (1, 8, 8), 2, False), (DoubleQuantQATMixNet, (1, 8, 8), 1, False), (DoubleQuantQATMixNet, 10, 1, False), + (AddNet, 10, 2, False), ], ) def test_net_has_no_tlu(