Skip to content

Commit

Permalink
feat: support conv1d operator
Browse files Browse the repository at this point in the history
  • Loading branch information
RomanBredehoft committed Feb 5, 2024
1 parent fa3ef88 commit 9d30ef6
Show file tree
Hide file tree
Showing 6 changed files with 148 additions and 44 deletions.
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,6 @@ To cite Concrete ML, notably in academic papers, please use the following entry,
<img src="https://github.com/zama-ai/concrete-ml/assets/157474013/8ef18a7e-671b-495c-8346-fa75227d0af3">
</a>


## License.

This software is distributed under the BSD-3-Clause-Clear license. If you have any questions, please contact us at [email protected].
32 changes: 30 additions & 2 deletions src/concrete/ml/onnx/ops_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1279,7 +1279,10 @@ def numpy_conv(
"""

# Convert the inputs to tensors to compute conv using torch
assert_true(len(kernel_shape) == 2, "The convolution operator currently supports only 2-d")
assert_true(
len(kernel_shape) in (1, 2),
f"The convolution operator currently only supports 1-d or 2-d. Got {len(kernel_shape)}-d",
)
assert_true(
bool(numpy.all(numpy.asarray(dilations) == 1)),
"The convolution operator in Concrete does not support dilation",
Expand All @@ -1300,8 +1303,33 @@ def numpy_conv(
# Pad the input if needed
x_pad = numpy_onnx_pad(x, pads)

is_conv1d = len(kernel_shape) == 1

# Workaround for handling torch's Conv1d operator until it is supported by Concrete Python
# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/41
if is_conv1d:
x_pad = numpy.expand_dims(x_pad, axis=-2)
w = numpy.expand_dims(w, axis=-2)
kernel_shape = (1, kernel_shape[0])
strides = (1, strides[0])
dilations = (1, dilations[0])

# Compute the torch convolution
res = fhe_conv(x_pad, w, b, None, strides, dilations, None, group)
res = fhe_conv(
x=x_pad,
weight=w,
bias=b,
pads=None,
strides=strides,
dilations=dilations,
kernel_shape=kernel_shape,
group=group,
)

# Workaround for handling torch's Conv1d operator until it is supported by Concrete Python
# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/41
if is_conv1d:
res = numpy.squeeze(res, axis=-2)

return (res,)

Expand Down
25 changes: 25 additions & 0 deletions src/concrete/ml/pytest/torch_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1546,3 +1546,28 @@ def forward(self, x):
x = self.input_quant(x)
x = x.reshape(x.shape + (1,))
return x.expand(x.shape[:-1] + (4,))


class Conv1dModel(nn.Module):
"""Small model that uses a 1D convolution operator."""

def __init__(self, input_output, activation_function) -> None:
super().__init__()

self.conv1 = nn.Conv1d(input_output, 2, 2, stride=1, padding=0)
self.act = activation_function()
self.fc1 = nn.Linear(input_output, 3)

def forward(self, x):
"""Forward pass.
Args:
x (torch.Tensor): The model's input.
Returns:
torch.Tensor: The model's output.
"""
x = self.act(self.conv1(x))
x = self.fc1(x)
return x
70 changes: 50 additions & 20 deletions src/concrete/ml/quantization/quantized_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -777,8 +777,9 @@ def __init__(

# Validate the parameters
assert_true(
len(self.kernel_shape) == 2,
"The convolution operator currently supports only 2d",
len(self.kernel_shape) in (1, 2),
"The convolution operator currently only supports 1-d or 2-d. "
f"Got {len(self.kernel_shape)}-d",
)
assert_true(
len(self.kernel_shape) == len(self.strides),
Expand All @@ -787,7 +788,7 @@ def __init__(
)
assert_true(
bool(numpy.all(numpy.asarray(self.dilations) == 1)),
"The convolution operator in Concrete does not suppport dilation",
"The convolution operator in Concrete does not support dilation",
)
assert_true(
len(self.pads) == 2 * len(self.kernel_shape),
Expand All @@ -796,7 +797,7 @@ def __init__(
" standard",
)

# pylint: disable-next=too-many-statements
# pylint: disable-next=too-many-statements, too-many-locals
def q_impl(
self,
*q_inputs: ONNXOpInputOutputType,
Expand Down Expand Up @@ -847,9 +848,6 @@ def q_impl(
f"group ({self.group}).",
)

# Prepare a constant tensor to compute the sum of the inputs
q_weights_1 = numpy.ones_like(q_weights.qvalues)

assert q_weights.quantizer.scale is not None
assert q_weights.quantizer.zero_point is not None

Expand All @@ -862,28 +860,48 @@ def q_impl(
pad_value = int(q_input.quantizer.zero_point)
q_input_pad = numpy_onnx_pad(q_input.qvalues, self.pads, pad_value, True)

is_conv1d = len(self.kernel_shape) == 1

q_weights_values = q_weights.qvalues
kernel_shape = self.kernel_shape
strides = self.strides
dilations = self.dilations

# Workaround for handling torch's Conv1d operator until it is supported by Concrete Python
# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/4117
if is_conv1d:
q_input_pad = numpy.expand_dims(q_input_pad, axis=-2)
q_weights_values = numpy.expand_dims(q_weights_values, axis=-2)
kernel_shape = (1, kernel_shape[0])
strides = (1, strides[0])
dilations = (1, dilations[0])

# Prepare a constant tensor to compute the sum of the inputs
q_weights_1 = numpy.ones_like(q_weights_values)

# We follow the Quantized Gemm implementation
# which in turn follows Eq.7 in https://arxiv.org/abs/1712.05877
# to split the core computation from the zero points and scales.

# Compute the first encrypted term that convolves weights and inputs
# Force padding to 0 as padding needs to use a custom padding initializer
# and is thus manually performed in the code above
fake_pads = [0] * len(self.pads)
fake_pads = [0, 0] * len(kernel_shape)

with tag(self.op_instance_name + ".conv"):
conv_wx = fhe_conv(
q_input_pad,
q_weights.qvalues,
q_weights_values,
bias=None,
pads=fake_pads,
strides=self.strides,
dilations=self.dilations,
kernel_shape=kernel_shape,
strides=strides,
dilations=dilations,
group=self.group,
)

# The total number of elements that are convolved by the application of a single kernel
n_weights = numpy.prod(q_weights.qvalues.shape[1:])
n_weights = numpy.prod(q_weights_values.shape[1:])

# If the weights have symmetric quantization, their zero point will be 0
# The following check avoids the computation of the sum of the inputs, which may have
Expand All @@ -900,9 +918,10 @@ def q_impl(
q_input_pad,
q_weights_1,
bias=None,
pads=[0, 0, 0, 0],
strides=self.strides,
dilations=self.dilations,
pads=fake_pads,
kernel_shape=kernel_shape,
strides=strides,
dilations=dilations,
group=self.group,
)

Expand All @@ -911,14 +930,22 @@ def q_impl(
else:
numpy_q_out = conv_wx

# Workaround for handling torch's Conv1d operator until it is supported by Concrete Python
# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/4117
if is_conv1d:
numpy_q_out = numpy.squeeze(numpy_q_out, axis=-2)

if self.debug_value_tracker is not None:
# pylint: disable-next=unsubscriptable-object
self.debug_value_tracker[self.op_instance_name]["output"] = numpy_q_out

weight_sum_axes = (1, 2) if is_conv1d else (1, 2, 3)
weight_transpose_axes = (1, 0, 2) if is_conv1d else (1, 0, 2, 3)

# Compute the third term, the sum of the weights which is a constant
sum_weights = q_input.quantizer.zero_point * numpy.sum(
q_weights.qvalues, axis=(1, 2, 3), keepdims=True
).transpose(1, 0, 2, 3)
q_weights.qvalues, axis=weight_sum_axes, keepdims=True
).transpose(*weight_transpose_axes)

# Compute the forth term which is a constant
final_term = n_weights * q_input.quantizer.zero_point * q_weights.quantizer.zero_point
Expand All @@ -929,6 +956,8 @@ def q_impl(
# any Gemm/Add/Conv layers that follow
m_matmul = q_input.quantizer.scale * q_weights.quantizer.scale

bias_shape = (1, -1, 1) if is_conv1d else (1, -1, 1, 1)

# If this operation's result are network outputs, return
# directly the integer values and an appropriate quantization parameters that
# allow direct in-the-clear de-quantization, including the bias
Expand All @@ -944,7 +973,7 @@ def q_impl(
out_zp: Union[int, numpy.ndarray] = sum_weights - final_term
if q_bias is not None:
# Reshape the biases to broadcast them to each channel
out_zp = out_zp - q_bias.values.reshape((1, -1, 1, 1)) / m_matmul
out_zp = out_zp - q_bias.values.reshape(bias_shape) / m_matmul

# We identify terms in the above equation to determine what
# the scale/zero-point of the in-the-clear quantizer should be
Expand All @@ -956,7 +985,8 @@ def q_impl(
# The bias scale should be the same scale as the one of the weights * inputs
assert q_bias.quantizer.scale is not None
assert numpy.isclose(q_bias.quantizer.scale, m_matmul)
numpy_q_out += q_bias.qvalues.reshape((1, -1, 1, 1))

numpy_q_out += q_bias.qvalues.reshape(bias_shape)

with tag(self.op_instance_name + ".conv_rounding"):
# Apply Concrete rounding (if relevant)
Expand All @@ -973,7 +1003,7 @@ def q_impl(
if q_bias is not None and not q_bias.quantizer.is_precomputed_qat:
# The bias addition is handled in float and will be fused into a TLU
# Reshape the biases to broadcast them to each channel
numpy_q_out = numpy_q_out + q_bias.values.reshape((1, -1, 1, 1)) # bias_part
numpy_q_out = numpy_q_out + q_bias.values.reshape(bias_shape) # bias_part

# And return as a QuantizedArray initialized from the float data, keeping
# track of the quantization parameters
Expand Down
53 changes: 36 additions & 17 deletions tests/quantization/test_quantized_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -698,9 +698,13 @@ def test_identity_op(x, n_bits):
),
],
)
@pytest.mark.parametrize("produces_output", [True, False])
@pytest.mark.parametrize("produces_output", [True, False], ids=["produces_output", ""])
@pytest.mark.parametrize("is_conv1d", [True, False], ids=["is_conv1d", "is_conv2d"])
# @pytest.mark.parametrize("is_conv1d", [True], ids=["is_conv1d"])
# pylint: disable-next=too-many-locals
def test_quantized_conv(params, n_bits, produces_output, check_r2_score, check_float_array_equal):
def test_quantized_conv(
params, n_bits, produces_output, is_conv1d, check_r2_score, check_float_array_equal
):
"""Test the quantized convolution operator."""

# Retrieve arguments
Expand All @@ -717,6 +721,19 @@ def test_quantized_conv(params, n_bits, produces_output, check_r2_score, check_f
group,
) = params

# If testing the conv1d operator, make the parameters represent 1D inputs
if is_conv1d:
size_input = size_input[:3]
size_weights = size_weights[:3]
strides = strides[:1]
pads = pads[:2]
dilations = (1,)
conv_torch_op = torch.conv1d

else:
dilations = (1, 1) # type: ignore[assignment]
conv_torch_op = torch.conv2d

net_input = numpy.random.uniform(size=size_input) * scale_input
weights = numpy.random.randn(*size_weights) * scale_weights
biases = numpy.random.uniform(size=size_bias) * scale_bias + offset_bias
Expand All @@ -734,33 +751,35 @@ def test_quantized_conv(params, n_bits, produces_output, check_r2_score, check_f
constant_inputs={1: q_weights, 2: q_bias},
strides=strides,
pads=pads,
kernel_shape=(weights.shape[2], weights.shape[3]),
dilations=(1, 1),
kernel_shape=weights.shape[2:],
dilations=dilations,
group=group,
)
q_op.produces_graph_output = produces_output

# Compute the result in floating point
expected_result = q_op.calibrate(net_input)

# Compute the reference result

# Pad the input if needed
# For Conv1d, torch and ONNX both follow the same padding convention
if is_conv1d:
input_padded = torch.nn.functional.pad(torch.Tensor(net_input.copy()), pads)

# Torch uses padding (padding_left,padding_right, padding_top,padding_bottom)
# For Conv2d, torch uses padding (padding_left, padding_right, padding_top, padding_bottom)
# While ONNX and Concrete ML use (padding_top, padding_left, padding_bottom, padding_right)
tx_pad = torch.nn.functional.pad(
torch.Tensor(net_input.copy()), (pads[1], pads[3], pads[0], pads[2])
)
else:
input_padded = torch.nn.functional.pad(
torch.Tensor(net_input.copy()), (pads[1], pads[3], pads[0], pads[2])
)

# Compute the torch convolution
torch_res = torch.conv2d(
tx_pad,
torch.Tensor(weights.copy()),
torch.Tensor(biases.squeeze().copy()) if biases is not None else None,
strides,
# Compute the reference result using the torch convolution operator
torch_res = conv_torch_op(
input=input_padded,
weight=torch.Tensor(weights.copy()),
bias=torch.Tensor(biases.squeeze().copy()) if biases is not None else None,
stride=strides,
groups=group,
).numpy()

check_float_array_equal(torch_res, expected_result)

# Compute the quantized result
Expand Down
11 changes: 7 additions & 4 deletions tests/torch/test_compile_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
CNNGrouped,
CNNOther,
ConcatFancyIndexing,
Conv1dModel,
DoubleQuantQATMixNet,
EncryptedMatrixMultiplicationModel,
ExpandModel,
Expand Down Expand Up @@ -507,16 +508,18 @@ def test_compile_torch_or_onnx_networks(
],
)
@pytest.mark.parametrize(
"model",
"model, is_1d",
[
pytest.param(CNNOther),
pytest.param(partial(CNNGrouped, groups=3)),
pytest.param(CNNOther, False, id="CNN"),
pytest.param(partial(CNNGrouped, groups=3), False, id="CNN_grouped"),
pytest.param(Conv1dModel, True, id="CNN_conv1d"),
],
)
@pytest.mark.parametrize("simulate", [True, False])
@pytest.mark.parametrize("is_onnx", [True, False])
def test_compile_torch_or_onnx_conv_networks( # pylint: disable=unused-argument
model,
is_1d,
activation_function,
default_configuration,
simulate,
Expand All @@ -530,7 +533,7 @@ def test_compile_torch_or_onnx_conv_networks( # pylint: disable=unused-argument
# The QAT bits is set to 0 in order to signal that the network is not using QAT
qat_bits = 0

input_shape = (6, 7, 7)
input_shape = (6, 7) if is_1d else (6, 7, 7)
input_output = input_shape[0]

q_module = compile_and_test_torch_or_onnx(
Expand Down

0 comments on commit 9d30ef6

Please sign in to comment.