Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: expose statuses to compile torch #261

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitleaksignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
3e2986af4179c38ad5103f571f93372358e4e74f:tests/deployment/test_deployment.py:generic-api-key:59
46d53ae370263367fc49a56638f361495a0ad5d0:tests/deployment/test_deployment.py:generic-api-key:59
2d3b4ca188efb338c03d8d2c921ef39ffc5537e3:tests/deployment/test_deployment.py:generic-api-key:59
198d3fef188aaf3e3a582b9f7943f7ac6e9b5186:tests/deployment/test_deployment.py:generic-api-key:59
198d3fef188aaf3e3a582b9f7943f7ac6e9b5186:tests/deployment/test_deployment.py:generic-api-key:59
41 changes: 39 additions & 2 deletions src/concrete/ml/quantization/quantized_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import copy
import re
from functools import partial
from typing import Any, Dict, Generator, Iterable, List, Optional, TextIO, Tuple, Union
from typing import Any, Dict, Generator, Iterable, List, Optional, Sequence, TextIO, Tuple, Union

import numpy
import onnx
Expand Down Expand Up @@ -577,6 +577,7 @@ def compile(
p_error: Optional[float] = None,
global_p_error: Optional[float] = None,
verbose: bool = False,
inputs_encryption_status: Optional[Sequence[str]] = None,
) -> Circuit:
"""Compile the module's forward function.

Expand All @@ -598,9 +599,15 @@ def compile(
error to a default value.
verbose (bool): Indicate if compilation information should be printed
during compilation. Default to False.
inputs_encryption_status (Optional[Sequence[str]]): encryption status ('clear',
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would still mention that everything is encrypted if the parameter is set to None !

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's mentionned in all the user facing APIs, I don't think many users will use this one tbh

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean it was mostly for making the docstring consistent with what this parameters does, as we now provide a way to chose what inputs to encrypt, to avoid confusion that we don't chose which ones are encrypted or not based on whatever other rules. But sure if no one else pushes for this then go for it

'encrypted') for each input.

Returns:
Circuit: The compiled Circuit.

Raises:
ValueError: if inputs_encryption_status does not match with the
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's a bit more than that right ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well it's if the configuration is wrong, how would you formulate it?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would usually very quickly describe what are bad configurations

parameters of the quantized module
"""
inputs = to_tuple(inputs)

Expand All @@ -621,9 +628,39 @@ def compile(
self._clear_forward, self.ordered_module_input_names
)

if inputs_encryption_status is None:
inputs_encryption_status = tuple(
"encrypted" for _ in orig_args_to_proxy_func_args.values()
)
else:
if len(inputs_encryption_status) < len(orig_args_to_proxy_func_args.values()):
raise ValueError(
f"Missing arguments from '{inputs_encryption_status}', expected "
f"{len(orig_args_to_proxy_func_args.values())} arguments."
)
if len(inputs_encryption_status) > len(orig_args_to_proxy_func_args.values()):
raise ValueError(
f"Too many arguments in '{inputs_encryption_status}', expected "
f"{len(orig_args_to_proxy_func_args.values())} arguments."
)
if not all(value in {"clear", "encrypted"} for value in inputs_encryption_status):
raise ValueError(
f"Unexpected status from '{inputs_encryption_status}',"
" expected 'clear' or 'encrypted'."
)
if not any(value == "encrypted" for value in inputs_encryption_status):
raise ValueError(
f"At least one input should be encrypted but got {inputs_encryption_status}"
)

assert inputs_encryption_status is not None # For mypy
inputs_encryption_status_dict = dict(
zip(orig_args_to_proxy_func_args.values(), inputs_encryption_status)
)

compiler = Compiler(
forward_proxy,
{arg_name: "encrypted" for arg_name in orig_args_to_proxy_func_args.values()},
parameter_encryption_statuses=inputs_encryption_status_dict,
)

# Quantize the inputs
Expand Down
24 changes: 20 additions & 4 deletions src/concrete/ml/torch/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import tempfile
import warnings
from pathlib import Path
from typing import Dict, Optional, Tuple, Union
from typing import Dict, Optional, Sequence, Tuple, Union

