-
Notifications
You must be signed in to change notification settings - Fork 154
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: expose statuses to compile torch #261
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,7 @@ | |
import copy | ||
import re | ||
from functools import partial | ||
from typing import Any, Dict, Generator, Iterable, List, Optional, TextIO, Tuple, Union | ||
from typing import Any, Dict, Generator, Iterable, List, Optional, Sequence, TextIO, Tuple, Union | ||
|
||
import numpy | ||
import onnx | ||
|
@@ -577,6 +577,7 @@ def compile( | |
p_error: Optional[float] = None, | ||
global_p_error: Optional[float] = None, | ||
verbose: bool = False, | ||
inputs_encryption_status: Optional[Sequence[str]] = None, | ||
) -> Circuit: | ||
"""Compile the module's forward function. | ||
|
||
|
@@ -598,9 +599,15 @@ def compile( | |
error to a default value. | ||
verbose (bool): Indicate if compilation information should be printed | ||
during compilation. Default to False. | ||
inputs_encryption_status (Optional[Sequence[str]]): encryption status ('clear', | ||
'encrypted') for each input. | ||
|
||
Returns: | ||
Circuit: The compiled Circuit. | ||
|
||
Raises: | ||
ValueError: if inputs_encryption_status does not match with the | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it's a bit more than that right ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Well it's if the configuration is wrong, how would you formulate it? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would usually very quickly describe what are bad configurations |
||
parameters of the quantized module | ||
""" | ||
inputs = to_tuple(inputs) | ||
|
||
|
@@ -621,9 +628,39 @@ def compile( | |
self._clear_forward, self.ordered_module_input_names | ||
) | ||
|
||
if inputs_encryption_status is None: | ||
inputs_encryption_status = tuple( | ||
"encrypted" for _ in orig_args_to_proxy_func_args.values() | ||
) | ||
else: | ||
if len(inputs_encryption_status) < len(orig_args_to_proxy_func_args.values()): | ||
raise ValueError( | ||
f"Missing arguments from '{inputs_encryption_status}', expected " | ||
f"{len(orig_args_to_proxy_func_args.values())} arguments." | ||
) | ||
if len(inputs_encryption_status) > len(orig_args_to_proxy_func_args.values()): | ||
raise ValueError( | ||
f"Too many arguments in '{inputs_encryption_status}', expected " | ||
f"{len(orig_args_to_proxy_func_args.values())} arguments." | ||
) | ||
if not all(value in {"clear", "encrypted"} for value in inputs_encryption_status): | ||
raise ValueError( | ||
f"Unexpected status from '{inputs_encryption_status}'," | ||
" expected 'clear' or 'encrypted'." | ||
) | ||
if not any(value == "encrypted" for value in inputs_encryption_status): | ||
raise ValueError( | ||
f"At least one input should be encrypted but got {inputs_encryption_status}" | ||
) | ||
|
||
assert inputs_encryption_status is not None # For mypy | ||
inputs_encryption_status_dict = dict( | ||
zip(orig_args_to_proxy_func_args.values(), inputs_encryption_status) | ||
) | ||
|
||
compiler = Compiler( | ||
forward_proxy, | ||
{arg_name: "encrypted" for arg_name in orig_args_to_proxy_func_args.values()}, | ||
parameter_encryption_statuses=inputs_encryption_status_dict, | ||
) | ||
|
||
# Quantize the inputs | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,7 +3,7 @@ | |
import tempfile | ||
import warnings | ||
from pathlib import Path | ||
from typing import Dict, Optional, Tuple, Union | ||
from typing import Dict, Optional, Sequence, Tuple, Union | ||
|
||
import numpy | ||
import onnx | ||
|
@@ -118,6 +118,7 @@ def _compile_torch_or_onnx_model( | |
p_error: Optional[float] = None, | ||
global_p_error: Optional[float] = None, | ||
verbose: bool = False, | ||
inputs_encryption_status: Optional[Sequence[str]] = None, | ||
) -> QuantizedModule: | ||
"""Compile a torch module or ONNX into an FHE equivalent. | ||
|
||
|
@@ -142,6 +143,8 @@ def _compile_torch_or_onnx_model( | |
global_p_error (Optional[float]): probability of error of the full circuit. In FHE | ||
simulation `global_p_error` is set to 0 | ||
verbose (bool): whether to show compilation information | ||
inputs_encryption_status (Optional[Sequence[str]]): encryption status ('clear', 'encrypted') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same what? There is a comment on the fact that by default everything is encrypted There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same as in "same comment as above" in order to not avoid making changes in both places if there are some |
||
for each input. By default all arguments will be encrypted. | ||
|
||
Returns: | ||
QuantizedModule: The resulting compiled QuantizedModule. | ||
|
@@ -186,6 +189,7 @@ def _compile_torch_or_onnx_model( | |
p_error=p_error, | ||
global_p_error=global_p_error, | ||
verbose=verbose, | ||
inputs_encryption_status=inputs_encryption_status, | ||
) | ||
|
||
return quantized_module | ||
|
@@ -199,11 +203,12 @@ def compile_torch_model( | |
configuration: Optional[Configuration] = None, | ||
artifacts: Optional[DebugArtifacts] = None, | ||
show_mlir: bool = False, | ||
n_bits=MAX_BITWIDTH_BACKWARD_COMPATIBLE, | ||
n_bits: Union[int, Dict[str, int]] = MAX_BITWIDTH_BACKWARD_COMPATIBLE, | ||
rounding_threshold_bits: Optional[int] = None, | ||
p_error: Optional[float] = None, | ||
global_p_error: Optional[float] = None, | ||
verbose: bool = False, | ||
inputs_encryption_status: Optional[Sequence[str]] = None, | ||
) -> QuantizedModule: | ||
"""Compile a torch module into an FHE equivalent. | ||
|
||
|
@@ -229,6 +234,8 @@ def compile_torch_model( | |
global_p_error (Optional[float]): probability of error of the full circuit. In FHE | ||
simulation `global_p_error` is set to 0 | ||
verbose (bool): whether to show compilation information | ||
inputs_encryption_status (Optional[Sequence[str]]): encryption status ('clear', 'encrypted') | ||
for each input. By default all arguments will be encrypted. | ||
|
||
Returns: | ||
QuantizedModule: The resulting compiled QuantizedModule. | ||
|
@@ -261,6 +268,7 @@ def compile_torch_model( | |
p_error=p_error, | ||
global_p_error=global_p_error, | ||
verbose=verbose, | ||
inputs_encryption_status=inputs_encryption_status, | ||
) | ||
|
||
|
||
|
@@ -272,11 +280,12 @@ def compile_onnx_model( | |
configuration: Optional[Configuration] = None, | ||
artifacts: Optional[DebugArtifacts] = None, | ||
show_mlir: bool = False, | ||
n_bits: Union[int, Dict] = MAX_BITWIDTH_BACKWARD_COMPATIBLE, | ||
n_bits: Union[int, Dict[str, int]] = MAX_BITWIDTH_BACKWARD_COMPATIBLE, | ||
rounding_threshold_bits: Optional[int] = None, | ||
p_error: Optional[float] = None, | ||
global_p_error: Optional[float] = None, | ||
verbose: bool = False, | ||
inputs_encryption_status: Optional[Sequence[str]] = None, | ||
) -> QuantizedModule: | ||
"""Compile a torch module into an FHE equivalent. | ||
|
||
|
@@ -302,6 +311,8 @@ def compile_onnx_model( | |
global_p_error (Optional[float]): probability of error of the full circuit. In FHE | ||
simulation `global_p_error` is set to 0 | ||
verbose (bool): whether to show compilation information | ||
inputs_encryption_status (Optional[Sequence[str]]): encryption status ('clear', 'encrypted') | ||
for each input. By default all arguments will be encrypted. | ||
|
||
Returns: | ||
QuantizedModule: The resulting compiled QuantizedModule. | ||
|
@@ -326,14 +337,15 @@ def compile_onnx_model( | |
p_error=p_error, | ||
global_p_error=global_p_error, | ||
verbose=verbose, | ||
inputs_encryption_status=inputs_encryption_status, | ||
) | ||
|
||
|
||
# pylint: disable-next=too-many-arguments | ||
def compile_brevitas_qat_model( | ||
torch_model: torch.nn.Module, | ||
torch_inputset: Dataset, | ||
n_bits: Optional[Union[int, dict]] = None, | ||
n_bits: Optional[Union[int, Dict[str, int]]] = None, | ||
configuration: Optional[Configuration] = None, | ||
artifacts: Optional[DebugArtifacts] = None, | ||
show_mlir: bool = False, | ||
|
@@ -342,6 +354,7 @@ def compile_brevitas_qat_model( | |
global_p_error: Optional[float] = None, | ||
output_onnx_file: Union[None, Path, str] = None, | ||
verbose: bool = False, | ||
inputs_encryption_status: Optional[Sequence[str]] = None, | ||
) -> QuantizedModule: | ||
"""Compile a Brevitas Quantization Aware Training model. | ||
|
||
|
@@ -374,6 +387,8 @@ def compile_brevitas_qat_model( | |
output_onnx_file (str): temporary file to store ONNX model. If None a temporary file | ||
is generated | ||
verbose (bool): whether to show compilation information | ||
inputs_encryption_status (Optional[Sequence[str]]): encryption status ('clear', 'encrypted') | ||
for each input. By default all arguments will be encrypted. | ||
|
||
Returns: | ||
QuantizedModule: The resulting compiled QuantizedModule. | ||
|
@@ -476,6 +491,7 @@ def compile_brevitas_qat_model( | |
p_error=p_error, | ||
global_p_error=global_p_error, | ||
verbose=verbose, | ||
inputs_encryption_status=inputs_encryption_status, | ||
) | ||
|
||
# Remove the tempfile if we used one | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would still mention that everything is encrypted if the parameter is set to None !
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's mentionned in all the user facing APIs, I don't think many users will use this one tbh
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I mean it was mostly for making the docstring consistent with what this parameters does, as we now provide a way to chose what inputs to encrypt, to avoid confusion that we don't chose which ones are encrypted or not based on whatever other rules. But sure if no one else pushes for this then go for it