From a3a297298ea5b13cfe7264faf201561415a2efee Mon Sep 17 00:00:00 2001 From: Luis Montero Date: Wed, 2 Aug 2023 11:04:03 +0200 Subject: [PATCH] chore: implement hybrid model demo with GPT-2 --- .gitignore | 6 + .../ml/deployment/fhe_client_server.py | 23 +- src/concrete/ml/onnx/convert.py | 185 ++++++++++- src/concrete/ml/quantization/post_training.py | 4 +- src/concrete/ml/sklearn/tree_to_numpy.py | 2 +- src/concrete/ml/torch/compile.py | 23 +- src/concrete/ml/torch/hybrid_model.py | 314 +++++++++++++++--- src/concrete/ml/torch/numpy_module.py | 3 +- tests/torch/test_compile_torch.py | 3 +- tests/torch/test_hybrid_converter.py | 3 +- use_case_examples/hybrid_model/README.md | 14 + use_case_examples/hybrid_model/compile.sh | 6 + .../hybrid_model/compile_hybrid_llm.py | 140 ++++++++ .../hybrid_model/infer_hybrid_llm_generate.py | 89 +++++ .../hybrid_model/load_and_analyze_data.py | 32 ++ .../hybrid_model/requirements.txt | 3 + use_case_examples/hybrid_model/serve.sh | 13 + use_case_examples/hybrid_model/serve_model.py | 193 +++++++++++ use_case_examples/llm/QGPT2Evaluate.ipynb | 2 +- 19 files changed, 959 insertions(+), 99 deletions(-) create mode 100644 use_case_examples/hybrid_model/README.md create mode 100644 use_case_examples/hybrid_model/compile.sh create mode 100644 use_case_examples/hybrid_model/compile_hybrid_llm.py create mode 100644 use_case_examples/hybrid_model/infer_hybrid_llm_generate.py create mode 100644 use_case_examples/hybrid_model/load_and_analyze_data.py create mode 100644 use_case_examples/hybrid_model/requirements.txt create mode 100644 use_case_examples/hybrid_model/serve.sh create mode 100644 use_case_examples/hybrid_model/serve_model.py diff --git a/.gitignore b/.gitignore index 8575460e70..e9b69af5e9 100644 --- a/.gitignore +++ b/.gitignore @@ -168,3 +168,9 @@ execution_time_of_individual_pytest_files.txt # Docs: Advance Examples data docs/advanced_examples/data/ + +# Hybrid model artifacts +use_case_examples/hybrid_model/clients/ +use_case_examples/hybrid_model/compiled_models/ +use_case_examples/hybrid_model/keys/ +use_case_examples/hybrid_model/user_keys/ \ No newline at end of file diff --git a/src/concrete/ml/deployment/fhe_client_server.py b/src/concrete/ml/deployment/fhe_client_server.py index 8feb481dd3..48ac74598f 100644 --- a/src/concrete/ml/deployment/fhe_client_server.py +++ b/src/concrete/ml/deployment/fhe_client_server.py @@ -54,11 +54,14 @@ def load(self): versions = json.load(file) errors = [] - packages_to_check = {"concrete-python"} + packages_to_check = {"concrete-python", "concrete-ml"} for package_name, package_version in versions.items(): if package_name not in packages_to_check: continue - current_version = version(package_name) + if package_name == "concrete-ml": + current_version = CML_VERSION + else: + current_version = version(package_name) if package_version != current_version: # pragma: no cover errors.append((package_name, package_version, current_version)) if errors: # pragma: no cover @@ -190,13 +193,10 @@ def save(self, via_mlir: bool = False): # Add versions versions_path = Path(self.path_dir).joinpath("versions.json") versions = { - package_name: version(package_name) - for package_name in ["concrete-ml", "concrete-python"] + "concrete-python": version("concrete-python"), + "concrete-ml": CML_VERSION, + "python": f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}", } - versions[ - "python" - ] = f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}" - with open(versions_path, "w", encoding="utf-8") as file: json.dump(fp=file, obj=versions) @@ -251,11 +251,14 @@ def load(self): # pylint: disable=no-value-for-parameter versions = json.load(file) errors = [] - packages_to_check = {"concrete-python"} + packages_to_check = {"concrete-python", "concrete-ml"} for package_name, package_version in versions.items(): if package_name not in packages_to_check: continue - current_version = version(package_name) + if package_name == "concrete-ml": + current_version = CML_VERSION + else: + current_version = version(package_name) if package_version != current_version: # pragma: no cover errors.append((package_name, package_version, current_version)) if errors: # pragma: no cover diff --git a/src/concrete/ml/onnx/convert.py b/src/concrete/ml/onnx/convert.py index d1e2fcd0e8..22b90edacd 100644 --- a/src/concrete/ml/onnx/convert.py +++ b/src/concrete/ml/onnx/convert.py @@ -2,23 +2,112 @@ import tempfile from pathlib import Path -from typing import Callable, Optional, Tuple, Union +from typing import Callable, Tuple, Union import numpy import onnx +import onnxoptimizer import torch -from onnx import checker +from onnx import checker, helper from .onnx_utils import IMPLEMENTED_ONNX_OPS, execute_onnx_with_numpy, get_op_type OPSET_VERSION_FOR_ONNX_EXPORT = 14 +# pylint: disable=too-many-nested-blocks +def fuse_matmul_bias_to_gemm(onnx_model: onnx.ModelProto): + """Fuse sequence of matmul -> add into a gemm node. + + Args: + onnx_model (onnx.ModelProto): A onnx model to optimize using Mat-Mult + Add -> Gemm + + Returns: + onnx.ModelProto: the optimized onnx model + + """ + # Convert nodes to list to avoid modifying iterable during iteration + nodes_list = list(onnx_model.graph.node) + + # Iterate through graph nodes + for node in nodes_list: + # Check if node is a MatMul node + if node.op_type == "MatMul": + # Store MatMul node output name + matmul_node_output_name = node.output[0] + + bias_node_name = None + add_node = None # store the reference to the add_node + + # Find the Add node which adds the bias + for potential_add_node in nodes_list: + if ( + potential_add_node.op_type == "Add" + and matmul_node_output_name in potential_add_node.input + ): + # Store Add node + add_node = potential_add_node + # Find the input node to the Add node which is not the MatMul node + for input_name in add_node.input: + if input_name != matmul_node_output_name: + bias_node_name = input_name + + # If no bias_node_name has been assigned, continue to the next node + if not bias_node_name: + continue + assert bias_node_name is not None + assert add_node is not None + + # Create a GEMM node which combines the MatMul and Add operations + gemm_node = helper.make_node( + "Gemm", # op_type + [node.input[0], node.input[1], bias_node_name], # inputs + [add_node.output[0]], # outputs + name="Gemm_Node", + alpha=1.0, + beta=1.0, # attributes + ) + # gemm_node.attribute = [ + # onnx.AttributeProto(f=1.0, name="alpha", type=onnx.AttributeProto.FLOAT), + # onnx.AttributeProto(f=1.0, name="beta", type=onnx.AttributeProto.FLOAT), + # onnx.AttributeProto(f=1, name="transB", type=onnx.AttributeProto.INT), + # ] + + # Replace the MatMul and Add nodes with the GEMM node + mat_mult_node_index = nodes_list.index(node) + add_node_index = nodes_list.index(add_node) + gemm_node_index = max(mat_mult_node_index, add_node_index) + + onnx_model.graph.node.remove(node) + onnx_model.graph.node.remove(add_node) + onnx_model.graph.node.insert(gemm_node_index, gemm_node) + + # Update connections in the graph + for potential_next_node in onnx_model.graph.node: + # check if this node was connected to the add_node + if add_node.output[0] in potential_next_node.input: + # replace the reference to the old add_node output with the gemm_node output + idx = list(potential_next_node.input).index(add_node.output[0]) + potential_next_node.input[idx] = gemm_node.output[0] + + # Update the model's output if necessary + for model_output in onnx_model.graph.output: + if model_output.name == add_node.output[0]: + model_output.name = gemm_node.output[0] + + # Update inputs and initializers + # onnx_model.graph.input.append( + # helper.make_tensor_value_info(bias_node_name, onnx.TensorProto.FLOAT, [1]) + # ) + + return onnx_model + + def get_equivalent_numpy_forward_and_onnx_model( torch_module: torch.nn.Module, dummy_input: Union[torch.Tensor, Tuple[torch.Tensor, ...]], - output_onnx_file: Optional[Union[Path, str]] = None, -) -> Tuple[Callable[..., Tuple[numpy.ndarray, ...]], onnx.GraphProto]: + output_onnx_file: Union[None, Path, str] = None, +) -> Tuple[Callable[..., Tuple[numpy.ndarray, ...]], onnx.ModelProto]: """Get the numpy equivalent forward of the provided torch Module. Args: @@ -34,12 +123,12 @@ def get_equivalent_numpy_forward_and_onnx_model( execute the equivalent numpy code to the passed torch_module and the generated ONNX model. """ - output_onnx_file_path = Path( tempfile.mkstemp(suffix=".onnx")[1] if output_onnx_file is None else output_onnx_file ) use_tempfile: bool = output_onnx_file is None + # Export to ONNX torch.onnx.export( torch_module, dummy_input, @@ -47,17 +136,15 @@ def get_equivalent_numpy_forward_and_onnx_model( opset_version=OPSET_VERSION_FOR_ONNX_EXPORT, ) equivalent_onnx_model = onnx.load_model(str(output_onnx_file_path)) - - checker.check_model(equivalent_onnx_model) - # Remove the tempfile if we used one if use_tempfile: output_onnx_file_path.unlink() - # The model was checked just above - equivalent_numpy_forward = get_equivalent_numpy_forward( - equivalent_onnx_model, check_model=False + equivalent_numpy_forward, equivalent_onnx_model = get_equivalent_numpy_forward( + equivalent_onnx_model, check_model=True ) + with output_onnx_file_path.open("wb") as file: + file.write(equivalent_onnx_model.SerializeToString()) return ( equivalent_numpy_forward, @@ -68,7 +155,7 @@ def get_equivalent_numpy_forward_and_onnx_model( def get_equivalent_numpy_forward( onnx_model: onnx.ModelProto, check_model: bool = True, -) -> Callable[..., Tuple[numpy.ndarray, ...]]: +) -> Tuple[Callable[..., Tuple[numpy.ndarray, ...]], onnx.ModelProto]: """Get the numpy equivalent forward of the provided ONNX model. Args: @@ -87,9 +174,74 @@ def get_equivalent_numpy_forward( """ if check_model: checker.check_model(onnx_model) - required_onnx_operators = set(get_op_type(node) for node in onnx_model.graph.node) - unsupported_operators = required_onnx_operators - IMPLEMENTED_ONNX_OPS + # Optimize ONNX graph + # List of all currently supported onnx optimizer passes + # From https://github.com/onnx/optimizer/blob/master/onnxoptimizer/pass_registry.h + # onnx_passes = [ + # 'adjust_add', + # 'rename_input_output', + # 'set_unique_name_for_nodes', + # 'nop', + # 'eliminate_nop_cast', + # 'eliminate_nop_dropout', + # 'eliminate_nop_flatten', + # 'extract_constant_to_initializer', + # 'eliminate_if_with_const_cond', + # 'eliminate_nop_monotone_argmax', + # 'eliminate_nop_pad', + # 'eliminate_nop_concat', + # 'eliminate_nop_split', + # 'eliminate_nop_expand', + # 'eliminate_shape_gather', + # 'eliminate_slice_after_shape', + # 'eliminate_nop_transpose', + # 'fuse_add_bias_into_conv', + # 'fuse_bn_into_conv', + # 'fuse_consecutive_concats', + # 'fuse_consecutive_log_softmax', + # 'fuse_consecutive_reduce_unsqueeze', + # 'fuse_consecutive_squeezes', + # 'fuse_consecutive_transposes', + # 'fuse_matmul_add_bias_into_gemm', + # 'fuse_pad_into_conv', + # 'fuse_pad_into_pool', + # 'fuse_transpose_into_gemm', + # 'replace_einsum_with_matmul', + # 'lift_lexical_references', + # 'split_init', + # 'split_predict', + # 'fuse_concat_into_reshape', + # 'eliminate_nop_reshape', + # 'eliminate_nop_with_unit', + # 'eliminate_common_subexpression', + # 'fuse_qkv', + # 'fuse_consecutive_unsqueezes', + # 'eliminate_deadend', + # 'eliminate_identity', + # 'eliminate_shape_op', + # 'fuse_consecutive_slices', + # 'eliminate_unused_initializer', + # 'eliminate_duplicate_initializer', + # 'adjust_slice_and_matmul' + # ] + onnx_passes = [ + "fuse_matmul_add_bias_into_gemm", + "eliminate_nop_pad", + "fuse_pad_into_conv", + "extract_constant_to_initializer", + "eliminate_unused_initializer", + ] + equivalent_onnx_model = onnxoptimizer.optimize(onnx_model, onnx_passes) + # Custom optimization + # ONNX optimizer does not optimize Mat-Mult + Bias pattern into GEMM if the input isn't a matrix + # We manually do the optimization for this case + equivalent_onnx_model = fuse_matmul_bias_to_gemm(equivalent_onnx_model) + checker.check_model(equivalent_onnx_model) + + # Check supported operators + required_onnx_operators = set(get_op_type(node) for node in equivalent_onnx_model.graph.node) + unsupported_operators = required_onnx_operators - IMPLEMENTED_ONNX_OPS if len(unsupported_operators) > 0: raise ValueError( "The following ONNX operators are required to convert the torch model to numpy but are " @@ -97,4 +249,7 @@ def get_equivalent_numpy_forward( f"Available ONNX operators: {', '.join(sorted(IMPLEMENTED_ONNX_OPS))}" ) - return lambda *args: execute_onnx_with_numpy(onnx_model.graph, *args) + # Return lambda of numpy equivalent of onnx execution + return ( + lambda *args: execute_onnx_with_numpy(equivalent_onnx_model.graph, *args) + ), equivalent_onnx_model diff --git a/src/concrete/ml/quantization/post_training.py b/src/concrete/ml/quantization/post_training.py index fcbe6abd5c..c32bec5b1a 100644 --- a/src/concrete/ml/quantization/post_training.py +++ b/src/concrete/ml/quantization/post_training.py @@ -788,7 +788,9 @@ def _process_initializer(self, n_bits: int, values: numpy.ndarray): QuantizedArray: a quantized tensor with integer values on n_bits bits """ - if isinstance(values, numpy.ndarray) and numpy.issubdtype(values.dtype, numpy.integer): + if isinstance(values, numpy.ndarray) and numpy.issubdtype( + values.dtype, numpy.integer + ): # pragma:no cover return values.view(RawOpOutput) assert isinstance(values, (numpy.ndarray, float)) diff --git a/src/concrete/ml/sklearn/tree_to_numpy.py b/src/concrete/ml/sklearn/tree_to_numpy.py index 45edb91f6c..807aad7887 100644 --- a/src/concrete/ml/sklearn/tree_to_numpy.py +++ b/src/concrete/ml/sklearn/tree_to_numpy.py @@ -291,6 +291,6 @@ def tree_to_numpy( # but also rounding the threshold such that they are now integers q_y = tree_values_preprocessing(onnx_model, framework, output_n_bits) - _tree_inference = get_equivalent_numpy_forward(onnx_model) + _tree_inference, onnx_model = get_equivalent_numpy_forward(onnx_model) return (_tree_inference, [q_y.quantizer], onnx_model) diff --git a/src/concrete/ml/torch/compile.py b/src/concrete/ml/torch/compile.py index 0457d8d4ba..9e83e24d97 100644 --- a/src/concrete/ml/torch/compile.py +++ b/src/concrete/ml/torch/compile.py @@ -3,7 +3,7 @@ import tempfile import warnings from pathlib import Path -from typing import Optional, Tuple, Union +from typing import Dict, Optional, Tuple, Union import numpy import onnx @@ -55,7 +55,7 @@ def build_quantized_module( model: Union[torch.nn.Module, onnx.ModelProto], torch_inputset: Dataset, import_qat: bool = False, - n_bits=MAX_BITWIDTH_BACKWARD_COMPATIBLE, + n_bits: Union[int, Dict[str, int]] = MAX_BITWIDTH_BACKWARD_COMPATIBLE, rounding_threshold_bits: Optional[int] = None, ) -> QuantizedModule: """Build a quantized module from a Torch or ONNX model. @@ -81,10 +81,7 @@ def build_quantized_module( convert_torch_tensor_or_numpy_array_to_numpy_array(val) for val in to_tuple(torch_inputset) ) - # Tracing needs to be done with the batch size of 1 since we compile our models to FHE with - # this batch size. The input set contains many examples, to determine a representative - # bit-width, but for tracing we only take a single one. We need the ONNX tracing batch size to - # match the batch size during FHE inference which can only be 1 for the moment. + # No batch dimension (i.e., 0 instead of [0]) because else GEMM onnx pass can't be applied dummy_input_for_tracing = tuple( torch.from_numpy(val[[0], ::]).float() for val in inputset_as_numpy_tuple ) @@ -113,7 +110,7 @@ def _compile_torch_or_onnx_model( configuration: Optional[Configuration] = None, artifacts: Optional[DebugArtifacts] = None, show_mlir: bool = False, - n_bits=MAX_BITWIDTH_BACKWARD_COMPATIBLE, + n_bits: Union[int, Dict[str, int]] = MAX_BITWIDTH_BACKWARD_COMPATIBLE, rounding_threshold_bits: Optional[int] = None, p_error: Optional[float] = None, global_p_error: Optional[float] = None, @@ -272,7 +269,7 @@ def compile_onnx_model( configuration: Optional[Configuration] = None, artifacts: Optional[DebugArtifacts] = None, show_mlir: bool = False, - n_bits=MAX_BITWIDTH_BACKWARD_COMPATIBLE, + n_bits: Union[int, Dict] = MAX_BITWIDTH_BACKWARD_COMPATIBLE, rounding_threshold_bits: Optional[int] = None, p_error: Optional[float] = None, global_p_error: Optional[float] = None, @@ -340,7 +337,7 @@ def compile_brevitas_qat_model( rounding_threshold_bits: Optional[int] = None, p_error: Optional[float] = None, global_p_error: Optional[float] = None, - output_onnx_file: Union[Path, str] = None, + output_onnx_file: Union[None, Path, str] = None, verbose: bool = False, ) -> QuantizedModule: """Compile a Brevitas Quantization Aware Training model. @@ -378,7 +375,6 @@ def compile_brevitas_qat_model( Returns: QuantizedModule: The resulting compiled QuantizedModule. """ - inputset_as_numpy_tuple = tuple( convert_torch_tensor_or_numpy_array_to_numpy_array(val) for val in to_tuple(torch_inputset) ) @@ -418,8 +414,11 @@ def compile_brevitas_qat_model( # https://github.com/onnx/optimizer/blob/master/onnxoptimizer/pass_registry.h # In the export function, the `args` parameter is used instead of the `input_shape` one in # order to be able to handle multi-inputs models - exporter.onnx_passes.append("eliminate_nop_pad") - exporter.onnx_passes.append("fuse_pad_into_conv") + exporter.onnx_passes += [ + "eliminate_nop_pad", + "fuse_pad_into_conv", + "fuse_matmul_add_bias_into_gemm", + ] onnx_model = exporter.export( torch_model, args=dummy_input_for_tracing, diff --git a/src/concrete/ml/torch/hybrid_model.py b/src/concrete/ml/torch/hybrid_model.py index d6c473ca55..861015f3ed 100644 --- a/src/concrete/ml/torch/hybrid_model.py +++ b/src/concrete/ml/torch/hybrid_model.py @@ -1,16 +1,44 @@ """Implement the conversion of a torch model to a hybrid fhe/torch inference.""" +import ast +import enum +import io +import sys +import time import uuid from pathlib import Path -from typing import List, Optional, Union +from typing import Dict, List, Optional, Tuple, Union +import numpy +import requests import torch from concrete.fhe import Configuration from torch import nn from transformers import Conv1D from ..deployment.fhe_client_server import FHEModelClient, FHEModelDev -from .compile import compile_torch_model +from .compile import QuantizedModule, compile_torch_model + + +class FHEMode(enum.Enum): + """Simple enum for different modes of execution of HybridModel.""" + + DISABLE = "disable" # Use torch weights + REMOTE = "remote" # Use remote FHE server + SIMULATE = "simulate" # Use FHE simulation + CALIBRATE = "calibrate" # Use calibration (to run before FHE compilation) + + +def tuple_to_underscore_str(tup: Tuple) -> str: + """Convert a tuple to a string representation. + + Args: + tup (Tuple): a tuple to change into string representation + + Returns: + str: a string representing the tuple + """ + return repr(tup).replace("(", "p_").replace(")", "_p").replace(", ", "_") # FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/3858 @@ -53,35 +81,107 @@ def convert_conv1d_to_linear(layer_or_module): return layer_or_module +# pylint: disable-next=too-many-instance-attributes class RemoteModule(nn.Module): """A wrapper class for the modules to be done remotely with FHE.""" def __init__( self, - module=None, - server_remote_address=None, + module: Optional[nn.Module] = None, + server_remote_address: Optional[str] = None, + module_name: Optional[str] = None, + model_name: Optional[str] = None, + verbose: int = 0, ): super().__init__() - self.private_module = module - self.server_remote_address = server_remote_address - self.calibration_data = [] + self.private_module: Optional[nn.Module] = module + self.server_remote_address: Optional[str] = server_remote_address + self.calibration_data: List = [] self.uid = str(uuid.uuid4()) - self.private_q_module = None - self.fhe_local_mode = "disable" - self.client: Optional[FHEModelClient] = None - self.path_to_keys = None - self.path_to_client = None - - def init_fhe_client(self, path_to_client: str, path_to_keys: str): + self.private_q_module: Optional[QuantizedModule] = None + # TODO: figure out if this is good + self.fhe_local_mode: FHEMode = FHEMode.CALIBRATE + self.clients: Dict[str, Tuple[str, FHEModelClient]] = {} + self.path_to_keys: Optional[Path] = None + self.path_to_clients: Optional[Path] = None + self.module_name: Optional[str] = module_name + self.model_name: Optional[str] = model_name + self.verbose = verbose + + def init_fhe_client( + self, path_to_client: Optional[Path] = None, path_to_keys: Optional[Path] = None + ): # pragma:no cover """Set the clients keys. Args: path_to_client (str): Path where the client.zip is located. path_to_keys (str): Path where keys are located. + + Raises: + ValueError: if anything goes wrong with the server. """ - # TODO: here we need to load fhe client.zip with FHEModelClient. - # Either by getting it from the server with the self.uid or - # directly getting it when downloading the model from HF. + self.path_to_clients = path_to_client + if self.path_to_clients is None: + self.path_to_clients = Path() / "clients" + self.path_to_clients.mkdir(exist_ok=True) + self.path_to_keys = path_to_keys + if self.path_to_keys is None: + self.path_to_keys = Path() / "keys" + self.path_to_keys.mkdir(exist_ok=True) + + assert self.module_name is not None + shapes_response = requests.get( + f"{self.server_remote_address}/list_shapes", + data={"module_name": self.module_name, "model_name": self.model_name}, + ) + if shapes_response.status_code != 200: + # Add link to request content + raise ValueError( + f"Couldn't get shapes from server:\n{shapes_response.content.decode('utf-8')}" + ) + shapes = shapes_response.json() + for shape in shapes: + client_response = requests.get( + f"{self.server_remote_address}/get_client", + data={ + "module_name": self.module_name, + "model_name": self.model_name, + "input_shape": shape, + }, + ) + if client_response.status_code != 200: + # Add link to request content + raise ValueError( + f"Couldn't get client from server:\n{client_response.content.decode('utf-8')}" + ) + path_to_client = self.path_to_clients / tuple_to_underscore_str(ast.literal_eval(shape)) + path_to_client.mkdir(exist_ok=True) + with open(path_to_client / "client.zip", "wb") as file: + file.write(client_response.content) + # Create the client + client = FHEModelClient( + path_dir=str(path_to_client.resolve()), key_dir=str(self.path_to_keys.resolve()) + ) + # The client first need to create the private and evaluation keys. + client.generate_private_and_evaluation_keys() + # Get the serialized evaluation keys + serialized_evaluation_keys = client.get_serialized_evaluation_keys() + if self.verbose: + print(f"Evaluation keys size: {len(serialized_evaluation_keys) / (10**6):.2f} MB") + assert isinstance(serialized_evaluation_keys, bytes) + assert self.module_name is not None + response = requests.post( + f"{self.server_remote_address}/add_key", + data={ + "module_name": self.module_name, + "model_name": self.model_name, + "input_shape": shape, + }, + files={"key": io.BytesIO(initial_bytes=serialized_evaluation_keys)}, + ) + assert response.status_code == 200, response.content.decode("utf-8") + uid = response.json()["uid"] + self.clients[shape] = (uid, client) def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward pass of the remote module. @@ -91,30 +191,97 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: Returns: (torch.Tensor): The output tensor. + + Raises: + ValueError: if fhe_mode is not supported """ - if self.fhe_local_mode != "disable": - # for mypy - assert self.private_module is not None + # - disable: torch module + # - remote: client-server + # - simulate: compiled simulation + # - calibrate: calibration + + if self.fhe_local_mode not in {FHEMode.DISABLE, FHEMode.CALIBRATE, FHEMode.REMOTE, None}: + # Using quantized module assert self.private_q_module is not None - y = self.private_q_module.forward(x.detach().numpy(), fhe=self.fhe_local_mode) - y = torch.Tensor(y) - elif self.private_module is not None: - if isinstance(x, torch.Tensor): - self.calibration_data.append(x.detach()) + y = torch.Tensor( + self.private_q_module.forward(x.detach().numpy(), fhe=self.fhe_local_mode.value) + ) + elif self.fhe_local_mode == FHEMode.DISABLE: + # Calling torch + assert self.private_module is not None + y = self.private_module.forward( + x.detach(), + ) + assert isinstance(y, torch.Tensor) + elif self.fhe_local_mode == FHEMode.CALIBRATE: + # Calling torch + gathering calibration data + assert self.private_module is not None + self.calibration_data.append(x.detach()) y = self.private_module(x) + assert isinstance(y, torch.Tensor) # TODO: https://github.com/zama-ai/concrete-ml-internal/issues/3869 - # else: - # y = self.remote_call(x) + elif self.fhe_local_mode == FHEMode.REMOTE: # pragma:no cover + # Remote call + y = self.remote_call(x) + else: # pragma:no cover + # Shouldn't happen + raise ValueError(f"{self.fhe_local_mode} is not recognized") return y - def remote_call(self, x: torch.Tensor): + def remote_call(self, x: torch.Tensor) -> torch.Tensor: # pragma:no cover """Call the remote server to get the private module inference. Args: x (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The result of the FHE computation """ # TODO: https://github.com/zama-ai/concrete-ml-internal/issues/3869 # implement server call and client initialization + base_device = x.device + x = x.to(device="cpu") + inferences = [] + for index in range(len(x)): + clear_input = x[[index], :].detach().numpy() + input_shape = tuple(clear_input.shape) + repr_input_shape = str(input_shape[1:]) + assert isinstance(clear_input, numpy.ndarray) + assert repr_input_shape in self.clients + key_id, client = self.clients[repr_input_shape] + assert client is not None + encrypted_input = client.quantize_encrypt_serialize(clear_input) + assert isinstance(encrypted_input, bytes) + if self.verbose: + print( + f"Encrypted input size: {sys.getsizeof(encrypted_input) / 1024 / 1024:.2f} MB" + ) + start = time.time() + assert self.module_name is not None + if self.verbose: + print("Infering ...") + inference_query = requests.post( + f"{self.server_remote_address}/compute", + files={ + "model_input": io.BytesIO(encrypted_input), + }, + data={ + "uid": key_id, + "module_name": self.module_name, + "model_name": self.model_name, + "input_shape": repr_input_shape, + }, + stream=True, + ) + end = time.time() + if self.verbose: + print(f"Inference done in {end - start} seconds") + # Unpack the results + assert inference_query.status_code == 200, inference_query.content.decode("utf-8") + encrypted_result = inference_query.content + decrypted_prediction = client.deserialize_decrypt_dequantize(encrypted_result)[0] + inferences.append(decrypted_prediction) + return torch.Tensor(numpy.array(inferences)).to(device=base_device) class HybridFHEModel: @@ -125,33 +292,45 @@ def __init__( model: nn.Module, module_names: Union[str, List[str]], server_remote_address=None, + model_name: str = "model", + verbose: int = 0, ): self.model = model self.module_names = [module_names] if isinstance(module_names, str) else module_names self.server_remote_address = server_remote_address - self.private_modules = { + self.private_modules: Dict[str, nn.Module] = { name: self._get_module_by_name(self.model, name) for name in self.module_names } - self.remote_modules: dict = {} + self.remote_modules: Dict[str, RemoteModule] = {} self.private_q_modules: dict = {} - self.configuration: Configuration = None + self.configuration: Optional[Configuration] = None + self.model_name = model_name + self.verbose = verbose self._replace_modules() def _replace_modules(self): """Replace the private modules in the model with remote layers.""" - for name in self.module_names: + for module_name in self.module_names: # FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/3858 # Conv1d introduce reshaping operations which adds more TLU - self.private_modules[name] = convert_conv1d_to_linear(self.private_modules[name]) + self.private_modules[module_name] = convert_conv1d_to_linear( + self.private_modules[module_name] + ) - remote_module = RemoteModule(self.private_modules[name], self.server_remote_address) + remote_module = RemoteModule( + module=self.private_modules[module_name], + server_remote_address=self.server_remote_address, + module_name=module_name, + model_name=self.model_name, + verbose=self.verbose, + ) - self.remote_modules[name] = remote_module + self.remote_modules[module_name] = remote_module # Now we need to replace the module in its parent module. - *path, last = name.split(".") + *path, last = module_name.split(".") parent_module = ( self._get_module_by_name(self.model, ".".join(path)) if path else self.model ) @@ -167,10 +346,9 @@ def __call__(self, x: torch.Tensor, fhe: str = "disable") -> torch.Tensor: Returns: (torch.Tensor): The output tensor. """ - # Set the fhe mode in each remote module for module in self.remote_modules.values(): - module.fhe_local_mode = fhe + module.fhe_local_mode = FHEMode(fhe) x = self.model(x) return x @@ -188,28 +366,36 @@ def _get_module_by_name(model: nn.Module, name: str) -> Union[RemoteModule, nn.M Raises: ValueError: If no module found for the given name. """ + # TODO: Shouldn't this search recursively in name modules of name modules? for module_name, module in model.named_modules(): if module_name == name: return module raise ValueError(f"No module found for name {name}") - def init_client(self, path_to_client: str, path_to_keys: str): + def init_client( + self, path_to_clients: Optional[Path] = None, path_to_keys: Optional[Path] = None + ): # pragma:no cover """Initialize client for all remote modules. Args: - path_to_client (str): Path to the client.zip files. - path_to_keys (str): Path to the keys folder. + path_to_clients (Optional[Path]): Path to the client.zip files. + path_to_keys (Optional[Path]): Path to the keys folder. """ - # TODO: https://github.com/zama-ai/concrete-ml-internal/issues/3869 - # implement client initialization + if path_to_clients is None: + path_to_clients = Path("clients") + path_to_clients.mkdir(exist_ok=True) + for module_name, module in self.remote_modules.items(): + path_to_client = path_to_clients / module_name + path_to_client.mkdir(exist_ok=True) + module.init_fhe_client(path_to_client=path_to_client, path_to_keys=path_to_keys) def compile_model( self, x: torch.Tensor, n_bits: int = 8, - rounding_threshold_bits: int = 8, - p_error=0.01, - configuration: Configuration = None, + rounding_threshold_bits: Optional[int] = 8, + p_error: float = 0.01, + configuration: Optional[Configuration] = None, ): """Compiles the specific layers to FHE. @@ -224,7 +410,12 @@ def compile_model( configuration (Configuration): A concrete Configuration object specifying the FHE encryption parameters. If not specified, a default configuration is used. """ + # We do a forward pass where we accumulate inputs to use for compilation + for name in self.module_names: + # default is "calibrate" + self.remote_modules[name].fhe_local_mode = FHEMode.CALIBRATE self.model(x) + self.configuration = configuration for name in self.module_names: @@ -244,32 +435,45 @@ def compile_model( self.remote_modules[name].private_q_module = self.private_q_modules[name] - def _save_fhe_circuit(self, path: Path): + def _save_fhe_circuit(self, path: Path, via_mlir=False): """Private method that saves the FHE circuits. Args: path (Path): The directory where the FHE circuit will be saved. + via_mlir (bool): if fhe circuits should be serialized using via_mlir option + useful for cross-platform (compile on one architecture and run on another) """ - path = Path(path) - for name in self.module_names: + model_path = Path(path) + for module_name in self.module_names: + input_shapes = [ + tuple(elt.dim_value for elt in onnx_input.type.tensor_type.shape.dim) + for onnx_input in self.private_q_modules[ # pylint: disable=protected-access + self.module_names[0] + ]._onnx_model.graph.input + ] + assert len(input_shapes) == 1, "Multi-input circuits not supported yet" + model_module_path = model_path.resolve() / module_name + model_module_path.mkdir(exist_ok=True) + model_module_shape_path = model_module_path / tuple_to_underscore_str(input_shapes[0]) model_dev = FHEModelDev( - str(path.resolve()) + f"/{name}_fhe_circuit", - self.private_q_modules[name], + str(model_module_shape_path.resolve()), + self.private_q_modules[module_name], ) - model_dev.save() + model_dev.save(via_mlir=via_mlir) - def save_and_clear_private_info(self, path: Path): + def save_and_clear_private_info(self, path: Path, via_mlir=False): """Save the PyTorch model to the provided path and also saves the corresponding FHE circuit. Args: path (Path): The directory where the model and the FHE circuit will be saved. + via_mlir (bool): if fhe circuits should be serialized using via_mlir option + useful for cross-platform (compile on one architecture and run on another) """ path = Path(path) path.mkdir(parents=True, exist_ok=True) for name in self.module_names: module = self._get_module_by_name(self.model, name) - # Remove private information for attr in ["private_module", "calibration_data", "private_q_module"]: if hasattr(module, attr): @@ -280,7 +484,7 @@ def save_and_clear_private_info(self, path: Path): torch.save(self.model, model_path.resolve()) # Save the FHE circuit in the same directory - self._save_fhe_circuit(path) + self._save_fhe_circuit(path, via_mlir=via_mlir) def publish_to_hub(self): """Allow the user to push the model and FHE required files to HF Hub.""" diff --git a/src/concrete/ml/torch/numpy_module.py b/src/concrete/ml/torch/numpy_module.py index 0d84fba1c7..36e2576910 100644 --- a/src/concrete/ml/torch/numpy_module.py +++ b/src/concrete/ml/torch/numpy_module.py @@ -55,8 +55,7 @@ def __init__( + f"but it is {onnx_model_opset_version}", ) - self._onnx_model = model - self.numpy_forward = get_equivalent_numpy_forward(model) + self.numpy_forward, self._onnx_model = get_equivalent_numpy_forward(model) else: raise ValueError( f"model must be a torch.nn.Module or an onnx.ModelProto, got {type(model).__name__}" diff --git a/tests/torch/test_compile_torch.py b/tests/torch/test_compile_torch.py index be5f50af61..be4952e171 100644 --- a/tests/torch/test_compile_torch.py +++ b/tests/torch/test_compile_torch.py @@ -1056,7 +1056,7 @@ def test_net_has_no_tlu( if num_inputs > 1 and use_qat: return - use_conv = isinstance(input_shape, tuple) and len(input_shape) > 1 + use_conv = isinstance(input_shape, tuple) and len(input_shape) == 3 net = module(use_conv, use_qat, input_shape, n_bits) net.eval() @@ -1107,6 +1107,7 @@ def decorate_name(self): n_bits=n_bits, ) + assert quantized_numpy_module.fhe_circuit is not None mlir = quantized_numpy_module.fhe_circuit.mlir # Check if a TLU is present or not, depending on whether we force a TLU to be present diff --git a/tests/torch/test_hybrid_converter.py b/tests/torch/test_hybrid_converter.py index 1d7557f1b4..8b288dbee1 100644 --- a/tests/torch/test_hybrid_converter.py +++ b/tests/torch/test_hybrid_converter.py @@ -72,9 +72,10 @@ def run_hybrid_model_test( module_names = module_names if isinstance(module_names, list) else [module_names] + # TODO: fix it -> broken due to shape handling # List of files to check files = ["model.pth"] + [ - f"{module_name}_fhe_circuit/{file_name}" + f"{module_name}/{file_name}" for module_name in module_names for file_name in ["client.zip", "server.zip", "versions.json"] ] diff --git a/use_case_examples/hybrid_model/README.md b/use_case_examples/hybrid_model/README.md new file mode 100644 index 0000000000..f4017ea305 --- /dev/null +++ b/use_case_examples/hybrid_model/README.md @@ -0,0 +1,14 @@ +# Hybrid model + +This use case example showcases how to partially run layers in FHE. + +In this case we apply a fully connected layer of a GPT-2 model in FHE. + +## How to run this use-case + +0. Install additional requirements using `python -m pip install -r requirements.txt` +1. Compile GPT-2 model using `bash compile.sh` script +1. Run FHE server using `bash serve.sh` +1. Run FHE client using `python infer_hybrid_llm_generate.py` + - You will first be asked about the number of tokens that you want to generate + - Then you will be able to enter your prompt diff --git a/use_case_examples/hybrid_model/compile.sh b/use_case_examples/hybrid_model/compile.sh new file mode 100644 index 0000000000..acbf2f7777 --- /dev/null +++ b/use_case_examples/hybrid_model/compile.sh @@ -0,0 +1,6 @@ +#!/bin/bash +VIA_MLIR=0 +for INDEX in 2 3 +do + INDEX=$INDEX VIA_MLIR=$VIA_MLIR python compile_hybrid_llm.py +done diff --git a/use_case_examples/hybrid_model/compile_hybrid_llm.py b/use_case_examples/hybrid_model/compile_hybrid_llm.py new file mode 100644 index 0000000000..f88cf6de21 --- /dev/null +++ b/use_case_examples/hybrid_model/compile_hybrid_llm.py @@ -0,0 +1,140 @@ +"""Showcase for the hybrid model converter.""" + +import os +from copy import deepcopy +from pathlib import Path +from typing import List, Union + +import torch +from concrete.fhe import Configuration, ParameterSelectionStrategy +from datasets import load_dataset +from transformers import AutoModelForCausalLM, AutoTokenizer + +from concrete.ml.torch.hybrid_model import HybridFHEModel + + +def compile_model( + model_name: str, + model: torch.nn.Module, + inputs: torch.Tensor, + module_names: Union[str, List], + expected_accuracy, + models_dir: Path, +): + """Run the test for any model with its private module names.""" + + # Enable multi params/precision + configuration = Configuration( + single_precision=False, + parameter_selection_strategy=ParameterSelectionStrategy.MULTI, + ) + + # Create a hybrid model + hybrid_model = HybridFHEModel(model, module_names) + hybrid_model.compile_model( + inputs, + n_bits=8, + # setting it to None is not enough -> weird + rounding_threshold_bits=None, + configuration=configuration, + ) + + # Sanity checks + logits_simulate = hybrid_model(inputs, fhe="simulate").logits + logits_disable = hybrid_model(inputs, fhe="disable").logits + logits_original = model(inputs).logits + # Ensure logits_disable and logits_original return the same output for the logits + assert torch.allclose(logits_disable, logits_original, atol=1e-7), "Outputs do not match!" + # Compare the topk accuracy of the FHE simulate circuit vs. the original. + k = 100 + # Get the topk indices for logits_disable and logits_simulate + topk_disable = logits_disable.topk(k, dim=-1).indices + topk_simulate = logits_simulate.topk(k, dim=-1).indices + # Prepare tensors for broadcasting + expanded_simulate = topk_simulate.unsqueeze(-1) + expanded_disable = topk_disable.unsqueeze(-2) + # Compute if elements of topk_simulate are in topk_disable for each token + is_in = (expanded_simulate == expanded_disable).any(-1) + # Compute average of these counts (the accuracy) + accuracy = is_in.float().mean() + # Make sure accuracy is above a certain threshold + if accuracy >= expected_accuracy: + print("Expected accuracy GPT2 hybrid not matched.") + + # Compilation + models_dir.mkdir(exist_ok=True) + model_dir = models_dir / model_name + print(f"Saving to {model_dir}") + via_mlir = bool(int(os.environ.get("VIA_MLIR", 0))) + hybrid_model.save_and_clear_private_info(model_dir, via_mlir=via_mlir) + + +if __name__ == "__main__": + configs = [ + ("transformer.h.0.mlp", 0.934), # Full MLP + (["transformer.h.0.mlp", "transformer.h.1.mlp"], 0.42), # Two full MLPs + ("transformer.h.0.mlp.c_proj", 0.986), # only projection in MLP + ("transformer.h.0.attn.c_proj", 0.986), # only projection in MLP + ] + config_index = int(os.environ.get("INDEX", 2)) + config = configs[config_index][0] + expected_accuracy = configs[config_index][1] + + # Compilation should be done on CPU + device = "cpu" + print(f"Using device: {device}") + + # Get GPT2 from Huggingface + model_name = "gpt2" + model_name_no_special_char = model_name.replace("/", "_") + tokenizer = AutoTokenizer.from_pretrained(model_name) + model = AutoModelForCausalLM.from_pretrained( + model_name, + device_map=device, + trust_remote_code=True, + ) + + configuration = { + "model_name": model_name, + "model_name_no_special_char": model_name_no_special_char, + "configuration": config, + } + + # In this case we compile for only one sample + # We might want to compile for multiple samples + # To do this the easiest solution is to compile on contexts of different sizes. + # They should all have the same lengths + # We might hack something based on HuggingFace dataset with some truncation + # Without truncation or selection it would require some knowledge of the tokenizer + max_context_size = 20 + num_samples = 50 + + dataset = load_dataset("wikipedia", "20220301.en") + print(model) + models_dir = Path(__file__).parent / os.environ.get("MODELS_DIR_NAME", "compiled_models") + models_dir.mkdir(exist_ok=True) + + # Compile for different shapes + for context_size in range(1, max_context_size): + prompts = [] + counter = 0 + for sample in dataset["train"]: + encoded = tokenizer.encode(sample["text"], return_tensors="pt") + if encoded.shape[1] >= context_size: + counter += 1 + prompts.append(encoded[:, :context_size]) + if counter == num_samples: + break + compile_inputset = torch.cat(prompts).to(device) + print(context_size, "compilation") + assert isinstance(model, torch.nn.Module) + + # We modify the model in place, so to compile multiple times we need to deepcopy the model + compile_model( + f"{model_name}_{config_index}", + deepcopy(model), + compile_inputset, + config, + expected_accuracy, + models_dir=models_dir, + ) diff --git a/use_case_examples/hybrid_model/infer_hybrid_llm_generate.py b/use_case_examples/hybrid_model/infer_hybrid_llm_generate.py new file mode 100644 index 0000000000..e6953ded15 --- /dev/null +++ b/use_case_examples/hybrid_model/infer_hybrid_llm_generate.py @@ -0,0 +1,89 @@ +"""Showcase for the hybrid model converter.""" +import time +from pathlib import Path + +import torch +from torch.backends import mps +from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizer, TextStreamer + +from concrete.ml.torch.hybrid_model import FHEMode, HybridFHEModel + +if __name__ == "__main__": + configs = [ + ("transformer.h.0.mlp", 0.934), + (["transformer.h.0.mlp", "transformer.h.1.mlp"], 0.42), + ("transformer.h.0.mlp.c_proj", 0.986), + ("transformer.h.0.attn.c_proj", 0.986), + ] + config_index = 3 + config = configs[config_index][0] + + device = "cpu" + if torch.cuda.is_available(): + device = "cuda" + if mps.is_available(): + device = "mps" + print(f"Using device: {device}") + + # Get GPT2 from Huggingface + # TODO: migrate to auto-model with model_name + model_name = "gpt2" + model_name_no_special_char = model_name.replace("/", "_") + tokenizer = AutoTokenizer.from_pretrained(model_name) + model = AutoModelForCausalLM.from_pretrained( + model_name, + device_map=device, + trust_remote_code=True, + ) + + # Modify model to use remote FHE server instead of local weights + hybrid_model = HybridFHEModel( + model, + config, + server_remote_address="http://0.0.0.0:8000", + model_name=f"{model_name}_{config_index}", + verbose=False, + ) + path_to_clients = Path(__file__).parent / "clients" + hybrid_model.init_client(path_to_clients=path_to_clients) + for module in hybrid_model.remote_modules.values(): + module.fhe_local_mode = FHEMode.REMOTE + + # Run example + while True: + # Take inputs + num_tokens = input("Number of tokens:\n") + if not num_tokens: + num_tokens = 5 + else: + num_tokens = int(num_tokens) + prompt = input("Prompt:\n") + if not prompt: + prompt = "Computations on encrypted data can help" + + # Encode and send to device + input_ids = tokenizer.encode(prompt, return_tensors="pt") + assert isinstance(input_ids, torch.Tensor) + input_ids = input_ids.to(device=device) + + print("*" * 10) + print("*" * 10) + print(f"{input_ids.shape[1]} tokens in '{prompt}'") + print("*" * 10) + print("*" * 10) + + # Print words as they are generated + streamer = TextStreamer(tokenizer=tokenizer) + start = time.time() + output_ids = model.generate( + input_ids, max_new_tokens=num_tokens, use_cache=True, streamer=streamer + ) + end = time.time() + generated = tokenizer.decode(output_ids[0]) + + print(f"{end - start} seconds to generate") + print("*" * 10) + print("*" * 10) + print(generated) + print("*" * 10) + print("*" * 10) diff --git a/use_case_examples/hybrid_model/load_and_analyze_data.py b/use_case_examples/hybrid_model/load_and_analyze_data.py new file mode 100644 index 0000000000..7fe8b1851e --- /dev/null +++ b/use_case_examples/hybrid_model/load_and_analyze_data.py @@ -0,0 +1,32 @@ +import json +from collections import Counter + +import matplotlib.pyplot as plt +from datasets import load_dataset +from tqdm import tqdm + + +def main(): + """ + Load wikipedia dataset and plot lenghts of text histogram. + For now this considers only the number of characters but we could also consider some stats like + the number of tokens, unique tokens, etc ... + """ + dataset = load_dataset("wikipedia", "20220301.en") + lengths = [len(sample["text"]) for sample in tqdm(dataset["train"])] + count = Counter(lengths) + print(count) + with open("wikipedia_counts.json", "w") as file: + json.dump(count, file) + with open("wikipedia_values.json", "w") as file: + json.dump(lengths, file) + + # Matplotlib plot + plt.subplots() + plt.hist(lengths, bins=1000) + plt.yscale("log") + plt.savefig("lengths.png") + + +if __name__ == "__main__": + main() diff --git a/use_case_examples/hybrid_model/requirements.txt b/use_case_examples/hybrid_model/requirements.txt new file mode 100644 index 0000000000..4ec68127ac --- /dev/null +++ b/use_case_examples/hybrid_model/requirements.txt @@ -0,0 +1,3 @@ +datasets==2.14.4 +apache_beam==2.49.0 +mwparserfromhell==0.6.4 diff --git a/use_case_examples/hybrid_model/serve.sh b/use_case_examples/hybrid_model/serve.sh new file mode 100644 index 0000000000..9d7b8f4a48 --- /dev/null +++ b/use_case_examples/hybrid_model/serve.sh @@ -0,0 +1,13 @@ +#!/bin/bash +uname_str=$(uname) +echo "${uname_str}" +if [[ $uname_str != "Darwin" ]]; then + echo "Not Darwin" + # tune the cpu-list according to the resources you want to allocate to it + PATH_TO_MODELS="compiled_models" PORT=8000 taskset --cpu-list 0-12 python serve_model.py + # No-limit + # PATH_TO_MODELS="compiled_models" PORT=8000 python serve_model.py +else + echo "Darwin" + PATH_TO_MODELS="compiled_models" PORT=8000 python serve_model.py +fi diff --git a/use_case_examples/hybrid_model/serve_model.py b/use_case_examples/hybrid_model/serve_model.py new file mode 100644 index 0000000000..9609465205 --- /dev/null +++ b/use_case_examples/hybrid_model/serve_model.py @@ -0,0 +1,193 @@ +"""Hybrid Model Deployment Server. + +Routes: + - Get all names + - Get client.zip + - Add a key + - Compute +""" + +import ast +import io +import os +import time +import uuid +from collections import defaultdict +from functools import lru_cache +from pathlib import Path +from typing import Dict, List, Optional, Tuple, Union + +import uvicorn +from fastapi import FastAPI, Form, HTTPException, UploadFile +from fastapi.responses import FileResponse, StreamingResponse +from loguru import logger + +# No relative import here because when not used in the package itself +from concrete.ml.deployment import FHEModelServer + + +def underscore_str_to_tuple(tup): + return ast.literal_eval(tup.replace("p_", "(").replace("_p", ")").replace("_", ", ")) + + +if __name__ == "__main__": + app = FastAPI(debug=False) + # Model-name -> Module-Name -> Input-shape + + FILE_FOLDER = Path(__file__).parent + KEY_PATH = Path(os.environ.get("KEY_PATH", FILE_FOLDER / Path("user_keys"))) + KEY_PATH.mkdir(exist_ok=True) + MODELS_PATH = Path(os.environ.get("PATH_TO_MODELS", FILE_FOLDER / Path("model"))) + PORT = os.environ.get("PORT", "5000") + MODULES = defaultdict(dict) + # Populate modules -> could be done dynamically on each query tbh + for model_path in MODELS_PATH.iterdir(): # Model + # TODO: change with a struct/obj + model_name = model_path.name + MODULES[model_name] = defaultdict(dict) + for module_path in model_path.iterdir(): # Module + if not module_path.is_dir(): + continue + module_name = module_path.name + MODULES[model_name][module_name] = defaultdict(dict) + for input_shape_path in module_path.iterdir(): + if not input_shape_path.is_dir(): + continue + input_shape = str(underscore_str_to_tuple(input_shape_path.name)) + MODULES[model_name][module_name][input_shape] = { + "path": input_shape_path.resolve(), + "module_name": module_name, + "model_name": model_name, + "shape": input_shape, + } + + @lru_cache(maxsize=None) + def load_key(uid) -> bytes: + with open(KEY_PATH / str(uid), "rb") as file: + return file.read() + + def dump_key(key_bytes: bytes, uid: Union[uuid.UUID, str]) -> None: + with open(KEY_PATH / str(uid), "wb") as file: + file.write(key_bytes) + + @lru_cache(maxsize=None) + def get_circuit(model_name, module_name, input_shape): + return FHEModelServer(str(MODULES[model_name][module_name][input_shape]["path"])) + + def check_inputs(model_name: str, module_name: Optional[str], input_shape: Optional[Tuple]): + if model_name not in MODULES: + raise HTTPException( + status_code=500, + detail=f"provided names '{model_name}' does not match any known name", + ) + if module_name is not None and module_name not in MODULES[model_name]: + raise HTTPException( + status_code=500, + detail=f"provided names '{module_name}' does not match any known name" + f"{list(MODULES[model_name].keys())}", + ) + if input_shape is not None and input_shape not in MODULES[model_name][module_name]: + raise HTTPException( + status_code=500, + detail=f"provided names '{module_name}' does not match any known name" + f"{list(MODULES[model_name][module_name].keys())}", + ) + + @app.get("/list_models") + def list_models(): + return MODULES + + @app.get("/list_modules") + def list_modules(model_name: str = Form()): + check_inputs(model_name, None, None) + return MODULES[model_name] + + @app.get("/list_shapes") + def list_shapes(model_name: str = Form(), module_name: str = Form()): + check_inputs(model_name, module_name, None) + return MODULES[model_name][module_name] + + @app.get("/get_client") + def get_client(model_name: str = Form(), module_name: str = Form(), input_shape: str = Form()): + """Get client. + + Returns: + FileResponse: client.zip + + Raises: + HTTPException: if the file can't be find locally + """ + check_inputs(model_name, module_name, input_shape) + path_to_client = ( + MODULES[model_name][module_name][str(input_shape)]["path"] / "client.zip" + ).resolve() + if not path_to_client.exists(): + raise HTTPException(status_code=500, detail="Could not find client.") + return FileResponse(path_to_client, media_type="application/zip") + + @app.post("/add_key") + async def add_key( + key: UploadFile, + model_name: str = Form(), + module_name: str = Form(), + input_shape: str = Form(), + ): + """Add public key. + + Arguments: + key (UploadFile): public key + + Returns: + Dict[str, str] + - uid: uid a personal uid + """ + check_inputs(model_name, module_name, input_shape) + uid = str(uuid.uuid4()) + key_bytes = await key.read() + dump_key(key_bytes, uid) + # TODO: we should probably store for which circuit the key was generated for + # such that we can raise an error if the targeted keys does not match the correct circuit + return {"uid": uid} + + @app.post("/compute") + async def compute( + model_input: UploadFile, + uid: str = Form(), + model_name: str = Form(), + module_name: str = Form(), + input_shape: str = Form(), + ): # noqa: B008 + """Compute the circuit over encrypted input. + + Arguments: + model_input (UploadFile): input of the circuit + uid (str): uid of the public key to use + + Returns: + StreamingResponse: the result of the circuit + """ + check_inputs(model_name, module_name, input_shape) + start = time.time() + key_bytes = load_key(uid) + end = time.time() + logger.info(f"It took {end - start} seconds to load the key") + + start = time.time() + fhe = get_circuit(model_name, module_name, input_shape) + end = time.time() + logger.info(f"It took {end - start} seconds to load the circuit") + + start = time.time() + encrypted_results = fhe.run( + serialized_encrypted_quantized_data=await model_input.read(), + serialized_evaluation_keys=key_bytes, + ) + end = time.time() + logger.info(f"fhe inference of input of shape {input_shape} took {end - start}") + logger.info(f"Results size is {len(encrypted_results)/(1024**2)} Mb") + start = time.time() + return StreamingResponse( + io.BytesIO(encrypted_results), + ) + + uvicorn.run(app, host="0.0.0.0", port=int(PORT)) diff --git a/use_case_examples/llm/QGPT2Evaluate.ipynb b/use_case_examples/llm/QGPT2Evaluate.ipynb index a67580b04f..5c560095ec 100644 --- a/use_case_examples/llm/QGPT2Evaluate.ipynb +++ b/use_case_examples/llm/QGPT2Evaluate.ipynb @@ -695,5 +695,5 @@ } }, "nbformat": 4, - "nbformat_minor": 2 + "nbformat_minor": 4 }