From 296506612145865146d7f9cd34c39d2a535785de Mon Sep 17 00:00:00 2001 From: Roman Bredehoft Date: Mon, 5 Feb 2024 14:16:25 +0100 Subject: [PATCH] feat: support conv1d operator --- README.md | 1 - src/concrete/ml/onnx/ops_impl.py | 32 ++++++++- src/concrete/ml/pytest/torch_models.py | 25 +++++++ src/concrete/ml/quantization/quantized_ops.py | 70 +++++++++++++------ tests/quantization/test_quantized_ops.py | 53 +++++++++----- tests/torch/test_compile_torch.py | 11 +-- 6 files changed, 148 insertions(+), 44 deletions(-) diff --git a/README.md b/README.md index 73ff40347..133f1e248 100644 --- a/README.md +++ b/README.md @@ -193,7 +193,6 @@ To cite Concrete ML, notably in academic papers, please use the following entry, - ## License. This software is distributed under the BSD-3-Clause-Clear license. If you have any questions, please contact us at hello@zama.ai. diff --git a/src/concrete/ml/onnx/ops_impl.py b/src/concrete/ml/onnx/ops_impl.py index ed0db0be3..d2d5f98bc 100644 --- a/src/concrete/ml/onnx/ops_impl.py +++ b/src/concrete/ml/onnx/ops_impl.py @@ -1279,7 +1279,10 @@ def numpy_conv( """ # Convert the inputs to tensors to compute conv using torch - assert_true(len(kernel_shape) == 2, "The convolution operator currently supports only 2-d") + assert_true( + len(kernel_shape) in (1, 2), + f"The convolution operator currently only supports 1d or 2d. Got {len(kernel_shape)}-d", + ) assert_true( bool(numpy.all(numpy.asarray(dilations) == 1)), "The convolution operator in Concrete does not support dilation", @@ -1300,8 +1303,33 @@ def numpy_conv( # Pad the input if needed x_pad = numpy_onnx_pad(x, pads) + is_conv1d = len(kernel_shape) == 1 + + # Workaround for handling torch's Conv1d operator until it is supported by Concrete Python + # FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/41 + if is_conv1d: + x_pad = numpy.expand_dims(x_pad, axis=-2) + w = numpy.expand_dims(w, axis=-2) + kernel_shape = (1, kernel_shape[0]) + strides = (1, strides[0]) + dilations = (1, dilations[0]) + # Compute the torch convolution - res = fhe_conv(x_pad, w, b, None, strides, dilations, None, group) + res = fhe_conv( + x=x_pad, + weight=w, + bias=b, + pads=None, + strides=strides, + dilations=dilations, + kernel_shape=kernel_shape, + group=group, + ) + + # Workaround for handling torch's Conv1d operator until it is supported by Concrete Python + # FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/41 + if is_conv1d: + res = numpy.squeeze(res, axis=-2) return (res,) diff --git a/src/concrete/ml/pytest/torch_models.py b/src/concrete/ml/pytest/torch_models.py index 1be176962..08e0d834f 100644 --- a/src/concrete/ml/pytest/torch_models.py +++ b/src/concrete/ml/pytest/torch_models.py @@ -1546,3 +1546,28 @@ def forward(self, x): x = self.input_quant(x) x = x.reshape(x.shape + (1,)) return x.expand(x.shape[:-1] + (4,)) + + +class Conv1dModel(nn.Module): + """Small model that uses a 1D convolution operator.""" + + def __init__(self, input_output, activation_function) -> None: + super().__init__() + + self.conv1 = nn.Conv1d(input_output, 2, 2, stride=1, padding=0) + self.act = activation_function() + self.fc1 = nn.Linear(input_output, 3) + + def forward(self, x): + """Forward pass. + + Args: + x (torch.Tensor): The model's input. + + Returns: + torch.Tensor: The model's output. + + """ + x = self.act(self.conv1(x)) + x = self.fc1(x) + return x diff --git a/src/concrete/ml/quantization/quantized_ops.py b/src/concrete/ml/quantization/quantized_ops.py index eb1ac1516..09427d941 100644 --- a/src/concrete/ml/quantization/quantized_ops.py +++ b/src/concrete/ml/quantization/quantized_ops.py @@ -777,8 +777,9 @@ def __init__( # Validate the parameters assert_true( - len(self.kernel_shape) == 2, - "The convolution operator currently supports only 2d", + len(self.kernel_shape) in (1, 2), + "The convolution operator currently only supports 1d or 2d. " + f"Got {len(self.kernel_shape)}-d", ) assert_true( len(self.kernel_shape) == len(self.strides), @@ -787,7 +788,7 @@ def __init__( ) assert_true( bool(numpy.all(numpy.asarray(self.dilations) == 1)), - "The convolution operator in Concrete does not suppport dilation", + "The convolution operator in Concrete does not support dilation", ) assert_true( len(self.pads) == 2 * len(self.kernel_shape), @@ -796,7 +797,7 @@ def __init__( " standard", ) - # pylint: disable-next=too-many-statements + # pylint: disable-next=too-many-statements, too-many-locals def q_impl( self, *q_inputs: ONNXOpInputOutputType, @@ -847,9 +848,6 @@ def q_impl( f"group ({self.group}).", ) - # Prepare a constant tensor to compute the sum of the inputs - q_weights_1 = numpy.ones_like(q_weights.qvalues) - assert q_weights.quantizer.scale is not None assert q_weights.quantizer.zero_point is not None @@ -862,6 +860,25 @@ def q_impl( pad_value = int(q_input.quantizer.zero_point) q_input_pad = numpy_onnx_pad(q_input.qvalues, self.pads, pad_value, True) + is_conv1d = len(self.kernel_shape) == 1 + + q_weights_values = q_weights.qvalues + kernel_shape = self.kernel_shape + strides = self.strides + dilations = self.dilations + + # Workaround for handling torch's Conv1d operator until it is supported by Concrete Python + # FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/4117 + if is_conv1d: + q_input_pad = numpy.expand_dims(q_input_pad, axis=-2) + q_weights_values = numpy.expand_dims(q_weights_values, axis=-2) + kernel_shape = (1, kernel_shape[0]) + strides = (1, strides[0]) + dilations = (1, dilations[0]) + + # Prepare a constant tensor to compute the sum of the inputs + q_weights_1 = numpy.ones_like(q_weights_values) + # We follow the Quantized Gemm implementation # which in turn follows Eq.7 in https://arxiv.org/abs/1712.05877 # to split the core computation from the zero points and scales. @@ -869,21 +886,22 @@ def q_impl( # Compute the first encrypted term that convolves weights and inputs # Force padding to 0 as padding needs to use a custom padding initializer # and is thus manually performed in the code above - fake_pads = [0] * len(self.pads) + fake_pads = [0, 0] * len(kernel_shape) with tag(self.op_instance_name + ".conv"): conv_wx = fhe_conv( q_input_pad, - q_weights.qvalues, + q_weights_values, bias=None, pads=fake_pads, - strides=self.strides, - dilations=self.dilations, + kernel_shape=kernel_shape, + strides=strides, + dilations=dilations, group=self.group, ) # The total number of elements that are convolved by the application of a single kernel - n_weights = numpy.prod(q_weights.qvalues.shape[1:]) + n_weights = numpy.prod(q_weights_values.shape[1:]) # If the weights have symmetric quantization, their zero point will be 0 # The following check avoids the computation of the sum of the inputs, which may have @@ -900,9 +918,10 @@ def q_impl( q_input_pad, q_weights_1, bias=None, - pads=[0, 0, 0, 0], - strides=self.strides, - dilations=self.dilations, + pads=fake_pads, + kernel_shape=kernel_shape, + strides=strides, + dilations=dilations, group=self.group, ) @@ -911,14 +930,22 @@ def q_impl( else: numpy_q_out = conv_wx + # Workaround for handling torch's Conv1d operator until it is supported by Concrete Python + # FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/4117 + if is_conv1d: + numpy_q_out = numpy.squeeze(numpy_q_out, axis=-2) + if self.debug_value_tracker is not None: # pylint: disable-next=unsubscriptable-object self.debug_value_tracker[self.op_instance_name]["output"] = numpy_q_out + weight_sum_axes = (1, 2) if is_conv1d else (1, 2, 3) + weight_transpose_axes = (1, 0, 2) if is_conv1d else (1, 0, 2, 3) + # Compute the third term, the sum of the weights which is a constant sum_weights = q_input.quantizer.zero_point * numpy.sum( - q_weights.qvalues, axis=(1, 2, 3), keepdims=True - ).transpose(1, 0, 2, 3) + q_weights.qvalues, axis=weight_sum_axes, keepdims=True + ).transpose(*weight_transpose_axes) # Compute the forth term which is a constant final_term = n_weights * q_input.quantizer.zero_point * q_weights.quantizer.zero_point @@ -929,6 +956,8 @@ def q_impl( # any Gemm/Add/Conv layers that follow m_matmul = q_input.quantizer.scale * q_weights.quantizer.scale + bias_shape = (1, -1, 1) if is_conv1d else (1, -1, 1, 1) + # If this operation's result are network outputs, return # directly the integer values and an appropriate quantization parameters that # allow direct in-the-clear de-quantization, including the bias @@ -944,7 +973,7 @@ def q_impl( out_zp: Union[int, numpy.ndarray] = sum_weights - final_term if q_bias is not None: # Reshape the biases to broadcast them to each channel - out_zp = out_zp - q_bias.values.reshape((1, -1, 1, 1)) / m_matmul + out_zp = out_zp - q_bias.values.reshape(bias_shape) / m_matmul # We identify terms in the above equation to determine what # the scale/zero-point of the in-the-clear quantizer should be @@ -956,7 +985,8 @@ def q_impl( # The bias scale should be the same scale as the one of the weights * inputs assert q_bias.quantizer.scale is not None assert numpy.isclose(q_bias.quantizer.scale, m_matmul) - numpy_q_out += q_bias.qvalues.reshape((1, -1, 1, 1)) + + numpy_q_out += q_bias.qvalues.reshape(bias_shape) with tag(self.op_instance_name + ".conv_rounding"): # Apply Concrete rounding (if relevant) @@ -973,7 +1003,7 @@ def q_impl( if q_bias is not None and not q_bias.quantizer.is_precomputed_qat: # The bias addition is handled in float and will be fused into a TLU # Reshape the biases to broadcast them to each channel - numpy_q_out = numpy_q_out + q_bias.values.reshape((1, -1, 1, 1)) # bias_part + numpy_q_out = numpy_q_out + q_bias.values.reshape(bias_shape) # bias_part # And return as a QuantizedArray initialized from the float data, keeping # track of the quantization parameters diff --git a/tests/quantization/test_quantized_ops.py b/tests/quantization/test_quantized_ops.py index 4fa475b27..b592fea9a 100644 --- a/tests/quantization/test_quantized_ops.py +++ b/tests/quantization/test_quantized_ops.py @@ -698,9 +698,13 @@ def test_identity_op(x, n_bits): ), ], ) -@pytest.mark.parametrize("produces_output", [True, False]) +@pytest.mark.parametrize("produces_output", [True, False], ids=["produces_output", ""]) +@pytest.mark.parametrize("is_conv1d", [True, False], ids=["is_conv1d", "is_conv2d"]) +# @pytest.mark.parametrize("is_conv1d", [True], ids=["is_conv1d"]) # pylint: disable-next=too-many-locals -def test_quantized_conv(params, n_bits, produces_output, check_r2_score, check_float_array_equal): +def test_quantized_conv( + params, n_bits, produces_output, is_conv1d, check_r2_score, check_float_array_equal +): """Test the quantized convolution operator.""" # Retrieve arguments @@ -717,6 +721,19 @@ def test_quantized_conv(params, n_bits, produces_output, check_r2_score, check_f group, ) = params + # If testing the conv1d operator, make the parameters represent 1D inputs + if is_conv1d: + size_input = size_input[:3] + size_weights = size_weights[:3] + strides = strides[:1] + pads = pads[:2] + dilations = (1,) + conv_torch_op = torch.conv1d + + else: + dilations = (1, 1) # type: ignore[assignment] + conv_torch_op = torch.conv2d + net_input = numpy.random.uniform(size=size_input) * scale_input weights = numpy.random.randn(*size_weights) * scale_weights biases = numpy.random.uniform(size=size_bias) * scale_bias + offset_bias @@ -734,8 +751,8 @@ def test_quantized_conv(params, n_bits, produces_output, check_r2_score, check_f constant_inputs={1: q_weights, 2: q_bias}, strides=strides, pads=pads, - kernel_shape=(weights.shape[2], weights.shape[3]), - dilations=(1, 1), + kernel_shape=weights.shape[2:], + dilations=dilations, group=group, ) q_op.produces_graph_output = produces_output @@ -743,24 +760,26 @@ def test_quantized_conv(params, n_bits, produces_output, check_r2_score, check_f # Compute the result in floating point expected_result = q_op.calibrate(net_input) - # Compute the reference result - - # Pad the input if needed + # For Conv1d, torch and ONNX both follow the same padding convention + if is_conv1d: + input_padded = torch.nn.functional.pad(torch.Tensor(net_input.copy()), pads) - # Torch uses padding (padding_left,padding_right, padding_top,padding_bottom) + # For Conv2d, torch uses padding (padding_left, padding_right, padding_top, padding_bottom) # While ONNX and Concrete ML use (padding_top, padding_left, padding_bottom, padding_right) - tx_pad = torch.nn.functional.pad( - torch.Tensor(net_input.copy()), (pads[1], pads[3], pads[0], pads[2]) - ) + else: + input_padded = torch.nn.functional.pad( + torch.Tensor(net_input.copy()), (pads[1], pads[3], pads[0], pads[2]) + ) - # Compute the torch convolution - torch_res = torch.conv2d( - tx_pad, - torch.Tensor(weights.copy()), - torch.Tensor(biases.squeeze().copy()) if biases is not None else None, - strides, + # Compute the reference result using the torch convolution operator + torch_res = conv_torch_op( + input=input_padded, + weight=torch.Tensor(weights.copy()), + bias=torch.Tensor(biases.squeeze().copy()) if biases is not None else None, + stride=strides, groups=group, ).numpy() + check_float_array_equal(torch_res, expected_result) # Compute the quantized result diff --git a/tests/torch/test_compile_torch.py b/tests/torch/test_compile_torch.py index 5dd5d0a9e..41e7846ec 100644 --- a/tests/torch/test_compile_torch.py +++ b/tests/torch/test_compile_torch.py @@ -30,6 +30,7 @@ CNNGrouped, CNNOther, ConcatFancyIndexing, + Conv1dModel, DoubleQuantQATMixNet, EncryptedMatrixMultiplicationModel, ExpandModel, @@ -507,16 +508,18 @@ def test_compile_torch_or_onnx_networks( ], ) @pytest.mark.parametrize( - "model", + "model, is_1d", [ - pytest.param(CNNOther), - pytest.param(partial(CNNGrouped, groups=3)), + pytest.param(CNNOther, False, id="CNN"), + pytest.param(partial(CNNGrouped, groups=3), False, id="CNN_grouped"), + pytest.param(Conv1dModel, True, id="CNN_conv1d"), ], ) @pytest.mark.parametrize("simulate", [True, False]) @pytest.mark.parametrize("is_onnx", [True, False]) def test_compile_torch_or_onnx_conv_networks( # pylint: disable=unused-argument model, + is_1d, activation_function, default_configuration, simulate, @@ -530,7 +533,7 @@ def test_compile_torch_or_onnx_conv_networks( # pylint: disable=unused-argument # The QAT bits is set to 0 in order to signal that the network is not using QAT qat_bits = 0 - input_shape = (6, 7, 7) + input_shape = (6, 7) if is_1d else (6, 7, 7) input_output = input_shape[0] q_module = compile_and_test_torch_or_onnx(