Skip to content

Commit

Permalink
chore: implement hybrid model demo with GPT-2
Browse files Browse the repository at this point in the history
  • Loading branch information
fd0r committed Sep 14, 2023
1 parent 9520589 commit 4504d5a
Show file tree
Hide file tree
Showing 16 changed files with 840 additions and 84 deletions.
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -168,3 +168,9 @@ execution_time_of_individual_pytest_files.txt
# Docs: Advance Examples data
docs/advanced_examples/data/


# Hybrid model artifacts
use_case_examples/hybrid_model/clients/
use_case_examples/hybrid_model/compiled_models/
use_case_examples/hybrid_model/keys/
use_case_examples/hybrid_model/user_keys/
23 changes: 13 additions & 10 deletions src/concrete/ml/deployment/fhe_client_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,14 @@ def load(self):
versions = json.load(file)

errors = []
packages_to_check = {"concrete-python"}
packages_to_check = {"concrete-python", "concrete-ml"}
for package_name, package_version in versions.items():
if package_name not in packages_to_check:
continue
current_version = version(package_name)
if package_name == "concrete-ml":
current_version = CML_VERSION
else:
current_version = version(package_name)
if package_version != current_version: # pragma: no cover
errors.append((package_name, package_version, current_version))
if errors: # pragma: no cover
Expand Down Expand Up @@ -190,13 +193,10 @@ def save(self, via_mlir: bool = False):
# Add versions
versions_path = Path(self.path_dir).joinpath("versions.json")
versions = {
package_name: version(package_name)
for package_name in ["concrete-ml", "concrete-python"]
"concrete-python": version("concrete-python"),
"concrete-ml": CML_VERSION,
"python": f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}",
}
versions[
"python"
] = f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}"

with open(versions_path, "w", encoding="utf-8") as file:
json.dump(fp=file, obj=versions)

Expand Down Expand Up @@ -251,11 +251,14 @@ def load(self): # pylint: disable=no-value-for-parameter
versions = json.load(file)

errors = []
packages_to_check = {"concrete-python"}
packages_to_check = {"concrete-python", "concrete-ml"}
for package_name, package_version in versions.items():
if package_name not in packages_to_check:
continue
current_version = version(package_name)
if package_name == "concrete-ml":
current_version = CML_VERSION
else:
current_version = version(package_name)
if package_version != current_version: # pragma: no cover
errors.append((package_name, package_version, current_version))
if errors: # pragma: no cover
Expand Down
61 changes: 56 additions & 5 deletions src/concrete/ml/onnx/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@

import tempfile
from pathlib import Path
from typing import Callable, Optional, Tuple, Union
from typing import Callable, Tuple, Union

import numpy
import onnx
import onnxoptimizer
import torch
from onnx import checker

Expand All @@ -17,8 +18,8 @@
def get_equivalent_numpy_forward_and_onnx_model(
torch_module: torch.nn.Module,
dummy_input: Union[torch.Tensor, Tuple[torch.Tensor, ...]],
output_onnx_file: Optional[Union[Path, str]] = None,
) -> Tuple[Callable[..., Tuple[numpy.ndarray, ...]], onnx.GraphProto]:
output_onnx_file: Union[None, Path, str] = None,
) -> Tuple[Callable[..., Tuple[numpy.ndarray, ...]], onnx.ModelProto]:
"""Get the numpy equivalent forward of the provided torch Module.
Args:
Expand All @@ -34,7 +35,6 @@ def get_equivalent_numpy_forward_and_onnx_model(
execute the equivalent numpy code to the passed torch_module and the generated ONNX
model.
"""

output_onnx_file_path = Path(
tempfile.mkstemp(suffix=".onnx")[1] if output_onnx_file is None else output_onnx_file
)
Expand All @@ -47,7 +47,58 @@ def get_equivalent_numpy_forward_and_onnx_model(
opset_version=OPSET_VERSION_FOR_ONNX_EXPORT,
)
equivalent_onnx_model = onnx.load_model(str(output_onnx_file_path))

# List of all currently supported onnx passes
# onnx_passes = [
# 'adjust_add',
# 'rename_input_output',
# 'set_unique_name_for_nodes',
# 'nop',
# 'eliminate_nop_cast',
# 'eliminate_nop_dropout',
# 'eliminate_nop_flatten',
# 'extract_constant_to_initializer',
# 'eliminate_if_with_const_cond',
# 'eliminate_nop_monotone_argmax',
# 'eliminate_nop_pad',
# 'eliminate_nop_concat',
# 'eliminate_nop_split',
# 'eliminate_nop_expand',
# 'eliminate_shape_gather',
# 'eliminate_slice_after_shape',
# 'eliminate_nop_transpose',
# 'fuse_add_bias_into_conv',
# 'fuse_bn_into_conv',
# 'fuse_consecutive_concats',
# 'fuse_consecutive_log_softmax',
# 'fuse_consecutive_reduce_unsqueeze',
# 'fuse_consecutive_squeezes',
# 'fuse_consecutive_transposes',
# 'fuse_matmul_add_bias_into_gemm',
# 'fuse_pad_into_conv',
# 'fuse_pad_into_pool',
# 'fuse_transpose_into_gemm',
# 'replace_einsum_with_matmul',
# 'lift_lexical_references',
# 'split_init',
# 'split_predict',
# 'fuse_concat_into_reshape',
# 'eliminate_nop_reshape',
# 'eliminate_nop_with_unit',
# 'eliminate_common_subexpression',
# 'fuse_qkv',
# 'fuse_consecutive_unsqueezes',
# 'eliminate_deadend',
# 'eliminate_identity',
# 'eliminate_shape_op',
# 'fuse_consecutive_slices',
# 'eliminate_unused_initializer',
# 'eliminate_duplicate_initializer',
# 'adjust_slice_and_matmul'
# ]
onnx_passes = ["fuse_matmul_add_bias_into_gemm"]
equivalent_onnx_model = onnxoptimizer.optimize(equivalent_onnx_model, onnx_passes)
with output_onnx_file_path.open("wb") as file:
file.write(equivalent_onnx_model.SerializeToString())
checker.check_model(equivalent_onnx_model)

