Skip to content

Commit

Permalink
chore: implement hybrid model demo with GPT-2
Browse files Browse the repository at this point in the history
  • Loading branch information
fd0r committed Sep 18, 2023
1 parent 9520589 commit a3a2972
Show file tree
Hide file tree
Showing 19 changed files with 959 additions and 99 deletions.
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -168,3 +168,9 @@ execution_time_of_individual_pytest_files.txt
# Docs: Advance Examples data
docs/advanced_examples/data/


# Hybrid model artifacts
use_case_examples/hybrid_model/clients/
use_case_examples/hybrid_model/compiled_models/
use_case_examples/hybrid_model/keys/
use_case_examples/hybrid_model/user_keys/
23 changes: 13 additions & 10 deletions src/concrete/ml/deployment/fhe_client_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,14 @@ def load(self):
versions = json.load(file)

errors = []
packages_to_check = {"concrete-python"}
packages_to_check = {"concrete-python", "concrete-ml"}
for package_name, package_version in versions.items():
if package_name not in packages_to_check:
continue
current_version = version(package_name)
if package_name == "concrete-ml":
current_version = CML_VERSION
else:
current_version = version(package_name)
if package_version != current_version: # pragma: no cover
errors.append((package_name, package_version, current_version))
if errors: # pragma: no cover
Expand Down Expand Up @@ -190,13 +193,10 @@ def save(self, via_mlir: bool = False):
# Add versions
versions_path = Path(self.path_dir).joinpath("versions.json")
versions = {
package_name: version(package_name)
for package_name in ["concrete-ml", "concrete-python"]
"concrete-python": version("concrete-python"),
"concrete-ml": CML_VERSION,
"python": f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}",
}
versions[
"python"
] = f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}"

with open(versions_path, "w", encoding="utf-8") as file:
json.dump(fp=file, obj=versions)

Expand Down Expand Up @@ -251,11 +251,14 @@ def load(self): # pylint: disable=no-value-for-parameter
versions = json.load(file)

errors = []
packages_to_check = {"concrete-python"}
packages_to_check = {"concrete-python", "concrete-ml"}
for package_name, package_version in versions.items():
if package_name not in packages_to_check:
continue
current_version = version(package_name)
if package_name == "concrete-ml":
current_version = CML_VERSION
else:
current_version = version(package_name)
if package_version != current_version: # pragma: no cover
errors.append((package_name, package_version, current_version))
if errors: # pragma: no cover
Expand Down
185 changes: 170 additions & 15 deletions src/concrete/ml/onnx/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,112 @@

import tempfile
from pathlib import Path
from typing import Callable, Optional, Tuple, Union
from typing import Callable, Tuple, Union

import numpy
import onnx
import onnxoptimizer
import torch
from onnx import checker
from onnx import checker, helper

from .onnx_utils import IMPLEMENTED_ONNX_OPS, execute_onnx_with_numpy, get_op_type

OPSET_VERSION_FOR_ONNX_EXPORT = 14


# pylint: disable=too-many-nested-blocks
def fuse_matmul_bias_to_gemm(onnx_model: onnx.ModelProto):
"""Fuse sequence of matmul -> add into a gemm node.
Args:
onnx_model (onnx.ModelProto): A onnx model to optimize using Mat-Mult + Add -> Gemm
Returns:
onnx.ModelProto: the optimized onnx model
"""
# Convert nodes to list to avoid modifying iterable during iteration
nodes_list = list(onnx_model.graph.node)

# Iterate through graph nodes
for node in nodes_list:
# Check if node is a MatMul node
if node.op_type == "MatMul":
# Store MatMul node output name
matmul_node_output_name = node.output[0]

bias_node_name = None
add_node = None # store the reference to the add_node

# Find the Add node which adds the bias
for potential_add_node in nodes_list:
if (
potential_add_node.op_type == "Add"
and matmul_node_output_name in potential_add_node.input
):
# Store Add node
add_node = potential_add_node
# Find the input node to the Add node which is not the MatMul node
for input_name in add_node.input:
if input_name != matmul_node_output_name:
bias_node_name = input_name

# If no bias_node_name has been assigned, continue to the next node
if not bias_node_name:
continue
assert bias_node_name is not None
assert add_node is not None

