Skip to content

Commit

Permalink
feat: expose statuses to compile torch
Browse files Browse the repository at this point in the history
  • Loading branch information
fd0r committed Sep 21, 2023
1 parent f338ac3 commit 782f399
Show file tree
Hide file tree
Showing 4 changed files with 165 additions and 11 deletions.
2 changes: 1 addition & 1 deletion .gitleaksignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
3e2986af4179c38ad5103f571f93372358e4e74f:tests/deployment/test_deployment.py:generic-api-key:59
46d53ae370263367fc49a56638f361495a0ad5d0:tests/deployment/test_deployment.py:generic-api-key:59
2d3b4ca188efb338c03d8d2c921ef39ffc5537e3:tests/deployment/test_deployment.py:generic-api-key:59
198d3fef188aaf3e3a582b9f7943f7ac6e9b5186:tests/deployment/test_deployment.py:generic-api-key:59
198d3fef188aaf3e3a582b9f7943f7ac6e9b5186:tests/deployment/test_deployment.py:generic-api-key:59
41 changes: 39 additions & 2 deletions src/concrete/ml/quantization/quantized_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import copy
import re
from functools import partial
from typing import Any, Dict, Generator, Iterable, List, Optional, TextIO, Tuple, Union
from typing import Any, Dict, Generator, Iterable, List, Optional, Sequence, TextIO, Tuple, Union

import numpy
import onnx
Expand Down Expand Up @@ -577,6 +577,7 @@ def compile(
p_error: Optional[float] = None,
global_p_error: Optional[float] = None,
verbose: bool = False,
inputs_encryption_status: Optional[Sequence[str]] = None,
) -> Circuit:
"""Compile the module's forward function.
Expand All @@ -598,9 +599,15 @@ def compile(
error to a default value.
verbose (bool): Indicate if compilation information should be printed
during compilation. Default to False.
inputs_encryption_status (Optional[Sequence[str]]): encryption status ('clear',
'encrypted') for each input.
Returns:
Circuit: The compiled Circuit.
Raises:
ValueError: if inputs_encryption_status does not match with the
parameters of the quantized module
"""
inputs = to_tuple(inputs)

Expand All @@ -621,9 +628,39 @@ def compile(
self._clear_forward, self.ordered_module_input_names
)

if inputs_encryption_status is None:
inputs_encryption_status = tuple(
"encrypted" for _ in orig_args_to_proxy_func_args.values()
)
else:
if len(inputs_encryption_status) < len(orig_args_to_proxy_func_args.values()):
raise ValueError(
f"Missing arguments from '{inputs_encryption_status}', expected "
f"{len(orig_args_to_proxy_func_args.values())} arguments."
)
if len(inputs_encryption_status) > len(orig_args_to_proxy_func_args.values()):
raise ValueError(
f"Too many arguments in '{inputs_encryption_status}', expected "
f"{len(orig_args_to_proxy_func_args.values())} arguments."
)
if not all(value in {"clear", "encrypted"} for value in inputs_encryption_status):
raise ValueError(
f"Unexpected status from '{inputs_encryption_status}',"
" expected 'clear' or 'encrypted'."
)
if not any(value == "encrypted" for value in inputs_encryption_status):
raise ValueError(
f"At least one input should be encrypted but got {inputs_encryption_status}"
)

assert inputs_encryption_status is not None # For mypy
inputs_encryption_status_dict = dict(
zip(orig_args_to_proxy_func_args.values(), inputs_encryption_status)
)

compiler = Compiler(
forward_proxy,
{arg_name: "encrypted" for arg_name in orig_args_to_proxy_func_args.values()},
parameter_encryption_statuses=inputs_encryption_status_dict,
)

# Quantize the inputs
Expand Down
24 changes: 20 additions & 4 deletions src/concrete/ml/torch/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import tempfile
import warnings
from pathlib import Path
from typing import Dict, Optional, Tuple, Union
from typing import Dict, Optional, Sequence, Tuple, Union