# Remove the tempfile if we used one
Expand Down
4 changes: 3 additions & 1 deletion src/concrete/ml/quantization/post_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -788,7 +788,9 @@ def _process_initializer(self, n_bits: int, values: numpy.ndarray):
QuantizedArray: a quantized tensor with integer values on n_bits bits
"""

if isinstance(values, numpy.ndarray) and numpy.issubdtype(values.dtype, numpy.integer):
if isinstance(values, numpy.ndarray) and numpy.issubdtype(
values.dtype, numpy.integer
): # pragma:no cover
return values.view(RawOpOutput)

assert isinstance(values, (numpy.ndarray, float))
Expand Down
19 changes: 8 additions & 11 deletions src/concrete/ml/torch/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import tempfile
import warnings
from pathlib import Path
from typing import Optional, Tuple, Union
from typing import Dict, Optional, Tuple, Union

import numpy
import onnx
Expand Down Expand Up @@ -55,7 +55,7 @@ def build_quantized_module(
model: Union[torch.nn.Module, onnx.ModelProto],
torch_inputset: Dataset,
import_qat: bool = False,
n_bits=MAX_BITWIDTH_BACKWARD_COMPATIBLE,
n_bits: Union[int, Dict[str, int]] = MAX_BITWIDTH_BACKWARD_COMPATIBLE,
rounding_threshold_bits: Optional[int] = None,
) -> QuantizedModule:
"""Build a quantized module from a Torch or ONNX model.
Expand All @@ -81,12 +81,9 @@ def build_quantized_module(
convert_torch_tensor_or_numpy_array_to_numpy_array(val) for val in to_tuple(torch_inputset)
)

# Tracing needs to be done with the batch size of 1 since we compile our models to FHE with
# this batch size. The input set contains many examples, to determine a representative
# bit-width, but for tracing we only take a single one. We need the ONNX tracing batch size to
# match the batch size during FHE inference which can only be 1 for the moment.
# No batch dimension (i.e., 0 instead of [0]) because else GEMM onnx pass can't be applied
dummy_input_for_tracing = tuple(
torch.from_numpy(val[[0], ::]).float() for val in inputset_as_numpy_tuple
torch.from_numpy(val[0, ::]).float() for val in inputset_as_numpy_tuple
)

# Create corresponding numpy model
Expand All @@ -113,7 +110,7 @@ def _compile_torch_or_onnx_model(
configuration: Optional[Configuration] = None,
artifacts: Optional[DebugArtifacts] = None,
show_mlir: bool = False,
n_bits=MAX_BITWIDTH_BACKWARD_COMPATIBLE,
n_bits: Union[int, Dict[str, int]] = MAX_BITWIDTH_BACKWARD_COMPATIBLE,
rounding_threshold_bits: Optional[int] = None,
p_error: Optional[float] = None,
global_p_error: Optional[float] = None,
Expand Down Expand Up @@ -272,7 +269,7 @@ def compile_onnx_model(
configuration: Optional[Configuration] = None,
artifacts: Optional[DebugArtifacts] = None,
show_mlir: bool = False,
n_bits=MAX_BITWIDTH_BACKWARD_COMPATIBLE,
n_bits: Union[int, Dict] = MAX_BITWIDTH_BACKWARD_COMPATIBLE,
rounding_threshold_bits: Optional[int] = None,
p_error: Optional[float] = None,
global_p_error: Optional[float] = None,
Expand Down Expand Up @@ -340,7 +337,7 @@ def compile_brevitas_qat_model(
rounding_threshold_bits: Optional[int] = None,
p_error: Optional[float] = None,
global_p_error: Optional[float] = None,
output_onnx_file: Union[Path, str] = None,
output_onnx_file: Union[None, Path, str] = None,
verbose: bool = False,
) -> QuantizedModule:
"""Compile a Brevitas Quantization Aware Training model.
Expand Down Expand Up @@ -384,7 +381,7 @@ def compile_brevitas_qat_model(
)

dummy_input_for_tracing = tuple(
torch.from_numpy(val[[0], ::]).float() for val in inputset_as_numpy_tuple
torch.from_numpy(val[0, ::]).float() for val in inputset_as_numpy_tuple
)

output_onnx_file_path = Path(
Expand Down
Loading

0 comments on commit 4504d5a

Please sign in to comment.