# Create a GEMM node which combines the MatMul and Add operations
gemm_node = helper.make_node(
"Gemm", # op_type
[node.input[0], node.input[1], bias_node_name], # inputs
[add_node.output[0]], # outputs
name="Gemm_Node",
alpha=1.0,
beta=1.0, # attributes
)
# gemm_node.attribute = [
# onnx.AttributeProto(f=1.0, name="alpha", type=onnx.AttributeProto.FLOAT),
# onnx.AttributeProto(f=1.0, name="beta", type=onnx.AttributeProto.FLOAT),
# onnx.AttributeProto(f=1, name="transB", type=onnx.AttributeProto.INT),
# ]

# Replace the MatMul and Add nodes with the GEMM node
mat_mult_node_index = nodes_list.index(node)
add_node_index = nodes_list.index(add_node)
gemm_node_index = max(mat_mult_node_index, add_node_index)

onnx_model.graph.node.remove(node)
onnx_model.graph.node.remove(add_node)
onnx_model.graph.node.insert(gemm_node_index, gemm_node)

# Update connections in the graph
for potential_next_node in onnx_model.graph.node:
# check if this node was connected to the add_node
if add_node.output[0] in potential_next_node.input:
# replace the reference to the old add_node output with the gemm_node output
idx = list(potential_next_node.input).index(add_node.output[0])
potential_next_node.input[idx] = gemm_node.output[0]

# Update the model's output if necessary
for model_output in onnx_model.graph.output:
if model_output.name == add_node.output[0]:
model_output.name = gemm_node.output[0]

# Update inputs and initializers
# onnx_model.graph.input.append(
# helper.make_tensor_value_info(bias_node_name, onnx.TensorProto.FLOAT, [1])
# )

return onnx_model


def get_equivalent_numpy_forward_and_onnx_model(
torch_module: torch.nn.Module,
dummy_input: Union[torch.Tensor, Tuple[torch.Tensor, ...]],
output_onnx_file: Optional[Union[Path, str]] = None,
) -> Tuple[Callable[..., Tuple[numpy.ndarray, ...]], onnx.GraphProto]:
output_onnx_file: Union[None, Path, str] = None,
) -> Tuple[Callable[..., Tuple[numpy.ndarray, ...]], onnx.ModelProto]:
"""Get the numpy equivalent forward of the provided torch Module.
Args:
Expand All @@ -34,30 +123,28 @@ def get_equivalent_numpy_forward_and_onnx_model(
execute the equivalent numpy code to the passed torch_module and the generated ONNX
model.
"""

output_onnx_file_path = Path(
tempfile.mkstemp(suffix=".onnx")[1] if output_onnx_file is None else output_onnx_file
)
use_tempfile: bool = output_onnx_file is None

# Export to ONNX
torch.onnx.export(
torch_module,
dummy_input,
str(output_onnx_file_path),
opset_version=OPSET_VERSION_FOR_ONNX_EXPORT,
)
equivalent_onnx_model = onnx.load_model(str(output_onnx_file_path))

checker.check_model(equivalent_onnx_model)

# Remove the tempfile if we used one
if use_tempfile:
output_onnx_file_path.unlink()

