Skip to content

Commit

Permalink
feat: expose statuses to compile torch
Browse files Browse the repository at this point in the history
  • Loading branch information
fd0r committed Sep 20, 2023
1 parent 4a485fc commit 02d8c10
Show file tree
Hide file tree
Showing 5 changed files with 134 additions and 5 deletions.
2 changes: 1 addition & 1 deletion .gitleaksignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
3e2986af4179c38ad5103f571f93372358e4e74f:tests/deployment/test_deployment.py:generic-api-key:59
46d53ae370263367fc49a56638f361495a0ad5d0:tests/deployment/test_deployment.py:generic-api-key:59
2d3b4ca188efb338c03d8d2c921ef39ffc5537e3:tests/deployment/test_deployment.py:generic-api-key:59
198d3fef188aaf3e3a582b9f7943f7ac6e9b5186:tests/deployment/test_deployment.py:generic-api-key:59
198d3fef188aaf3e3a582b9f7943f7ac6e9b5186:tests/deployment/test_deployment.py:generic-api-key:59
4 changes: 2 additions & 2 deletions script/make_utils/setup_os_deps.sh
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ done


linux_install_gitleaks () {
GITLEAKS_VERSION=8.5.2
GITLEAKS_LINUX_X64_SHA256=d83e4721c58638d5a2128ca70341c87fe78b6275483e7dc769a9ca6fe4d25dfd
GITLEAKS_VERSION=8.17.0
GITLEAKS_LINUX_X64_SHA256=e0e1d641cc55bcf3c0ecc1abcfc6b432e86611a53121d87ce40eacd9467f98c3

GITLEAKS_ARCHIVE_LINK="https://github.com/zricethezav/gitleaks/releases/download/v${GITLEAKS_VERSION}/gitleaks_${GITLEAKS_VERSION}_linux_x64.tar.gz"

Expand Down
21 changes: 20 additions & 1 deletion src/concrete/ml/quantization/quantized_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,6 +577,7 @@ def compile(
p_error: Optional[float] = None,
global_p_error: Optional[float] = None,
verbose: bool = False,
statuses: Optional[Dict[str, str]] = None,
) -> Circuit:
"""Compile the module's forward function.
Expand All @@ -598,9 +599,13 @@ def compile(
error to a default value.
verbose (bool): Indicate if compilation information should be printed
during compilation. Default to False.
statuses (Optional[Dict[str, str]]): same dict as Concrete Python encryption statuses
Returns:
Circuit: The compiled Circuit.
Raises:
ValueError: if statuses does not match with the parameters of the quantized module
"""
inputs = to_tuple(inputs)

Expand All @@ -621,9 +626,23 @@ def compile(
self._clear_forward, self.ordered_module_input_names
)

if statuses is None:
statuses = {arg_name: "encrypted" for arg_name in orig_args_to_proxy_func_args.values()}
else:
if not all(arg_name in statuses for arg_name in orig_args_to_proxy_func_args.values()):
raise ValueError(
f"Missing arguments from '{statuses}', expected: "
f"'{list(orig_args_to_proxy_func_args.values())}'."
)
if not all(value in {"clear", "encrypted"} for value in statuses.values()):
raise ValueError(
"Unexpected status from '{statuses}', expected clear or encrypted."
)

assert statuses is not None # For mypy
compiler = Compiler(
forward_proxy,
{arg_name: "encrypted" for arg_name in orig_args_to_proxy_func_args.values()},
parameter_encryption_statuses=statuses,
)

# Quantize the inputs
Expand Down
8 changes: 7 additions & 1 deletion src/concrete/ml/torch/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import tempfile
import warnings
from pathlib import Path
from typing import Optional, Tuple, Union
from typing import Dict, Optional, Tuple, Union

import numpy
import onnx
Expand Down Expand Up @@ -118,6 +118,7 @@ def _compile_torch_or_onnx_model(
p_error: Optional[float] = None,
global_p_error: Optional[float] = None,
verbose: bool = False,
statuses: Optional[Dict[str, str]] = None,
) -> QuantizedModule:
"""Compile a torch module or ONNX into an FHE equivalent.
Expand All @@ -142,6 +143,7 @@ def _compile_torch_or_onnx_model(
global_p_error (Optional[float]): probability of error of the full circuit. In FHE
simulation `global_p_error` is set to 0
verbose (bool): whether to show compilation information
statuses (Optional[Dict[str, str]]): same dict as Concrete Python encryption statuses
Returns:
QuantizedModule: The resulting compiled QuantizedModule.
Expand Down Expand Up @@ -186,6 +188,7 @@ def _compile_torch_or_onnx_model(
p_error=p_error,
global_p_error=global_p_error,
verbose=verbose,
statuses=statuses,
)

return quantized_module
Expand All @@ -204,6 +207,7 @@ def compile_torch_model(
p_error: Optional[float] = None,
global_p_error: Optional[float] = None,
verbose: bool = False,
statuses: Optional[Dict[str, str]] = None,
) -> QuantizedModule:
"""Compile a torch module into an FHE equivalent.
Expand All @@ -229,6 +233,7 @@ def compile_torch_model(
global_p_error (Optional[float]): probability of error of the full circuit. In FHE
simulation `global_p_error` is set to 0
verbose (bool): whether to show compilation information
statuses (Optional[Dict[str, str]]): same dict as Concrete Python encryption statuses
Returns:
QuantizedModule: The resulting compiled QuantizedModule.
Expand Down Expand Up @@ -261,6 +266,7 @@ def compile_torch_model(
p_error=p_error,
global_p_error=global_p_error,
verbose=verbose,
statuses=statuses,
)


Expand Down
104 changes: 104 additions & 0 deletions tests/quantization/test_quantized_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,110 @@ def test_quantized_module_rounding_fhe(model_class, input_shape, default_configu
# TODO: https://github.com/zama-ai/concrete-ml-internal/issues/3800


@pytest.mark.parametrize("model_class, input_shape", [pytest.param(FC, (100, 32 * 32 * 3))])
def test_statuses(model_class, input_shape, default_configuration):
"""Check that giving statuses work properly."""

torch_fc_model = model_class(activation_function=nn.ReLU)
torch_fc_model.eval()

# Create random input
numpy_input = numpy.random.uniform(size=input_shape)
torch_input = torch.from_numpy(numpy_input).float()
assert isinstance(torch_input, torch.Tensor)
assert isinstance(numpy_input, numpy.ndarray)

# Compile with rounding activated
with pytest.raises(ValueError):
# Empty status
compile_torch_model(
torch_fc_model,
torch_input,
False,
default_configuration,
n_bits=2,
p_error=0.01,
rounding_threshold_bits=6,
statuses={},
)

# Wrong name status
with pytest.raises(ValueError):
compile_torch_model(
torch_fc_model,
torch_input,
False,
default_configuration,
n_bits=2,
p_error=0.01,
rounding_threshold_bits=6,
statuses={"x": "encrypted"},
)

# Wrong encryption status
with pytest.raises(ValueError):
compile_torch_model(
torch_fc_model,
torch_input,
False,
default_configuration,
n_bits=2,
p_error=0.01,
rounding_threshold_bits=6,
statuses={"_onnx__Gemm_0": "random"},
)

# Additional argument (error from Concrete Python)
with pytest.raises(ValueError):
compile_torch_model(
torch_fc_model,
torch_input,
False,
default_configuration,
n_bits=2,
p_error=0.01,
rounding_threshold_bits=6,
statuses={"_onnx__Gemm_0": "encrypted", "test": "encrypted"},
)

# No encrypted value
with pytest.raises(RuntimeError):
compile_torch_model(
torch_fc_model,
torch_input,
False,
default_configuration,
n_bits=2,
p_error=0.01,
rounding_threshold_bits=6,
statuses={"_onnx__Gemm_0": "clear"},
)

# Correct
compile_torch_model(
torch_fc_model,
torch_input,
False,
default_configuration,
n_bits=2,
p_error=0.01,
rounding_threshold_bits=6,
statuses={"_onnx__Gemm_0": "encrypted"},
)

# Default (redundant with other test)
compile_torch_model(
torch_fc_model,
torch_input,
False,
default_configuration,
n_bits=2,
p_error=0.01,
rounding_threshold_bits=6,
statuses=None,
)


def quantized_module_predictions_are_equal(
quantized_module_1: QuantizedModule,
quantized_module_2: QuantizedModule,
Expand Down

0 comments on commit 02d8c10

Please sign in to comment.