Skip to content

Commit

Permalink
chore: implement hybrid model demo with GPT-2
Browse files Browse the repository at this point in the history
  • Loading branch information
fd0r committed Sep 21, 2023
1 parent c7f4cf8 commit f1d1490
Show file tree
Hide file tree
Showing 16 changed files with 930 additions and 117 deletions.
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -168,3 +168,9 @@ execution_time_of_individual_pytest_files.txt
# Docs: Advance Examples data
docs/advanced_examples/data/


# Hybrid model artifacts
use_case_examples/hybrid_model/clients/
use_case_examples/hybrid_model/compiled_models/
use_case_examples/hybrid_model/keys/
use_case_examples/hybrid_model/user_keys/
23 changes: 13 additions & 10 deletions src/concrete/ml/deployment/fhe_client_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,14 @@ def load(self):
versions = json.load(file)

errors = []
packages_to_check = {"concrete-python"}
packages_to_check = {"concrete-python", "concrete-ml"}
for package_name, package_version in versions.items():
if package_name not in packages_to_check:
continue
current_version = version(package_name)
if package_name == "concrete-ml":
current_version = CML_VERSION
else:
current_version = version(package_name)
if package_version != current_version: # pragma: no cover
errors.append((package_name, package_version, current_version))
if errors: # pragma: no cover
Expand Down Expand Up @@ -190,13 +193,10 @@ def save(self, via_mlir: bool = False):
# Add versions
versions_path = Path(self.path_dir).joinpath("versions.json")
versions = {
package_name: version(package_name)
for package_name in ["concrete-ml", "concrete-python"]
"concrete-python": version("concrete-python"),
"concrete-ml": CML_VERSION,
"python": f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}",
}
versions[
"python"
] = f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}"

with open(versions_path, "w", encoding="utf-8") as file:
json.dump(fp=file, obj=versions)

Expand Down Expand Up @@ -251,11 +251,14 @@ def load(self): # pylint: disable=no-value-for-parameter
versions = json.load(file)

errors = []
packages_to_check = {"concrete-python"}
packages_to_check = {"concrete-python", "concrete-ml"}
for package_name, package_version in versions.items():
if package_name not in packages_to_check:
continue
current_version = version(package_name)
if package_name == "concrete-ml":
current_version = CML_VERSION
else:
current_version = version(package_name)
if package_version != current_version: # pragma: no cover
errors.append((package_name, package_version, current_version))
if errors: # pragma: no cover
Expand Down
148 changes: 131 additions & 17 deletions src/concrete/ml/onnx/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,116 @@

import tempfile
from pathlib import Path
from typing import Callable, Optional, Tuple, Union
from typing import Callable, Tuple, Union

import numpy
import onnx
import onnxoptimizer
import torch
from onnx import checker
from onnx import checker, helper

from .onnx_utils import IMPLEMENTED_ONNX_OPS, execute_onnx_with_numpy, get_op_type

OPSET_VERSION_FOR_ONNX_EXPORT = 14


def get_equivalent_numpy_forward_and_onnx_model(
# pylint: disable=too-many-branches
def fuse_matmul_bias_to_gemm(onnx_model: onnx.ModelProto):
"""Fuse sequence of matmul -> add into a gemm node.
Args:
onnx_model (onnx.ModelProto): A onnx model to optimize using Mat-Mult + Add -> Gemm
Returns:
onnx.ModelProto: the optimized onnx model
"""
# Convert nodes to list to avoid modifying iterable during iteration
nodes_list = list(onnx_model.graph.node)

# Iterate through graph nodes
for matmul_node in nodes_list:
# Run only if the node is a MatMul node
if matmul_node.op_type != "MatMul":
continue
# Store MatMul node output name
matmul_node_output_name = matmul_node.output[0]
assert len(matmul_node.output) == 1

# Make sure that only one node uses the output of the mat-mult
mat_mult_output_use_node = []
for other_node in onnx_model.graph.node:
if other_node is matmul_node:
continue
if matmul_node_output_name in other_node.input:
mat_mult_output_use_node.append(other_node)
if len(mat_mult_output_use_node) != 1:
continue

# Check that following node is Add
add_node = mat_mult_output_use_node[0]
if add_node.op_type != "Add":
continue
assert len(add_node.output) == 1

# Find other Add input
bias_other_input_node_name = None
for input_name in add_node.input:
if input_name != matmul_node_output_name:
bias_other_input_node_name = input_name
assert bias_other_input_node_name is not None

# Only merge if the input of the add node is an initializer
# otherwise there might be some scaling issues
initializer_names = [elt.name for elt in onnx_model.graph.initializer]
if bias_other_input_node_name not in initializer_names:
continue

# Create a GEMM node which combines the MatMul and Add operations
gemm_node = helper.make_node(
"Gemm", # op_type
[matmul_node.input[0], matmul_node.input[1], bias_other_input_node_name], # inputs
[add_node.output[0]], # outputs
name="Gemm_Node",
alpha=1.0,
beta=1.0, # attributes
)
assert len(gemm_node.output) == 1

# Replace the MatMul and Add nodes with the GEMM node
# The graph needs to keep being topologically sorted
mat_mult_node_index = list(onnx_model.graph.node).index(matmul_node)
add_node_index = list(onnx_model.graph.node).index(add_node)
gemm_node_index = max(mat_mult_node_index, add_node_index)

onnx_model.graph.node.insert(gemm_node_index, gemm_node)
onnx_model.graph.node.remove(add_node)
onnx_model.graph.node.remove(matmul_node)

# Update connections in the graph
for potential_next_node in onnx_model.graph.node:
# Check if this node was connected to the add_node
if add_node.output[0] not in potential_next_node.input:
continue

# replace the reference to the old add_node output with the gemm_node output
for idx, potential_next_node_input in enumerate(potential_next_node.input):
if potential_next_node_input == add_node.output[0]:
potential_next_node.input[idx] = gemm_node.output[0]

# Update the model's output if necessary
for model_output in onnx_model.graph.output:
if model_output.name == add_node.output[0]:
model_output.name = gemm_node.output[0]

return onnx_model


def get_equivalent_numpy_forward_from_torch(
torch_module: torch.nn.Module,
dummy_input: Union[torch.Tensor, Tuple[torch.Tensor, ...]],
output_onnx_file: Optional[Union[Path, str]] = None,
) -> Tuple[Callable[..., Tuple[numpy.ndarray, ...]], onnx.GraphProto]:
output_onnx_file: Union[None, Path, str] = None,
) -> Tuple[Callable[..., Tuple[numpy.ndarray, ...]], onnx.ModelProto]:
"""Get the numpy equivalent forward of the provided torch Module.
Args:
Expand All @@ -34,41 +127,39 @@ def get_equivalent_numpy_forward_and_onnx_model(
execute the equivalent numpy code to the passed torch_module and the generated ONNX
model.
"""