# The model was checked just above
equivalent_numpy_forward = get_equivalent_numpy_forward(
equivalent_onnx_model, check_model=False
equivalent_numpy_forward, equivalent_onnx_model = get_equivalent_numpy_forward(
equivalent_onnx_model, check_model=True
)
with output_onnx_file_path.open("wb") as file:
file.write(equivalent_onnx_model.SerializeToString())

return (
equivalent_numpy_forward,
Expand All @@ -68,7 +155,7 @@ def get_equivalent_numpy_forward_and_onnx_model(
def get_equivalent_numpy_forward(
onnx_model: onnx.ModelProto,
check_model: bool = True,
) -> Callable[..., Tuple[numpy.ndarray, ...]]:
) -> Tuple[Callable[..., Tuple[numpy.ndarray, ...]], onnx.ModelProto]:
"""Get the numpy equivalent forward of the provided ONNX model.
Args:
Expand All @@ -87,14 +174,82 @@ def get_equivalent_numpy_forward(
"""
if check_model:
checker.check_model(onnx_model)
required_onnx_operators = set(get_op_type(node) for node in onnx_model.graph.node)
unsupported_operators = required_onnx_operators - IMPLEMENTED_ONNX_OPS

# Optimize ONNX graph
# List of all currently supported onnx optimizer passes
# From https://github.com/onnx/optimizer/blob/master/onnxoptimizer/pass_registry.h
# onnx_passes = [
# 'adjust_add',
# 'rename_input_output',
# 'set_unique_name_for_nodes',
# 'nop',
# 'eliminate_nop_cast',
# 'eliminate_nop_dropout',
# 'eliminate_nop_flatten',
# 'extract_constant_to_initializer',
# 'eliminate_if_with_const_cond',
# 'eliminate_nop_monotone_argmax',
# 'eliminate_nop_pad',
# 'eliminate_nop_concat',
# 'eliminate_nop_split',
# 'eliminate_nop_expand',
# 'eliminate_shape_gather',
# 'eliminate_slice_after_shape',
# 'eliminate_nop_transpose',
# 'fuse_add_bias_into_conv',
# 'fuse_bn_into_conv',
# 'fuse_consecutive_concats',
# 'fuse_consecutive_log_softmax',
# 'fuse_consecutive_reduce_unsqueeze',
# 'fuse_consecutive_squeezes',
# 'fuse_consecutive_transposes',
# 'fuse_matmul_add_bias_into_gemm',
# 'fuse_pad_into_conv',
# 'fuse_pad_into_pool',
# 'fuse_transpose_into_gemm',
# 'replace_einsum_with_matmul',
# 'lift_lexical_references',
# 'split_init',
# 'split_predict',
# 'fuse_concat_into_reshape',
# 'eliminate_nop_reshape',
# 'eliminate_nop_with_unit',
# 'eliminate_common_subexpression',
# 'fuse_qkv',
# 'fuse_consecutive_unsqueezes',
# 'eliminate_deadend',
# 'eliminate_identity',
# 'eliminate_shape_op',
# 'fuse_consecutive_slices',
# 'eliminate_unused_initializer',
# 'eliminate_duplicate_initializer',
# 'adjust_slice_and_matmul'
# ]
onnx_passes = [
"fuse_matmul_add_bias_into_gemm",
"eliminate_nop_pad",
"fuse_pad_into_conv",
"extract_constant_to_initializer",
"eliminate_unused_initializer",
]
equivalent_onnx_model = onnxoptimizer.optimize(onnx_model, onnx_passes)
# Custom optimization
# ONNX optimizer does not optimize Mat-Mult + Bias pattern into GEMM if the input isn't a matrix
# We manually do the optimization for this case
equivalent_onnx_model = fuse_matmul_bias_to_gemm(equivalent_onnx_model)
checker.check_model(equivalent_onnx_model)

# Check supported operators
required_onnx_operators = set(get_op_type(node) for node in equivalent_onnx_model.graph.node)
unsupported_operators = required_onnx_operators - IMPLEMENTED_ONNX_OPS
if len(unsupported_operators) > 0:
raise ValueError(
"The following ONNX operators are required to convert the torch model to numpy but are "
f"not currently implemented: {', '.join(sorted(unsupported_operators))}.\n"
f"Available ONNX operators: {', '.join(sorted(IMPLEMENTED_ONNX_OPS))}"
)

return lambda *args: execute_onnx_with_numpy(onnx_model.graph, *args)
# Return lambda of numpy equivalent of onnx execution
return (
lambda *args: execute_onnx_with_numpy(equivalent_onnx_model.graph, *args)
), equivalent_onnx_model
4 changes: 3 additions & 1 deletion src/concrete/ml/quantization/post_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -788,7 +788,9 @@ def _process_initializer(self, n_bits: int, values: numpy.ndarray):
QuantizedArray: a quantized tensor with integer values on n_bits bits
"""

if isinstance(values, numpy.ndarray) and numpy.issubdtype(values.dtype, numpy.integer):
if isinstance(values, numpy.ndarray) and numpy.issubdtype(
values.dtype, numpy.integer
): # pragma:no cover
return values.view(RawOpOutput)

assert isinstance(values, (numpy.ndarray, float))
Expand Down
2 changes: 1 addition & 1 deletion src/concrete/ml/sklearn/tree_to_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,6 @@ def tree_to_numpy(
# but also rounding the threshold such that they are now integers
q_y = tree_values_preprocessing(onnx_model, framework, output_n_bits)

_tree_inference = get_equivalent_numpy_forward(onnx_model)
_tree_inference, onnx_model = get_equivalent_numpy_forward(onnx_model)

return (_tree_inference, [q_y.quantizer], onnx_model)
Loading

0 comments on commit a3a2972

Please sign in to comment.