import numpy
import onnx
Expand Down Expand Up @@ -118,6 +118,7 @@ def _compile_torch_or_onnx_model(
p_error: Optional[float] = None,
global_p_error: Optional[float] = None,
verbose: bool = False,
inputs_encryption_status: Optional[Sequence[str]] = None,
) -> QuantizedModule:
"""Compile a torch module or ONNX into an FHE equivalent.

Expand All @@ -142,6 +143,8 @@ def _compile_torch_or_onnx_model(
global_p_error (Optional[float]): probability of error of the full circuit. In FHE
simulation `global_p_error` is set to 0
verbose (bool): whether to show compilation information
inputs_encryption_status (Optional[Sequence[str]]): encryption status ('clear', 'encrypted')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same what? There is a comment on the fact that by default everything is encrypted

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as in "same comment as above" in order to not avoid making changes in both places if there are some

for each input. By default all arguments will be encrypted.

Returns:
QuantizedModule: The resulting compiled QuantizedModule.
Expand Down Expand Up @@ -186,6 +189,7 @@ def _compile_torch_or_onnx_model(
p_error=p_error,
global_p_error=global_p_error,
verbose=verbose,
inputs_encryption_status=inputs_encryption_status,
)

return quantized_module
Expand All @@ -199,11 +203,12 @@ def compile_torch_model(
configuration: Optional[Configuration] = None,
artifacts: Optional[DebugArtifacts] = None,
show_mlir: bool = False,
n_bits=MAX_BITWIDTH_BACKWARD_COMPATIBLE,
n_bits: Union[int, Dict[str, int]] = MAX_BITWIDTH_BACKWARD_COMPATIBLE,
rounding_threshold_bits: Optional[int] = None,
p_error: Optional[float] = None,
global_p_error: Optional[float] = None,
verbose: bool = False,
inputs_encryption_status: Optional[Sequence[str]] = None,
) -> QuantizedModule:
"""Compile a torch module into an FHE equivalent.

