diff --git a/src/concrete/ml/onnx/convert.py b/src/concrete/ml/onnx/convert.py index 9cebb3bdcf..9be80500c3 100644 --- a/src/concrete/ml/onnx/convert.py +++ b/src/concrete/ml/onnx/convert.py @@ -183,53 +183,6 @@ def get_equivalent_numpy_forward_from_onnx( # Optimize ONNX graph # List of all currently supported onnx optimizer passes # From https://github.com/onnx/optimizer/blob/master/onnxoptimizer/pass_registry.h - # onnx_passes = [ - # 'adjust_add', - # 'rename_input_output', - # 'set_unique_name_for_nodes', - # 'nop', - # 'eliminate_nop_cast', - # 'eliminate_nop_dropout', - # 'eliminate_nop_flatten', - # 'extract_constant_to_initializer', - # 'eliminate_if_with_const_cond', - # 'eliminate_nop_monotone_argmax', - # 'eliminate_nop_pad', - # 'eliminate_nop_concat', - # 'eliminate_nop_split', - # 'eliminate_nop_expand', - # 'eliminate_shape_gather', - # 'eliminate_slice_after_shape', - # 'eliminate_nop_transpose', - # 'fuse_add_bias_into_conv', - # 'fuse_bn_into_conv', - # 'fuse_consecutive_concats', - # 'fuse_consecutive_log_softmax', - # 'fuse_consecutive_reduce_unsqueeze', - # 'fuse_consecutive_squeezes', - # 'fuse_consecutive_transposes', - # 'fuse_matmul_add_bias_into_gemm', - # 'fuse_pad_into_conv', - # 'fuse_pad_into_pool', - # 'fuse_transpose_into_gemm', - # 'replace_einsum_with_matmul', - # 'lift_lexical_references', - # 'split_init', - # 'split_predict', - # 'fuse_concat_into_reshape', - # 'eliminate_nop_reshape', - # 'eliminate_nop_with_unit', - # 'eliminate_common_subexpression', - # 'fuse_qkv', - # 'fuse_consecutive_unsqueezes', - # 'eliminate_deadend', - # 'eliminate_identity', - # 'eliminate_shape_op', - # 'fuse_consecutive_slices', - # 'eliminate_unused_initializer', - # 'eliminate_duplicate_initializer', - # 'adjust_slice_and_matmul' - # ] onnx_passes = [ "fuse_matmul_add_bias_into_gemm", "eliminate_nop_pad", @@ -243,9 +196,7 @@ def get_equivalent_numpy_forward_from_onnx( # ONNX optimizer does not optimize Mat-Mult + Bias pattern into GEMM if the input isn't a matrix # We manually do the optimization for this case equivalent_onnx_model = fuse_matmul_bias_to_gemm(equivalent_onnx_model) - with open("debug.onnx", "wb") as file: - file.write(equivalent_onnx_model.SerializeToString()) - # checker.check_model(equivalent_onnx_model) + checker.check_model(equivalent_onnx_model) # Check supported operators required_onnx_operators = set(get_op_type(node) for node in equivalent_onnx_model.graph.node) diff --git a/src/concrete/ml/quantization/post_training.py b/src/concrete/ml/quantization/post_training.py index c32bec5b1a..fcbe6abd5c 100644 --- a/src/concrete/ml/quantization/post_training.py +++ b/src/concrete/ml/quantization/post_training.py @@ -788,9 +788,7 @@ def _process_initializer(self, n_bits: int, values: numpy.ndarray): QuantizedArray: a quantized tensor with integer values on n_bits bits """ - if isinstance(values, numpy.ndarray) and numpy.issubdtype( - values.dtype, numpy.integer - ): # pragma:no cover + if isinstance(values, numpy.ndarray) and numpy.issubdtype(values.dtype, numpy.integer): return values.view(RawOpOutput) assert isinstance(values, (numpy.ndarray, float)) diff --git a/src/concrete/ml/torch/compile.py b/src/concrete/ml/torch/compile.py index 9e83e24d97..6130da146b 100644 --- a/src/concrete/ml/torch/compile.py +++ b/src/concrete/ml/torch/compile.py @@ -81,7 +81,10 @@ def build_quantized_module( convert_torch_tensor_or_numpy_array_to_numpy_array(val) for val in to_tuple(torch_inputset) ) - # No batch dimension (i.e., 0 instead of [0]) because else GEMM onnx pass can't be applied + # Tracing needs to be done with the batch size of 1 since we compile our models to FHE with + # this batch size. The input set contains many examples, to determine a representative + # bit-width, but for tracing we only take a single one. We need the ONNX tracing batch size to + # match the batch size during FHE inference which can only be 1 for the moment. dummy_input_for_tracing = tuple( torch.from_numpy(val[[0], ::]).float() for val in inputset_as_numpy_tuple ) diff --git a/src/concrete/ml/torch/hybrid_model.py b/src/concrete/ml/torch/hybrid_model.py index 861015f3ed..59f6c8a627 100644 --- a/src/concrete/ml/torch/hybrid_model.py +++ b/src/concrete/ml/torch/hybrid_model.py @@ -16,11 +16,12 @@ from torch import nn from transformers import Conv1D +from ..common.utils import MAX_BITWIDTH_BACKWARD_COMPATIBLE from ..deployment.fhe_client_server import FHEModelClient, FHEModelDev from .compile import QuantizedModule, compile_torch_model -class FHEMode(enum.Enum): +class HybridFHEMode(enum.Enum): """Simple enum for different modes of execution of HybridModel.""" DISABLE = "disable" # Use torch weights @@ -99,8 +100,7 @@ def __init__( self.calibration_data: List = [] self.uid = str(uuid.uuid4()) self.private_q_module: Optional[QuantizedModule] = None - # TODO: figure out if this is good - self.fhe_local_mode: FHEMode = FHEMode.CALIBRATE + self.fhe_local_mode: HybridFHEMode = HybridFHEMode.CALIBRATE self.clients: Dict[str, Tuple[str, FHEModelClient]] = {} self.path_to_keys: Optional[Path] = None self.path_to_clients: Optional[Path] = None @@ -120,6 +120,7 @@ def init_fhe_client( Raises: ValueError: if anything goes wrong with the server. """ + # Handle paths self.path_to_clients = path_to_client if self.path_to_clients is None: self.path_to_clients = Path() / "clients" @@ -129,6 +130,8 @@ def init_fhe_client( self.path_to_keys = Path() / "keys" self.path_to_keys.mkdir(exist_ok=True) + # List all shapes supported by the server + # This is needed until we have generic shape support in Concrete Python assert self.module_name is not None shapes_response = requests.get( f"{self.server_remote_address}/list_shapes", @@ -139,6 +142,8 @@ def init_fhe_client( raise ValueError( f"Couldn't get shapes from server:\n{shapes_response.content.decode('utf-8')}" ) + + # For all supported shape we need to get the FHE client from the server shapes = shapes_response.json() for shape in shapes: client_response = requests.get( @@ -170,6 +175,7 @@ def init_fhe_client( print(f"Evaluation keys size: {len(serialized_evaluation_keys) / (10**6):.2f} MB") assert isinstance(serialized_evaluation_keys, bytes) assert self.module_name is not None + # Upload the key to the server response = requests.post( f"{self.server_remote_address}/add_key", data={ @@ -181,11 +187,21 @@ def init_fhe_client( ) assert response.status_code == 200, response.content.decode("utf-8") uid = response.json()["uid"] + # We store the key id and the client in the object + # If we observe memory issues due to this we can always move + # towards client lazy loading with caching as done on the server. self.clients[shape] = (uid, client) def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward pass of the remote module. + To change the behavior of this forward function one must change the fhe_local_mode + attribute. Choices are: + - disable: forward using torch module + - remote: forward with fhe client-server + - simulate: forward with local fhe simulation + - calibrate: forward for calibration + Args: x (torch.Tensor): The input tensor. @@ -193,34 +209,39 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: (torch.Tensor): The output tensor. Raises: - ValueError: if fhe_mode is not supported + ValueError: if local_fhe_mode is not supported """ # - disable: torch module # - remote: client-server # - simulate: compiled simulation # - calibrate: calibration - if self.fhe_local_mode not in {FHEMode.DISABLE, FHEMode.CALIBRATE, FHEMode.REMOTE, None}: + if self.fhe_local_mode not in { + HybridFHEMode.DISABLE, + HybridFHEMode.CALIBRATE, + HybridFHEMode.REMOTE, + None, + }: # Using quantized module assert self.private_q_module is not None y = torch.Tensor( self.private_q_module.forward(x.detach().numpy(), fhe=self.fhe_local_mode.value) ) - elif self.fhe_local_mode == FHEMode.DISABLE: + elif self.fhe_local_mode == HybridFHEMode.DISABLE: # Calling torch assert self.private_module is not None y = self.private_module.forward( x.detach(), ) assert isinstance(y, torch.Tensor) - elif self.fhe_local_mode == FHEMode.CALIBRATE: + elif self.fhe_local_mode == HybridFHEMode.CALIBRATE: # Calling torch + gathering calibration data assert self.private_module is not None self.calibration_data.append(x.detach()) y = self.private_module(x) assert isinstance(y, torch.Tensor) # TODO: https://github.com/zama-ai/concrete-ml-internal/issues/3869 - elif self.fhe_local_mode == FHEMode.REMOTE: # pragma:no cover + elif self.fhe_local_mode == HybridFHEMode.REMOTE: # pragma:no cover # Remote call y = self.remote_call(x) else: # pragma:no cover @@ -237,12 +258,15 @@ def remote_call(self, x: torch.Tensor) -> torch.Tensor: # pragma:no cover Returns: torch.Tensor: The result of the FHE computation """ - # TODO: https://github.com/zama-ai/concrete-ml-internal/issues/3869 - # implement server call and client initialization + # Store tensor device and move to CPU for FHE encryption base_device = x.device x = x.to(device="cpu") + + # We need to iterate over elements in the batch since + # we don't support batch inference inferences = [] for index in range(len(x)): + # Manage tensor, tensor shape, and encrypt tensor clear_input = x[[index], :].detach().numpy() input_shape = tuple(clear_input.shape) repr_input_shape = str(input_shape[1:]) @@ -260,6 +284,7 @@ def remote_call(self, x: torch.Tensor) -> torch.Tensor: # pragma:no cover assert self.module_name is not None if self.verbose: print("Infering ...") + # Inference using FHE server inference_query = requests.post( f"{self.server_remote_address}/compute", files={ @@ -276,16 +301,30 @@ def remote_call(self, x: torch.Tensor) -> torch.Tensor: # pragma:no cover end = time.time() if self.verbose: print(f"Inference done in {end - start} seconds") - # Unpack the results + # Deserialize and decrypt the result assert inference_query.status_code == 200, inference_query.content.decode("utf-8") encrypted_result = inference_query.content decrypted_prediction = client.deserialize_decrypt_dequantize(encrypted_result)[0] inferences.append(decrypted_prediction) + # Concatenate results and move them back to proper device return torch.Tensor(numpy.array(inferences)).to(device=base_device) +# Add support for QAT models +# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/3992 class HybridFHEModel: - """Convert a model to a hybrid model.""" + """Convert a model to a hybrid model. + + This is done by converting targeted modules by RemoteModules. + This will modify the model in place. + + Args: + model (nn.Module): The model to modify (in-place modification) + module_names (Union[str, List[str]]): The module name(s) to replace with FHE server. + server_remote_address): The remote address of the FHE server + model_name (str): Model name identifier + verbose (int): If logs should be printed when interacting with FHE server + """ def __init__( self, @@ -348,7 +387,7 @@ def __call__(self, x: torch.Tensor, fhe: str = "disable") -> torch.Tensor: """ # Set the fhe mode in each remote module for module in self.remote_modules.values(): - module.fhe_local_mode = FHEMode(fhe) + module.fhe_local_mode = HybridFHEMode(fhe) x = self.model(x) return x @@ -392,9 +431,9 @@ def init_client( def compile_model( self, x: torch.Tensor, - n_bits: int = 8, - rounding_threshold_bits: Optional[int] = 8, - p_error: float = 0.01, + n_bits: int = MAX_BITWIDTH_BACKWARD_COMPATIBLE, + rounding_threshold_bits: Optional[int] = None, + p_error: Optional[float] = None, configuration: Optional[Configuration] = None, ): """Compiles the specific layers to FHE. @@ -413,7 +452,7 @@ def compile_model( # We do a forward pass where we accumulate inputs to use for compilation for name in self.module_names: # default is "calibrate" - self.remote_modules[name].fhe_local_mode = FHEMode.CALIBRATE + self.remote_modules[name].fhe_local_mode = HybridFHEMode.CALIBRATE self.model(x) self.configuration = configuration diff --git a/tests/torch/test_compile_torch.py b/tests/torch/test_compile_torch.py index be4952e171..65f143ef75 100644 --- a/tests/torch/test_compile_torch.py +++ b/tests/torch/test_compile_torch.py @@ -1056,7 +1056,7 @@ def test_net_has_no_tlu( if num_inputs > 1 and use_qat: return - use_conv = isinstance(input_shape, tuple) and len(input_shape) == 3 + use_conv = isinstance(input_shape, tuple) and len(input_shape) > 1 net = module(use_conv, use_qat, input_shape, n_bits) net.eval() diff --git a/tests/torch/test_hybrid_converter.py b/tests/torch/test_hybrid_converter.py index a0a727cd0b..99b784e238 100644 --- a/tests/torch/test_hybrid_converter.py +++ b/tests/torch/test_hybrid_converter.py @@ -26,7 +26,7 @@ def run_hybrid_model_test( # Create a hybrid model hybrid_model = HybridFHEModel(model, module_names) hybrid_model.compile_model( - inputs, n_bits=8, rounding_threshold_bits=8, configuration=configuration + inputs, p_error=0.01, n_bits=8, rounding_threshold_bits=8, configuration=configuration ) # Check we can run the simulate locally @@ -92,7 +92,7 @@ def run_hybrid_model_test( def test_gpt2_hybrid_mlp(list_or_str_private_modules_names, expected_accuracy): """Test GPT2 hybrid.""" - # Get GPT2 from Huggingface + # Get GPT2 from Hugging Face model_name = "gpt2" model = GPT2LMHeadModel.from_pretrained(model_name) tokenizer = GPT2Tokenizer.from_pretrained(model_name) @@ -106,7 +106,7 @@ def test_gpt2_hybrid_mlp(list_or_str_private_modules_names, expected_accuracy): def test_gpt2_hybrid_mlp_module_not_found(): """Test GPT2 hybrid.""" - # Get GPT2 from Huggingface + # Get GPT2 from Hugging Face model_name = "gpt2" model = GPT2LMHeadModel.from_pretrained(model_name) diff --git a/use_case_examples/hybrid_model/compile_hybrid_llm.py b/use_case_examples/hybrid_model/compile_hybrid_llm.py index f88cf6de21..ec93f24597 100644 --- a/use_case_examples/hybrid_model/compile_hybrid_llm.py +++ b/use_case_examples/hybrid_model/compile_hybrid_llm.py @@ -1,5 +1,6 @@ """Showcase for the hybrid model converter.""" +import json import os from copy import deepcopy from pathlib import Path @@ -34,7 +35,7 @@ def compile_model( hybrid_model.compile_model( inputs, n_bits=8, - # setting it to None is not enough -> weird + # We only do linear layers so no need to round rounding_threshold_bits=None, configuration=configuration, ) @@ -43,6 +44,7 @@ def compile_model( logits_simulate = hybrid_model(inputs, fhe="simulate").logits logits_disable = hybrid_model(inputs, fhe="disable").logits logits_original = model(inputs).logits + # Ensure logits_disable and logits_original return the same output for the logits assert torch.allclose(logits_disable, logits_original, atol=1e-7), "Outputs do not match!" # Compare the topk accuracy of the FHE simulate circuit vs. the original. @@ -84,7 +86,7 @@ def compile_model( device = "cpu" print(f"Using device: {device}") - # Get GPT2 from Huggingface + # Get GPT2 from Hugging Face model_name = "gpt2" model_name_no_special_char = model_name.replace("/", "_") tokenizer = AutoTokenizer.from_pretrained(model_name) @@ -98,7 +100,10 @@ def compile_model( "model_name": model_name, "model_name_no_special_char": model_name_no_special_char, "configuration": config, + "index": config_index, } + with open("configuration.json", "w") as file: + json.dump(configuration, file) # In this case we compile for only one sample # We might want to compile for multiple samples diff --git a/use_case_examples/hybrid_model/infer_hybrid_llm_generate.py b/use_case_examples/hybrid_model/infer_hybrid_llm_generate.py index e6953ded15..c831894c56 100644 --- a/use_case_examples/hybrid_model/infer_hybrid_llm_generate.py +++ b/use_case_examples/hybrid_model/infer_hybrid_llm_generate.py @@ -1,4 +1,5 @@ """Showcase for the hybrid model converter.""" +import json import time from pathlib import Path @@ -6,17 +7,14 @@ from torch.backends import mps from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizer, TextStreamer -from concrete.ml.torch.hybrid_model import FHEMode, HybridFHEModel +from concrete.ml.torch.hybrid_model import HybridFHEMode, HybridFHEModel if __name__ == "__main__": - configs = [ - ("transformer.h.0.mlp", 0.934), - (["transformer.h.0.mlp", "transformer.h.1.mlp"], 0.42), - ("transformer.h.0.mlp.c_proj", 0.986), - ("transformer.h.0.attn.c_proj", 0.986), - ] - config_index = 3 - config = configs[config_index][0] + # Use configuration dumped by compilation + with open("configuration.json", "r") as file: + configuration = json.load(file) + config = configuration["configuration"] + config_index = configuration["index"] device = "cpu" if torch.cuda.is_available(): @@ -25,9 +23,9 @@ device = "mps" print(f"Using device: {device}") - # Get GPT2 from Huggingface - # TODO: migrate to auto-model with model_name + # Get GPT2 from Hugging Face model_name = "gpt2" + # Avoid having / in the string model_name_no_special_char = model_name.replace("/", "_") tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained( @@ -47,7 +45,7 @@ path_to_clients = Path(__file__).parent / "clients" hybrid_model.init_client(path_to_clients=path_to_clients) for module in hybrid_model.remote_modules.values(): - module.fhe_local_mode = FHEMode.REMOTE + module.fhe_local_mode = HybridFHEMode.REMOTE # Run example while True: diff --git a/use_case_examples/hybrid_model/load_and_analyze_data.py b/use_case_examples/hybrid_model/load_and_analyze_data.py deleted file mode 100644 index 7fe8b1851e..0000000000 --- a/use_case_examples/hybrid_model/load_and_analyze_data.py +++ /dev/null @@ -1,32 +0,0 @@ -import json -from collections import Counter - -import matplotlib.pyplot as plt -from datasets import load_dataset -from tqdm import tqdm - - -def main(): - """ - Load wikipedia dataset and plot lenghts of text histogram. - For now this considers only the number of characters but we could also consider some stats like - the number of tokens, unique tokens, etc ... - """ - dataset = load_dataset("wikipedia", "20220301.en") - lengths = [len(sample["text"]) for sample in tqdm(dataset["train"])] - count = Counter(lengths) - print(count) - with open("wikipedia_counts.json", "w") as file: - json.dump(count, file) - with open("wikipedia_values.json", "w") as file: - json.dump(lengths, file) - - # Matplotlib plot - plt.subplots() - plt.hist(lengths, bins=1000) - plt.yscale("log") - plt.savefig("lengths.png") - - -if __name__ == "__main__": - main() diff --git a/use_case_examples/hybrid_model/serve_model.py b/use_case_examples/hybrid_model/serve_model.py index 9609465205..62f2c1867b 100644 --- a/use_case_examples/hybrid_model/serve_model.py +++ b/use_case_examples/hybrid_model/serve_model.py @@ -40,9 +40,11 @@ def underscore_str_to_tuple(tup): MODELS_PATH = Path(os.environ.get("PATH_TO_MODELS", FILE_FOLDER / Path("model"))) PORT = os.environ.get("PORT", "5000") MODULES = defaultdict(dict) - # Populate modules -> could be done dynamically on each query tbh + # Populate modules at the beginning + # this could also be done dynamically on each query if needed + # We build the following mapping: + # model_name -> module_name -> input_shape -> some information for model_path in MODELS_PATH.iterdir(): # Model - # TODO: change with a struct/obj model_name = model_path.name MODULES[model_name] = defaultdict(dict) for module_path in model_path.iterdir(): # Module @@ -145,8 +147,6 @@ async def add_key( uid = str(uuid.uuid4()) key_bytes = await key.read() dump_key(key_bytes, uid) - # TODO: we should probably store for which circuit the key was generated for - # such that we can raise an error if the targeted keys does not match the correct circuit return {"uid": uid} @app.post("/compute")