Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support conv1d operator #484

Merged
merged 1 commit into from
Feb 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,6 @@ To cite Concrete ML, notably in academic papers, please use the following entry,
<img src="https://github.com/zama-ai/concrete-ml/assets/157474013/8ef18a7e-671b-495c-8346-fa75227d0af3">
</a>


## License.

This software is distributed under the BSD-3-Clause-Clear license. If you have any questions, please contact us at [email protected].
32 changes: 30 additions & 2 deletions src/concrete/ml/onnx/ops_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1279,7 +1279,10 @@ def numpy_conv(
"""

# Convert the inputs to tensors to compute conv using torch
assert_true(len(kernel_shape) == 2, "The convolution operator currently supports only 2-d")
assert_true(
len(kernel_shape) in (1, 2),
f"The convolution operator currently only supports 1d or 2d. Got {len(kernel_shape)}-d",
)
assert_true(
bool(numpy.all(numpy.asarray(dilations) == 1)),
"The convolution operator in Concrete does not support dilation",
Expand All @@ -1300,8 +1303,33 @@ def numpy_conv(
# Pad the input if needed
x_pad = numpy_onnx_pad(x, pads)

is_conv1d = len(kernel_shape) == 1

# Workaround for handling torch's Conv1d operator until it is supported by Concrete Python
# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/41
if is_conv1d:
x_pad = numpy.expand_dims(x_pad, axis=-2)
w = numpy.expand_dims(w, axis=-2)
kernel_shape = (1, kernel_shape[0])
strides = (1, strides[0])
dilations = (1, dilations[0])

# Compute the torch convolution
res = fhe_conv(x_pad, w, b, None, strides, dilations, None, group)
res = fhe_conv(
x=x_pad,
weight=w,
bias=b,
pads=None,
strides=strides,
dilations=dilations,
kernel_shape=kernel_shape,
group=group,
)

# Workaround for handling torch's Conv1d operator until it is supported by Concrete Python
# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/41
if is_conv1d:
res = numpy.squeeze(res, axis=-2)

return (res,)

Expand Down
25 changes: 25 additions & 0 deletions src/concrete/ml/pytest/torch_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1546,3 +1546,28 @@ def forward(self, x):
x = self.input_quant(x)
x = x.reshape(x.shape + (1,))
return x.expand(x.shape[:-1] + (4,))


class Conv1dModel(nn.Module):
"""Small model that uses a 1D convolution operator."""

def __init__(self, input_output, activation_function) -> None:
super().__init__()

self.conv1 = nn.Conv1d(input_output, 2, 2, stride=1, padding=0)
self.act = activation_function()
self.fc1 = nn.Linear(input_output, 3)

def forward(self, x):
"""Forward pass.

Args:
x (torch.Tensor): The model's input.

Returns:
torch.Tensor: The model's output.

"""
x = self.act(self.conv1(x))
x = self.fc1(x)
return x
70 changes: 50 additions & 20 deletions src/concrete/ml/quantization/quantized_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -777,8 +777,9 @@ def __init__(

# Validate the parameters
assert_true(
len(self.kernel_shape) == 2,
"The convolution operator currently supports only 2d",
len(self.kernel_shape) in (1, 2),
"The convolution operator currently only supports 1d or 2d. "
f"Got {len(self.kernel_shape)}-d",
)
assert_true(
len(self.kernel_shape) == len(self.strides),
Expand All @@ -787,7 +788,7 @@ def __init__(
)
assert_true(
bool(numpy.all(numpy.asarray(self.dilations) == 1)),
"The convolution operator in Concrete does not suppport dilation",
"The convolution operator in Concrete does not support dilation",
)
assert_true(
len(self.pads) == 2 * len(self.kernel_shape),
Expand All @@ -796,7 +797,7 @@ def __init__(
" standard",
)

# pylint: disable-next=too-many-statements
# pylint: disable-next=too-many-statements, too-many-locals
def q_impl(
self,
*q_inputs: ONNXOpInputOutputType,
Expand Down Expand Up @@ -847,9 +848,6 @@ def q_impl(
f"group ({self.group}).",
)

# Prepare a constant tensor to compute the sum of the inputs
q_weights_1 = numpy.ones_like(q_weights.qvalues)

assert q_weights.quantizer.scale is not None
assert q_weights.quantizer.zero_point is not None

Expand All @@ -862,28 +860,48 @@ def q_impl(
pad_value = int(q_input.quantizer.zero_point)
q_input_pad = numpy_onnx_pad(q_input.qvalues, self.pads, pad_value, True)

is_conv1d = len(self.kernel_shape) == 1

q_weights_values = q_weights.qvalues
kernel_shape = self.kernel_shape
strides = self.strides
dilations = self.dilations

# Workaround for handling torch's Conv1d operator until it is supported by Concrete Python
# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/4117
if is_conv1d:
q_input_pad = numpy.expand_dims(q_input_pad, axis=-2)
q_weights_values = numpy.expand_dims(q_weights_values, axis=-2)
kernel_shape = (1, kernel_shape[0])
strides = (1, strides[0])
dilations = (1, dilations[0])

# Prepare a constant tensor to compute the sum of the inputs
q_weights_1 = numpy.ones_like(q_weights_values)

# We follow the Quantized Gemm implementation
# which in turn follows Eq.7 in https://arxiv.org/abs/1712.05877
# to split the core computation from the zero points and scales.

# Compute the first encrypted term that convolves weights and inputs
# Force padding to 0 as padding needs to use a custom padding initializer
# and is thus manually performed in the code above
fake_pads = [0] * len(self.pads)
fake_pads = [0, 0] * len(kernel_shape)

with tag(self.op_instance_name + ".conv"):
conv_wx = fhe_conv(
q_input_pad,
q_weights.qvalues,
q_weights_values,
bias=None,
pads=fake_pads,
strides=self.strides,
dilations=self.dilations,
kernel_shape=kernel_shape,
strides=strides,
dilations=dilations,
group=self.group,
)

# The total number of elements that are convolved by the application of a single kernel
n_weights = numpy.prod(q_weights.qvalues.shape[1:])
n_weights = numpy.prod(q_weights_values.shape[1:])

# If the weights have symmetric quantization, their zero point will be 0
# The following check avoids the computation of the sum of the inputs, which may have
Expand All @@ -900,9 +918,10 @@ def q_impl(
q_input_pad,
q_weights_1,
bias=None,
pads=[0, 0, 0, 0],
strides=self.strides,
dilations=self.dilations,
pads=fake_pads,
kernel_shape=kernel_shape,
strides=strides,
dilations=dilations,
group=self.group,
)

Expand All @@ -911,14 +930,22 @@ def q_impl(
else:
numpy_q_out = conv_wx

# Workaround for handling torch's Conv1d operator until it is supported by Concrete Python
# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/4117
if is_conv1d:
numpy_q_out = numpy.squeeze(numpy_q_out, axis=-2)

if self.debug_value_tracker is not None:
# pylint: disable-next=unsubscriptable-object
self.debug_value_tracker[self.op_instance_name]["output"] = numpy_q_out

weight_sum_axes = (1, 2) if is_conv1d else (1, 2, 3)
weight_transpose_axes = (1, 0, 2) if is_conv1d else (1, 0, 2, 3)

# Compute the third term, the sum of the weights which is a constant
sum_weights = q_input.quantizer.zero_point * numpy.sum(
q_weights.qvalues, axis=(1, 2, 3), keepdims=True
).transpose(1, 0, 2, 3)
q_weights.qvalues, axis=weight_sum_axes, keepdims=True
).transpose(*weight_transpose_axes)

# Compute the forth term which is a constant
final_term = n_weights * q_input.quantizer.zero_point * q_weights.quantizer.zero_point
Expand All @@ -929,6 +956,8 @@ def q_impl(
# any Gemm/Add/Conv layers that follow
m_matmul = q_input.quantizer.scale * q_weights.quantizer.scale

bias_shape = (1, -1, 1) if is_conv1d else (1, -1, 1, 1)

# If this operation's result are network outputs, return
# directly the integer values and an appropriate quantization parameters that
# allow direct in-the-clear de-quantization, including the bias
Expand All @@ -944,7 +973,7 @@ def q_impl(
out_zp: Union[int, numpy.ndarray] = sum_weights - final_term
if q_bias is not None:
# Reshape the biases to broadcast them to each channel
out_zp = out_zp - q_bias.values.reshape((1, -1, 1, 1)) / m_matmul
out_zp = out_zp - q_bias.values.reshape(bias_shape) / m_matmul

# We identify terms in the above equation to determine what
# the scale/zero-point of the in-the-clear quantizer should be
Expand All @@ -956,7 +985,8 @@ def q_impl(
# The bias scale should be the same scale as the one of the weights * inputs
assert q_bias.quantizer.scale is not None
assert numpy.isclose(q_bias.quantizer.scale, m_matmul)
numpy_q_out += q_bias.qvalues.reshape((1, -1, 1, 1))

numpy_q_out += q_bias.qvalues.reshape(bias_shape)

with tag(self.op_instance_name + ".conv_rounding"):
# Apply Concrete rounding (if relevant)
Expand All @@ -973,7 +1003,7 @@ def q_impl(
if q_bias is not None and not q_bias.quantizer.is_precomputed_qat:
# The bias addition is handled in float and will be fused into a TLU
# Reshape the biases to broadcast them to each channel
numpy_q_out = numpy_q_out + q_bias.values.reshape((1, -1, 1, 1)) # bias_part
numpy_q_out = numpy_q_out + q_bias.values.reshape(bias_shape) # bias_part

# And return as a QuantizedArray initialized from the float data, keeping
# track of the quantization parameters
Expand Down
53 changes: 36 additions & 17 deletions tests/quantization/test_quantized_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -698,9 +698,13 @@ def test_identity_op(x, n_bits):
),
],
)
@pytest.mark.parametrize("produces_output", [True, False])
@pytest.mark.parametrize("produces_output", [True, False], ids=["produces_output", ""])
@pytest.mark.parametrize("is_conv1d", [True, False], ids=["is_conv1d", "is_conv2d"])
# @pytest.mark.parametrize("is_conv1d", [True], ids=["is_conv1d"])
# pylint: disable-next=too-many-locals
def test_quantized_conv(params, n_bits, produces_output, check_r2_score, check_float_array_equal):
def test_quantized_conv(
params, n_bits, produces_output, is_conv1d, check_r2_score, check_float_array_equal
):
"""Test the quantized convolution operator."""

# Retrieve arguments
Expand All @@ -717,6 +721,19 @@ def test_quantized_conv(params, n_bits, produces_output, check_r2_score, check_f
group,
) = params

# If testing the conv1d operator, make the parameters represent 1D inputs
if is_conv1d:
size_input = size_input[:3]
size_weights = size_weights[:3]
strides = strides[:1]
pads = pads[:2]
dilations = (1,)
conv_torch_op = torch.conv1d

else:
dilations = (1, 1) # type: ignore[assignment]
conv_torch_op = torch.conv2d

net_input = numpy.random.uniform(size=size_input) * scale_input
weights = numpy.random.randn(*size_weights) * scale_weights
biases = numpy.random.uniform(size=size_bias) * scale_bias + offset_bias
Expand All @@ -734,33 +751,35 @@ def test_quantized_conv(params, n_bits, produces_output, check_r2_score, check_f
constant_inputs={1: q_weights, 2: q_bias},
strides=strides,
pads=pads,
kernel_shape=(weights.shape[2], weights.shape[3]),
dilations=(1, 1),
kernel_shape=weights.shape[2:],
dilations=dilations,
group=group,
)
q_op.produces_graph_output = produces_output

# Compute the result in floating point
expected_result = q_op.calibrate(net_input)

# Compute the reference result

# Pad the input if needed
# For Conv1d, torch and ONNX both follow the same padding convention
if is_conv1d:
input_padded = torch.nn.functional.pad(torch.Tensor(net_input.copy()), pads)

# Torch uses padding (padding_left,padding_right, padding_top,padding_bottom)
# For Conv2d, torch uses padding (padding_left, padding_right, padding_top, padding_bottom)
# While ONNX and Concrete ML use (padding_top, padding_left, padding_bottom, padding_right)
tx_pad = torch.nn.functional.pad(
torch.Tensor(net_input.copy()), (pads[1], pads[3], pads[0], pads[2])
)
else:
input_padded = torch.nn.functional.pad(
torch.Tensor(net_input.copy()), (pads[1], pads[3], pads[0], pads[2])
)

# Compute the torch convolution
torch_res = torch.conv2d(
tx_pad,
torch.Tensor(weights.copy()),
torch.Tensor(biases.squeeze().copy()) if biases is not None else None,
strides,
# Compute the reference result using the torch convolution operator
torch_res = conv_torch_op(
input=input_padded,
weight=torch.Tensor(weights.copy()),
bias=torch.Tensor(biases.squeeze().copy()) if biases is not None else None,
stride=strides,
groups=group,
).numpy()

check_float_array_equal(torch_res, expected_result)

# Compute the quantized result
Expand Down
11 changes: 7 additions & 4 deletions tests/torch/test_compile_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
CNNGrouped,
CNNOther,
ConcatFancyIndexing,
Conv1dModel,
DoubleQuantQATMixNet,
EncryptedMatrixMultiplicationModel,
ExpandModel,
Expand Down Expand Up @@ -507,16 +508,18 @@ def test_compile_torch_or_onnx_networks(
],
)
@pytest.mark.parametrize(
"model",
"model, is_1d",
[
pytest.param(CNNOther),
pytest.param(partial(CNNGrouped, groups=3)),
pytest.param(CNNOther, False, id="CNN"),
pytest.param(partial(CNNGrouped, groups=3), False, id="CNN_grouped"),
pytest.param(Conv1dModel, True, id="CNN_conv1d"),
],
)
@pytest.mark.parametrize("simulate", [True, False])
@pytest.mark.parametrize("is_onnx", [True, False])
def test_compile_torch_or_onnx_conv_networks( # pylint: disable=unused-argument
model,
is_1d,
activation_function,
default_configuration,
simulate,
Expand All @@ -530,7 +533,7 @@ def test_compile_torch_or_onnx_conv_networks( # pylint: disable=unused-argument
# The QAT bits is set to 0 in order to signal that the network is not using QAT
qat_bits = 0

input_shape = (6, 7, 7)
input_shape = (6, 7) if is_1d else (6, 7, 7)
input_output = input_shape[0]

q_module = compile_and_test_torch_or_onnx(
Expand Down
Loading