import numpy
import onnx
Expand Down Expand Up @@ -118,6 +118,7 @@ def _compile_torch_or_onnx_model(
p_error: Optional[float] = None,
global_p_error: Optional[float] = None,
verbose: bool = False,
inputs_encryption_status: Optional[Sequence[str]] = None,
) -> QuantizedModule:
"""Compile a torch module or ONNX into an FHE equivalent.
Expand All @@ -142,6 +143,8 @@ def _compile_torch_or_onnx_model(
global_p_error (Optional[float]): probability of error of the full circuit. In FHE
simulation `global_p_error` is set to 0
verbose (bool): whether to show compilation information
inputs_encryption_status (Optional[Sequence[str]]): encryption status ('clear', 'encrypted')
for each input.
Returns:
QuantizedModule: The resulting compiled QuantizedModule.
Expand Down Expand Up @@ -186,6 +189,7 @@ def _compile_torch_or_onnx_model(
p_error=p_error,
global_p_error=global_p_error,
verbose=verbose,
inputs_encryption_status=inputs_encryption_status,
)

return quantized_module
Expand All @@ -199,11 +203,12 @@ def compile_torch_model(
configuration: Optional[Configuration] = None,
artifacts: Optional[DebugArtifacts] = None,
show_mlir: bool = False,
n_bits=MAX_BITWIDTH_BACKWARD_COMPATIBLE,
n_bits: Union[int, Dict[str, int]] = MAX_BITWIDTH_BACKWARD_COMPATIBLE,
rounding_threshold_bits: Optional[int] = None,
p_error: Optional[float] = None,
global_p_error: Optional[float] = None,
verbose: bool = False,
inputs_encryption_status: Optional[Sequence[str]] = None,
) -> QuantizedModule:
"""Compile a torch module into an FHE equivalent.
Expand All @@ -229,6 +234,8 @@ def compile_torch_model(
global_p_error (Optional[float]): probability of error of the full circuit. In FHE
simulation `global_p_error` is set to 0
verbose (bool): whether to show compilation information
inputs_encryption_status (Optional[Sequence[str]]): encryption status ('clear', 'encrypted')
for each input.
Returns:
QuantizedModule: The resulting compiled QuantizedModule.
Expand Down Expand Up @@ -261,6 +268,7 @@ def compile_torch_model(
p_error=p_error,
global_p_error=global_p_error,
verbose=verbose,
inputs_encryption_status=inputs_encryption_status,
)


Expand All @@ -272,11 +280,12 @@ def compile_onnx_model(
configuration: Optional[Configuration] = None,
artifacts: Optional[DebugArtifacts] = None,
show_mlir: bool = False,
n_bits: Union[int, Dict] = MAX_BITWIDTH_BACKWARD_COMPATIBLE,
n_bits: Union[int, Dict[str, int]] = MAX_BITWIDTH_BACKWARD_COMPATIBLE,
rounding_threshold_bits: Optional[int] = None,
p_error: Optional[float] = None,
global_p_error: Optional[float] = None,
verbose: bool = False,
inputs_encryption_status: Optional[Sequence[str]] = None,
) -> QuantizedModule:
"""Compile a torch module into an FHE equivalent.
Expand All @@ -302,6 +311,8 @@ def compile_onnx_model(
global_p_error (Optional[float]): probability of error of the full circuit. In FHE
simulation `global_p_error` is set to 0
verbose (bool): whether to show compilation information
inputs_encryption_status (Optional[Sequence[str]]): encryption status ('clear', 'encrypted')
for each input.
Returns:
QuantizedModule: The resulting compiled QuantizedModule.
Expand All @@ -326,14 +337,15 @@ def compile_onnx_model(
p_error=p_error,
global_p_error=global_p_error,
verbose=verbose,
inputs_encryption_status=inputs_encryption_status,
)


# pylint: disable-next=too-many-arguments
def compile_brevitas_qat_model(
torch_model: torch.nn.Module,
torch_inputset: Dataset,
n_bits: Optional[Union[int, dict]] = None,
n_bits: Optional[Union[int, Dict[str, int]]] = None,
configuration: Optional[Configuration] = None,
artifacts: Optional[DebugArtifacts] = None,
show_mlir: bool = False,
Expand All @@ -342,6 +354,7 @@ def compile_brevitas_qat_model(
global_p_error: Optional[float] = None,
output_onnx_file: Union[None, Path, str] = None,
verbose: bool = False,
inputs_encryption_status: Optional[Sequence[str]] = None,
) -> QuantizedModule:
"""Compile a Brevitas Quantization Aware Training model.
Expand Down Expand Up @@ -374,6 +387,8 @@ def compile_brevitas_qat_model(
output_onnx_file (str): temporary file to store ONNX model. If None a temporary file
is generated
verbose (bool): whether to show compilation information
inputs_encryption_status (Optional[Sequence[str]]): encryption status ('clear', 'encrypted')
for each input.
Returns:
QuantizedModule: The resulting compiled QuantizedModule.
Expand Down Expand Up @@ -476,6 +491,7 @@ def compile_brevitas_qat_model(
p_error=p_error,
global_p_error=global_p_error,
verbose=verbose,
inputs_encryption_status=inputs_encryption_status,
)

# Remove the tempfile if we used one
Expand Down
109 changes: 105 additions & 4 deletions tests/quantization/test_quantized_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,10 +305,14 @@ def test_bitwidth_report(model_class, input_shape, activation_function, default_
)

# Ensure the value are the expected ones
# Disable mypy here as it does not seem to understand that the `range` value is
# a tuple in both the report dictionaries
assert op_report["range"][0] == expected_report["range"][0] # type: ignore[index]
assert op_report["range"][1] == expected_report["range"][1] # type: ignore[index]
assert isinstance(expected_report, dict)
assert isinstance(op_report, dict)
op_report_range = op_report["range"]
expected_report_range = expected_report["range"]
assert isinstance(op_report_range, (tuple, list))
assert isinstance(expected_report_range, (tuple, list))
assert op_report_range[0] == expected_report_range[0]
assert op_report_range[1] == expected_report_range[1]
assert op_report["bitwidth"] == expected_report["bitwidth"]


Expand Down Expand Up @@ -351,6 +355,103 @@ def test_quantized_module_rounding_fhe(model_class, input_shape, default_configu
# TODO: https://github.com/zama-ai/concrete-ml-internal/issues/3800


# TODO: extend this test with multi-input encryption status
@pytest.mark.parametrize("model_class, input_shape", [pytest.param(FC, (100, 32 * 32 * 3))])
def test_inputs_encryption_status(model_class, input_shape, default_configuration):
"""Check that giving inputs_encryption_status work properly."""

torch_fc_model = model_class(activation_function=nn.ReLU)
torch_fc_model.eval()

# Create random input
numpy_input = numpy.random.uniform(size=input_shape)
torch_input = torch.from_numpy(numpy_input).float()
assert isinstance(torch_input, torch.Tensor)
assert isinstance(numpy_input, numpy.ndarray)

# Compile with rounding activated
with pytest.raises(ValueError, match="Missing arguments from '.*', expected 1 arguments."):
# Empty status
compile_torch_model(
torch_fc_model,
torch_input,
False,
default_configuration,
n_bits=2,
p_error=0.01,
rounding_threshold_bits=6,
inputs_encryption_status=[],
)

# Wrong encryption status
with pytest.raises(
ValueError, match="Unexpected status from '.*', expected 'clear' or 'encrypted'."
):
compile_torch_model(
torch_fc_model,
torch_input,
False,
default_configuration,
n_bits=2,
p_error=0.01,
rounding_threshold_bits=6,
inputs_encryption_status=("random",),
)

# Additional argument (error from Concrete Python)
with pytest.raises(ValueError, match="Too many arguments in '.*', expected 1 arguments."):
compile_torch_model(
torch_fc_model,
torch_input,
False,
default_configuration,
n_bits=2,
p_error=0.01,
rounding_threshold_bits=6,
inputs_encryption_status=(
"encrypted",
"encrypted",
),
)

# No encrypted value
with pytest.raises(ValueError, match="At least one input should be encrypted but got .*"):
compile_torch_model(
torch_fc_model,
torch_input,
False,
default_configuration,
n_bits=2,
p_error=0.01,
rounding_threshold_bits=6,
inputs_encryption_status=("clear",),
)

# Correct
compile_torch_model(
torch_fc_model,
torch_input,
False,
default_configuration,
n_bits=2,
p_error=0.01,
rounding_threshold_bits=6,
inputs_encryption_status=("encrypted",),
)

# Default (redundant with other test)
compile_torch_model(
torch_fc_model,
torch_input,
False,
default_configuration,
n_bits=2,
p_error=0.01,
rounding_threshold_bits=6,
inputs_encryption_status=None,
)


def quantized_module_predictions_are_equal(
quantized_module_1: QuantizedModule,
quantized_module_2: QuantizedModule,
Expand Down

0 comments on commit 782f399

Please sign in to comment.