Skip to content

Commit

Permalink
chore: implement hybrid model demo with GPT-2
Browse files Browse the repository at this point in the history
  • Loading branch information
fd0r committed Sep 13, 2023
1 parent 9520589 commit 5c9cfe8
Show file tree
Hide file tree
Showing 14 changed files with 820 additions and 80 deletions.
21 changes: 12 additions & 9 deletions src/concrete/ml/deployment/fhe_client_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,14 @@ def load(self):
versions = json.load(file)

errors = []
packages_to_check = {"concrete-python"}
packages_to_check = {"concrete-python", "concrete-ml"}
for package_name, package_version in versions.items():
if package_name not in packages_to_check:
continue
current_version = version(package_name)
if package_name == "concrete-ml":
current_version = CML_VERSION
else:
current_version = version(package_name)
if package_version != current_version: # pragma: no cover
errors.append((package_name, package_version, current_version))
if errors: # pragma: no cover
Expand Down Expand Up @@ -190,13 +193,10 @@ def save(self, via_mlir: bool = False):
# Add versions
versions_path = Path(self.path_dir).joinpath("versions.json")
versions = {
package_name: version(package_name)
for package_name in ["concrete-ml", "concrete-python"]
"concrete-python": version("concrete-python"),
"concrete-ml": CML_VERSION,
"python": f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}",
}
versions[
"python"
] = f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}"

with open(versions_path, "w", encoding="utf-8") as file:
json.dump(fp=file, obj=versions)

Expand Down Expand Up @@ -255,7 +255,10 @@ def load(self): # pylint: disable=no-value-for-parameter
for package_name, package_version in versions.items():
if package_name not in packages_to_check:
continue
current_version = version(package_name)
if package_name == "concrete-ml":
current_version = CML_VERSION
else:
current_version = version(package_name)
if package_version != current_version: # pragma: no cover
errors.append((package_name, package_version, current_version))
if errors: # pragma: no cover
Expand Down
61 changes: 56 additions & 5 deletions src/concrete/ml/onnx/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@

import tempfile
from pathlib import Path
from typing import Callable, Optional, Tuple, Union
from typing import Callable, Tuple, Union

import numpy
import onnx
import onnxoptimizer
import torch
from onnx import checker

Expand All @@ -17,8 +18,8 @@
def get_equivalent_numpy_forward_and_onnx_model(
torch_module: torch.nn.Module,
dummy_input: Union[torch.Tensor, Tuple[torch.Tensor, ...]],
output_onnx_file: Optional[Union[Path, str]] = None,
) -> Tuple[Callable[..., Tuple[numpy.ndarray, ...]], onnx.GraphProto]:
output_onnx_file: Union[None, Path, str] = None,
) -> Tuple[Callable[..., Tuple[numpy.ndarray, ...]], onnx.ModelProto]:
"""Get the numpy equivalent forward of the provided torch Module.
Args:
Expand All @@ -34,7 +35,6 @@ def get_equivalent_numpy_forward_and_onnx_model(
execute the equivalent numpy code to the passed torch_module and the generated ONNX
model.
"""

output_onnx_file_path = Path(
tempfile.mkstemp(suffix=".onnx")[1] if output_onnx_file is None else output_onnx_file
)
Expand All @@ -47,7 +47,58 @@ def get_equivalent_numpy_forward_and_onnx_model(
opset_version=OPSET_VERSION_FOR_ONNX_EXPORT,
)
equivalent_onnx_model = onnx.load_model(str(output_onnx_file_path))

