Skip to content

Commit

Permalink
feat: let's have compressed keys as an option
Browse files Browse the repository at this point in the history
for now, I make it the default, but we can change that before merging.
  • Loading branch information
bcm-at-zama committed Feb 16, 2024
1 parent ce49988 commit 9f4702f
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 0 deletions.
3 changes: 3 additions & 0 deletions src/concrete/ml/quantization/quantized_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,6 +632,7 @@ def compile(
global_p_error: Optional[float] = None,
verbose: bool = False,
inputs_encryption_status: Optional[Sequence[str]] = None,
compress_evaluation_keys: bool = True,
) -> Circuit:
"""Compile the module's forward function.
Expand All @@ -655,6 +656,7 @@ def compile(
during compilation. Default to False.
inputs_encryption_status (Optional[Sequence[str]]): encryption status ('clear',
'encrypted') for each input.
compress_evaluation_keys (bool): Indicate if we compress keys. Default to True
Returns:
Circuit: The compiled Circuit.
Expand Down Expand Up @@ -745,6 +747,7 @@ def compile(
single_precision=False,
fhe_simulation=False,
fhe_execution=True,
compress_evaluation_keys=compress_evaluation_keys,
)

self._is_compiled = True
Expand Down
5 changes: 5 additions & 0 deletions src/concrete/ml/sklearn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,7 @@ def compile(
p_error: Optional[float] = None,
global_p_error: Optional[float] = None,
verbose: bool = False,
compress_evaluation_keys: bool = True,
) -> Circuit:
"""Compile the model.
Expand All @@ -525,6 +526,7 @@ def compile(
currently set to 0. Default to None, which sets this error to a default value.
verbose (bool): Indicate if compilation information should be printed
during compilation. Default to False.
compress_evaluation_keys (bool): Indicate if we compress keys. Default to True
Returns:
Circuit: The compiled Circuit.
Expand Down Expand Up @@ -572,6 +574,7 @@ def compile(
single_precision=False,
fhe_simulation=False,
fhe_execution=True,
compress_evaluation_keys=compress_evaluation_keys,
)

self._is_compiled = True
Expand Down Expand Up @@ -1148,6 +1151,7 @@ def compile(
p_error: Optional[float] = None,
global_p_error: Optional[float] = None,
verbose: bool = False,
compress_evaluation_keys: bool = True,
) -> Circuit:
# Reset for double compile
self._is_compiled = False
Expand All @@ -1170,6 +1174,7 @@ def compile(
p_error=p_error,
global_p_error=global_p_error,
verbose=verbose,
compress_evaluation_keys=compress_evaluation_keys,
)

# Make sure that no avoidable TLUs are found in the built-in model
Expand Down
12 changes: 12 additions & 0 deletions src/concrete/ml/torch/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ def _compile_torch_or_onnx_model(
verbose: bool = False,
inputs_encryption_status: Optional[Sequence[str]] = None,
reduce_sum_copy=False,
compress_evaluation_keys: bool = True,
) -> QuantizedModule:
"""Compile a torch module or ONNX into an FHE equivalent.
Expand Down Expand Up @@ -174,6 +175,7 @@ def _compile_torch_or_onnx_model(
for each input. By default all arguments will be encrypted.
reduce_sum_copy (bool): if the inputs of QuantizedReduceSum should be copied to avoid
bit-width propagation
compress_evaluation_keys (bool): Indicate if we compress keys. Default to True
Returns:
QuantizedModule: The resulting compiled QuantizedModule.
Expand Down Expand Up @@ -220,6 +222,7 @@ def _compile_torch_or_onnx_model(
global_p_error=global_p_error,
verbose=verbose,
inputs_encryption_status=inputs_encryption_status,
compress_evaluation_keys=compress_evaluation_keys,
)

return quantized_module
Expand All @@ -240,6 +243,7 @@ def compile_torch_model(
verbose: bool = False,
inputs_encryption_status: Optional[Sequence[str]] = None,
reduce_sum_copy: bool = False,
compress_evaluation_keys: bool = True,
) -> QuantizedModule:
"""Compile a torch module into an FHE equivalent.
Expand Down Expand Up @@ -274,6 +278,7 @@ def compile_torch_model(
for each input. By default all arguments will be encrypted.
reduce_sum_copy (bool): if the inputs of QuantizedReduceSum should be copied to avoid
bit-width propagation
compress_evaluation_keys (bool): Indicate if we compress keys. Default to True
Returns:
QuantizedModule: The resulting compiled QuantizedModule.
Expand Down Expand Up @@ -304,6 +309,7 @@ def compile_torch_model(
verbose=verbose,
inputs_encryption_status=inputs_encryption_status,
reduce_sum_copy=reduce_sum_copy,
compress_evaluation_keys=compress_evaluation_keys,
)


Expand All @@ -322,6 +328,7 @@ def compile_onnx_model(
verbose: bool = False,
inputs_encryption_status: Optional[Sequence[str]] = None,
reduce_sum_copy: bool = False,
compress_evaluation_keys: bool = True,
) -> QuantizedModule:
"""Compile a torch module into an FHE equivalent.
Expand Down Expand Up @@ -356,6 +363,7 @@ def compile_onnx_model(
for each input. By default all arguments will be encrypted.
reduce_sum_copy (bool): if the inputs of QuantizedReduceSum should be copied to avoid
bit-width propagation
compress_evaluation_keys (bool): Indicate if we compress keys. Default to True
Returns:
QuantizedModule: The resulting compiled QuantizedModule.
Expand All @@ -382,6 +390,7 @@ def compile_onnx_model(
verbose=verbose,
inputs_encryption_status=inputs_encryption_status,
reduce_sum_copy=reduce_sum_copy,
compress_evaluation_keys=compress_evaluation_keys,
)


Expand All @@ -400,6 +409,7 @@ def compile_brevitas_qat_model(
verbose: bool = False,
inputs_encryption_status: Optional[Sequence[str]] = None,
reduce_sum_copy: bool = False,
compress_evaluation_keys: bool = True,
) -> QuantizedModule:
"""Compile a Brevitas Quantization Aware Training model.
Expand Down Expand Up @@ -436,6 +446,7 @@ def compile_brevitas_qat_model(
for each input. By default all arguments will be encrypted.
reduce_sum_copy (bool): if the inputs of QuantizedReduceSum should be copied to avoid
bit-width propagation
compress_evaluation_keys (bool): Indicate if we compress keys. Default to True
Returns:
QuantizedModule: The resulting compiled QuantizedModule.
Expand Down Expand Up @@ -531,6 +542,7 @@ def compile_brevitas_qat_model(
verbose=verbose,
inputs_encryption_status=inputs_encryption_status,
reduce_sum_copy=reduce_sum_copy,
compress_evaluation_keys=compress_evaluation_keys,
)

# Remove the tempfile if we used one
Expand Down

0 comments on commit 9f4702f

Please sign in to comment.