output_onnx_file_path = Path(
tempfile.mkstemp(suffix=".onnx")[1] if output_onnx_file is None else output_onnx_file
)
use_tempfile: bool = output_onnx_file is None

# Export to ONNX
torch.onnx.export(
torch_module,
dummy_input,
str(output_onnx_file_path),
opset_version=OPSET_VERSION_FOR_ONNX_EXPORT,
)
equivalent_onnx_model = onnx.load_model(str(output_onnx_file_path))

checker.check_model(equivalent_onnx_model)

# Remove the tempfile if we used one
if use_tempfile:
output_onnx_file_path.unlink()

# The model was checked just above
equivalent_numpy_forward = get_equivalent_numpy_forward(
equivalent_onnx_model, check_model=False
equivalent_numpy_forward, equivalent_onnx_model = get_equivalent_numpy_forward_from_onnx(
equivalent_onnx_model, check_model=True
)
with output_onnx_file_path.open("wb") as file:
file.write(equivalent_onnx_model.SerializeToString())

return (
equivalent_numpy_forward,
equivalent_onnx_model,
)


def get_equivalent_numpy_forward(
def get_equivalent_numpy_forward_from_onnx(
onnx_model: onnx.ModelProto,
check_model: bool = True,
) -> Callable[..., Tuple[numpy.ndarray, ...]]:
) -> Tuple[Callable[..., Tuple[numpy.ndarray, ...]], onnx.ModelProto]:
"""Get the numpy equivalent forward of the provided ONNX model.
Args:
Expand All @@ -87,14 +178,37 @@ def get_equivalent_numpy_forward(
"""
if check_model:
checker.check_model(onnx_model)
required_onnx_operators = set(get_op_type(node) for node in onnx_model.graph.node)
unsupported_operators = required_onnx_operators - IMPLEMENTED_ONNX_OPS
checker.check_model(onnx_model)

# Optimize ONNX graph
# List of all currently supported onnx optimizer passes
# From https://github.com/onnx/optimizer/blob/master/onnxoptimizer/pass_registry.h
onnx_passes = [
"fuse_matmul_add_bias_into_gemm",
"eliminate_nop_pad",
"fuse_pad_into_conv",
"extract_constant_to_initializer",
"eliminate_unused_initializer",
]
equivalent_onnx_model = onnxoptimizer.optimize(onnx_model, onnx_passes)
checker.check_model(equivalent_onnx_model)
# Custom optimization
# ONNX optimizer does not optimize Mat-Mult + Bias pattern into GEMM if the input isn't a matrix
# We manually do the optimization for this case
equivalent_onnx_model = fuse_matmul_bias_to_gemm(equivalent_onnx_model)
checker.check_model(equivalent_onnx_model)

# Check supported operators
required_onnx_operators = set(get_op_type(node) for node in equivalent_onnx_model.graph.node)
unsupported_operators = required_onnx_operators - IMPLEMENTED_ONNX_OPS
if len(unsupported_operators) > 0:
raise ValueError(
"The following ONNX operators are required to convert the torch model to numpy but are "
f"not currently implemented: {', '.join(sorted(unsupported_operators))}.\n"
f"Available ONNX operators: {', '.join(sorted(IMPLEMENTED_ONNX_OPS))}"
)

return lambda *args: execute_onnx_with_numpy(onnx_model.graph, *args)
# Return lambda of numpy equivalent of onnx execution
return (
lambda *args: execute_onnx_with_numpy(equivalent_onnx_model.graph, *args)
), equivalent_onnx_model
4 changes: 2 additions & 2 deletions src/concrete/ml/sklearn/tree_to_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
get_onnx_opset_version,
is_regressor_or_partial_regressor,
)
from ..onnx.convert import OPSET_VERSION_FOR_ONNX_EXPORT, get_equivalent_numpy_forward
from ..onnx.convert import OPSET_VERSION_FOR_ONNX_EXPORT, get_equivalent_numpy_forward_from_onnx
from ..onnx.onnx_model_manipulations import clean_graph_at_node_op_type, remove_node_types
from ..quantization import QuantizedArray
from ..quantization.quantizers import UniformQuantizer
Expand Down Expand Up @@ -291,6 +291,6 @@ def tree_to_numpy(
# but also rounding the threshold such that they are now integers
q_y = tree_values_preprocessing(onnx_model, framework, output_n_bits)

_tree_inference = get_equivalent_numpy_forward(onnx_model)
_tree_inference, onnx_model = get_equivalent_numpy_forward_from_onnx(onnx_model)

return (_tree_inference, [q_y.quantizer], onnx_model)
18 changes: 10 additions & 8 deletions src/concrete/ml/torch/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import tempfile
import warnings
from pathlib import Path
from typing import Optional, Tuple, Union
from typing import Dict, Optional, Tuple, Union

import numpy
import onnx
Expand Down Expand Up @@ -55,7 +55,7 @@ def build_quantized_module(
model: Union[torch.nn.Module, onnx.ModelProto],
torch_inputset: Dataset,
import_qat: bool = False,
n_bits=MAX_BITWIDTH_BACKWARD_COMPATIBLE,
n_bits: Union[int, Dict[str, int]] = MAX_BITWIDTH_BACKWARD_COMPATIBLE,
rounding_threshold_bits: Optional[int] = None,
) -> QuantizedModule:
"""Build a quantized module from a Torch or ONNX model.
Expand Down Expand Up @@ -113,7 +113,7 @@ def _compile_torch_or_onnx_model(
configuration: Optional[Configuration] = None,
artifacts: Optional[DebugArtifacts] = None,
show_mlir: bool = False,
n_bits=MAX_BITWIDTH_BACKWARD_COMPATIBLE,
n_bits: Union[int, Dict[str, int]] = MAX_BITWIDTH_BACKWARD_COMPATIBLE,
rounding_threshold_bits: Optional[int] = None,
p_error: Optional[float] = None,
global_p_error: Optional[float] = None,
Expand Down Expand Up @@ -272,7 +272,7 @@ def compile_onnx_model(
configuration: Optional[Configuration] = None,
artifacts: Optional[DebugArtifacts] = None,
show_mlir: bool = False,
n_bits=MAX_BITWIDTH_BACKWARD_COMPATIBLE,
n_bits: Union[int, Dict] = MAX_BITWIDTH_BACKWARD_COMPATIBLE,
rounding_threshold_bits: Optional[int] = None,
p_error: Optional[float] = None,
global_p_error: Optional[float] = None,
Expand Down Expand Up @@ -340,7 +340,7 @@ def compile_brevitas_qat_model(
rounding_threshold_bits: Optional[int] = None,
p_error: Optional[float] = None,
global_p_error: Optional[float] = None,
output_onnx_file: Union[Path, str] = None,
output_onnx_file: Union[None, Path, str] = None,
verbose: bool = False,
) -> QuantizedModule:
"""Compile a Brevitas Quantization Aware Training model.
Expand Down Expand Up @@ -378,7 +378,6 @@ def compile_brevitas_qat_model(
Returns:
QuantizedModule: The resulting compiled QuantizedModule.
"""

inputset_as_numpy_tuple = tuple(
convert_torch_tensor_or_numpy_array_to_numpy_array(val) for val in to_tuple(torch_inputset)
)
Expand Down Expand Up @@ -418,8 +417,11 @@ def compile_brevitas_qat_model(
# https://github.com/onnx/optimizer/blob/master/onnxoptimizer/pass_registry.h
# In the export function, the `args` parameter is used instead of the `input_shape` one in
# order to be able to handle multi-inputs models
exporter.onnx_passes.append("eliminate_nop_pad")
exporter.onnx_passes.append("fuse_pad_into_conv")
exporter.onnx_passes += [
"eliminate_nop_pad",
"fuse_pad_into_conv",
"fuse_matmul_add_bias_into_gemm",
]
onnx_model = exporter.export(
torch_model,
args=dummy_input_for_tracing,
Expand Down
Loading

0 comments on commit f1d1490

Please sign in to comment.