# List of all currently supported onnx passes
# onnx_passes = [
# 'adjust_add',
# 'rename_input_output',
# 'set_unique_name_for_nodes',
# 'nop',
# 'eliminate_nop_cast',
# 'eliminate_nop_dropout',
# 'eliminate_nop_flatten',
# 'extract_constant_to_initializer',
# 'eliminate_if_with_const_cond',
# 'eliminate_nop_monotone_argmax',
# 'eliminate_nop_pad',
# 'eliminate_nop_concat',
# 'eliminate_nop_split',
# 'eliminate_nop_expand',
# 'eliminate_shape_gather',
# 'eliminate_slice_after_shape',
# 'eliminate_nop_transpose',
# 'fuse_add_bias_into_conv',
# 'fuse_bn_into_conv',
# 'fuse_consecutive_concats',
# 'fuse_consecutive_log_softmax',
# 'fuse_consecutive_reduce_unsqueeze',
# 'fuse_consecutive_squeezes',
# 'fuse_consecutive_transposes',
# 'fuse_matmul_add_bias_into_gemm',
# 'fuse_pad_into_conv',
# 'fuse_pad_into_pool',
# 'fuse_transpose_into_gemm',
# 'replace_einsum_with_matmul',
# 'lift_lexical_references',
# 'split_init',
# 'split_predict',
# 'fuse_concat_into_reshape',
# 'eliminate_nop_reshape',
# 'eliminate_nop_with_unit',
# 'eliminate_common_subexpression',
# 'fuse_qkv',
# 'fuse_consecutive_unsqueezes',
# 'eliminate_deadend',
# 'eliminate_identity',
# 'eliminate_shape_op',
# 'fuse_consecutive_slices',
# 'eliminate_unused_initializer',
# 'eliminate_duplicate_initializer',
# 'adjust_slice_and_matmul'
# ]
onnx_passes = ["fuse_matmul_add_bias_into_gemm"]
equivalent_onnx_model = onnxoptimizer.optimize(equivalent_onnx_model, onnx_passes)
with output_onnx_file_path.open("wb") as file:
file.write(equivalent_onnx_model.SerializeToString())
checker.check_model(equivalent_onnx_model)

# Remove the tempfile if we used one
Expand Down
17 changes: 7 additions & 10 deletions src/concrete/ml/torch/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import tempfile
import warnings
from pathlib import Path
from typing import Optional, Tuple, Union
from typing import Dict, Optional, Tuple, Union

import numpy
import onnx
Expand Down Expand Up @@ -55,7 +55,7 @@ def build_quantized_module(
model: Union[torch.nn.Module, onnx.ModelProto],
torch_inputset: Dataset,
import_qat: bool = False,
n_bits=MAX_BITWIDTH_BACKWARD_COMPATIBLE,
n_bits: Union[int, Dict[str, int]] = MAX_BITWIDTH_BACKWARD_COMPATIBLE,
rounding_threshold_bits: Optional[int] = None,
) -> QuantizedModule:
"""Build a quantized module from a Torch or ONNX model.
Expand All @@ -81,12 +81,9 @@ def build_quantized_module(
convert_torch_tensor_or_numpy_array_to_numpy_array(val) for val in to_tuple(torch_inputset)
)

# Tracing needs to be done with the batch size of 1 since we compile our models to FHE with
# this batch size. The input set contains many examples, to determine a representative
# bit-width, but for tracing we only take a single one. We need the ONNX tracing batch size to
# match the batch size during FHE inference which can only be 1 for the moment.
# No batch dimension (i.e. 0 instead of [0]) because else GEMM onnx pass is not applied
dummy_input_for_tracing = tuple(
torch.from_numpy(val[[0], ::]).float() for val in inputset_as_numpy_tuple
torch.from_numpy(val[0, ::]).float() for val in inputset_as_numpy_tuple
)

# Create corresponding numpy model
Expand All @@ -113,7 +110,7 @@ def _compile_torch_or_onnx_model(
configuration: Optional[Configuration] = None,
artifacts: Optional[DebugArtifacts] = None,
show_mlir: bool = False,
n_bits=MAX_BITWIDTH_BACKWARD_COMPATIBLE,
n_bits: Union[int, Dict[str, int]] = MAX_BITWIDTH_BACKWARD_COMPATIBLE,
rounding_threshold_bits: Optional[int] = None,
p_error: Optional[float] = None,
global_p_error: Optional[float] = None,
Expand Down Expand Up @@ -272,7 +269,7 @@ def compile_onnx_model(
configuration: Optional[Configuration] = None,
artifacts: Optional[DebugArtifacts] = None,
show_mlir: bool = False,
n_bits=MAX_BITWIDTH_BACKWARD_COMPATIBLE,
n_bits: Union[int, Dict] = MAX_BITWIDTH_BACKWARD_COMPATIBLE,
rounding_threshold_bits: Optional[int] = None,
p_error: Optional[float] = None,
global_p_error: Optional[float] = None,
Expand Down Expand Up @@ -340,7 +337,7 @@ def compile_brevitas_qat_model(
rounding_threshold_bits: Optional[int] = None,
p_error: Optional[float] = None,
global_p_error: Optional[float] = None,
output_onnx_file: Union[Path, str] = None,
output_onnx_file: Union[None, Path, str] = None,
verbose: bool = False,
) -> QuantizedModule:
"""Compile a Brevitas Quantization Aware Training model.
Expand Down
Loading

0 comments on commit 5c9cfe8

Please sign in to comment.