Expand All @@ -229,6 +234,8 @@ def compile_torch_model(
global_p_error (Optional[float]): probability of error of the full circuit. In FHE
simulation `global_p_error` is set to 0
verbose (bool): whether to show compilation information
inputs_encryption_status (Optional[Sequence[str]]): encryption status ('clear', 'encrypted')
for each input. By default all arguments will be encrypted.

Returns:
QuantizedModule: The resulting compiled QuantizedModule.
Expand Down Expand Up @@ -261,6 +268,7 @@ def compile_torch_model(
p_error=p_error,
global_p_error=global_p_error,
verbose=verbose,
inputs_encryption_status=inputs_encryption_status,
)


Expand All @@ -272,11 +280,12 @@ def compile_onnx_model(
configuration: Optional[Configuration] = None,
artifacts: Optional[DebugArtifacts] = None,
show_mlir: bool = False,
n_bits: Union[int, Dict] = MAX_BITWIDTH_BACKWARD_COMPATIBLE,
n_bits: Union[int, Dict[str, int]] = MAX_BITWIDTH_BACKWARD_COMPATIBLE,
rounding_threshold_bits: Optional[int] = None,
p_error: Optional[float] = None,
global_p_error: Optional[float] = None,
verbose: bool = False,
inputs_encryption_status: Optional[Sequence[str]] = None,
) -> QuantizedModule:
"""Compile a torch module into an FHE equivalent.

Expand All @@ -302,6 +311,8 @@ def compile_onnx_model(
global_p_error (Optional[float]): probability of error of the full circuit. In FHE
simulation `global_p_error` is set to 0
verbose (bool): whether to show compilation information
inputs_encryption_status (Optional[Sequence[str]]): encryption status ('clear', 'encrypted')
for each input. By default all arguments will be encrypted.

Returns:
QuantizedModule: The resulting compiled QuantizedModule.
Expand All @@ -326,14 +337,15 @@ def compile_onnx_model(
p_error=p_error,
global_p_error=global_p_error,
verbose=verbose,
inputs_encryption_status=inputs_encryption_status,
)


# pylint: disable-next=too-many-arguments
def compile_brevitas_qat_model(
torch_model: torch.nn.Module,
torch_inputset: Dataset,
n_bits: Optional[Union[int, dict]] = None,
n_bits: Optional[Union[int, Dict[str, int]]] = None,
configuration: Optional[Configuration] = None,
artifacts: Optional[DebugArtifacts] = None,
show_mlir: bool = False,
Expand All @@ -342,6 +354,7 @@ def compile_brevitas_qat_model(
global_p_error: Optional[float] = None,
output_onnx_file: Union[None, Path, str] = None,
verbose: bool = False,
inputs_encryption_status: Optional[Sequence[str]] = None,
) -> QuantizedModule:
"""Compile a Brevitas Quantization Aware Training model.

Expand Down Expand Up @@ -374,6 +387,8 @@ def compile_brevitas_qat_model(
output_onnx_file (str): temporary file to store ONNX model. If None a temporary file
is generated
verbose (bool): whether to show compilation information
inputs_encryption_status (Optional[Sequence[str]]): encryption status ('clear', 'encrypted')
for each input. By default all arguments will be encrypted.

Returns:
QuantizedModule: The resulting compiled QuantizedModule.
Expand Down Expand Up @@ -476,6 +491,7 @@ def compile_brevitas_qat_model(
p_error=p_error,
global_p_error=global_p_error,
verbose=verbose,
inputs_encryption_status=inputs_encryption_status,
)

# Remove the tempfile if we used one
Expand Down
112 changes: 107 additions & 5 deletions tests/quantization/test_quantized_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,10 +305,14 @@ def test_bitwidth_report(model_class, input_shape, activation_function, default_
)

# Ensure the value are the expected ones
# Disable mypy here as it does not seem to understand that the `range` value is
# a tuple in both the report dictionaries
assert op_report["range"][0] == expected_report["range"][0] # type: ignore[index]
assert op_report["range"][1] == expected_report["range"][1] # type: ignore[index]
assert isinstance(expected_report, dict)
assert isinstance(op_report, dict)
op_report_range = op_report["range"]
expected_report_range = expected_report["range"]
assert isinstance(op_report_range, (tuple, list))
assert isinstance(expected_report_range, (tuple, list))
assert op_report_range[0] == expected_report_range[0]
assert op_report_range[1] == expected_report_range[1]
assert op_report["bitwidth"] == expected_report["bitwidth"]


Expand Down Expand Up @@ -348,7 +352,105 @@ def test_quantized_module_rounding_fhe(model_class, input_shape, default_configu
# Execute the model with rounding in FHE execution mode
quantized_model.forward(numpy_test, fhe="execute")

# TODO: https://github.com/zama-ai/concrete-ml-internal/issues/3800
# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/3800


# Extend this test with multi-input encryption status
# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/4011
@pytest.mark.parametrize("model_class, input_shape", [pytest.param(FC, (100, 32 * 32 * 3))])
def test_inputs_encryption_status(model_class, input_shape, default_configuration):
"""Check that giving inputs_encryption_status work properly."""

torch_fc_model = model_class(activation_function=nn.ReLU)
torch_fc_model.eval()

# Create random input
numpy_input = numpy.random.uniform(size=input_shape)
torch_input = torch.from_numpy(numpy_input).float()
assert isinstance(torch_input, torch.Tensor)
assert isinstance(numpy_input, numpy.ndarray)

# Compile with rounding activated
with pytest.raises(ValueError, match="Missing arguments from '.*', expected 1 arguments."):
# Empty status
compile_torch_model(
torch_fc_model,
torch_input,
False,
default_configuration,
n_bits=2,
p_error=0.01,
rounding_threshold_bits=6,
inputs_encryption_status=[],
)

# Wrong encryption status
with pytest.raises(
ValueError, match="Unexpected status from '.*', expected 'clear' or 'encrypted'."
):
compile_torch_model(
torch_fc_model,
torch_input,
False,
default_configuration,
n_bits=2,
p_error=0.01,
rounding_threshold_bits=6,
inputs_encryption_status=("random",),
)

# Additional argument (error from Concrete Python)
with pytest.raises(ValueError, match="Too many arguments in '.*', expected 1 arguments."):
compile_torch_model(
torch_fc_model,
torch_input,
False,
default_configuration,
n_bits=2,
p_error=0.01,
rounding_threshold_bits=6,
inputs_encryption_status=(
"encrypted",
"encrypted",
),
)

# No encrypted value
with pytest.raises(ValueError, match="At least one input should be encrypted but got .*"):
compile_torch_model(
torch_fc_model,
torch_input,
False,
default_configuration,
n_bits=2,
p_error=0.01,
rounding_threshold_bits=6,
inputs_encryption_status=("clear",),
)

# Correct
compile_torch_model(
torch_fc_model,
torch_input,
False,
default_configuration,
n_bits=2,
p_error=0.01,
rounding_threshold_bits=6,
inputs_encryption_status=("encrypted",),
)

# Default (redundant with other test)
compile_torch_model(
torch_fc_model,
torch_input,
False,
default_configuration,
n_bits=2,
p_error=0.01,
rounding_threshold_bits=6,
inputs_encryption_status=None,
)


def quantized_module_predictions_are_equal(
Expand Down